如何訓練模型?(2)——梯度下降法
梯度下降法
上一節講到,我們使用最小二乘法得到的閉式解去計算線性回歸模型的損失函數——均方根誤差(RMSE)的最小值,也就是所謂的訓練模型,從而得到模型最優參數。但是最小二乘法如上節寫到具有諸多局限性,這一節將介紹另外一種方法——梯度下降法。
- 為什麼梯度下降法可以用來訓練模型?
- 批量梯度下降法(BGD)
- 隨機梯度下降法(SGD)
- 小批量梯度下降法(MBGD)
- 總結和思考
為什麼梯度下降法可以用來訓練模型?
我們訓練模型就是為了最小化損失函數,也就是一個求最值的過程。根據微積分的知識我們知道,梯度向量從幾何意義上講,就是函數增加最快的地方。比如函數 ,對 求偏導數就得到其梯度向量 ,或者表示為 或 ,在點 處,沿著其梯度向量方向 即為函數 增加最快的方向,換句話說,在這個方向上可以更快的找到函數的最大值,相對地,如果沿著梯度的反方向 就更快找到函數最小值。
但是值得注意的是,用梯度下降法找到的不一定是最小值,可能是極小值也可能是一個鞍點。因為梯度下降是尋找的是一個不動點(fixed point)!只有你的函數是凸函數的時候,極小值才為最小值。而我們現在討論的均方根誤差函數就是一個凸函數。
下面這張圖就很直觀的展示梯度下降的過程,假象我們在一座大山上的某處位置,由於我們不知道怎麼下山,於是決定走一步算一步,也就是在每走到一個位置的時候,求解當前位置的梯度,沿著梯度的負方向,也就是當前最陡峭的位置向下走一步,然後繼續求解當前位置梯度,向這一步所在位置沿著最陡峭最易下山的位置走一步。這樣一步步的走下去,一直走到覺得我們已經到了山腳。當然這樣走下去,有可能我們不能走到山腳,而是到了某一個局部的山峰低處。
批量梯度下降法(BGD)
接著上一節說到,一般線性回歸函數的假設函數為
對應的損失函數(均方誤差函數)為
其中的 是為了計算結果簡潔加上去的,不影響最後結果。對損失函數求偏導數得到
由此得到對所有據點求偏導數,累加為
批量梯度下降法,是梯度下降法最常用的形式,具體做法也就是在更新參數時使用所有的樣本來進行更新
其中 為步長,也就是每次下降的跨度,這個參數的設定很重要,步長太大,會導致迭代過快,甚至有可能錯過最優解。步長太小,迭代速度太慢,很長時間演算法都不能結束。所以演算法的步長需要多次運行後才能得到一個較為優的值。
其Python實現的代碼為
從以上介紹可以看出BGD可以得到全局最優解,但每一次迭代都會用到訓練集中所有的數據,在做數據分析的時候,往往數據量比較大,即m很大,其時間開銷可想而知。為了解決這個問題,就引入了隨機梯度下降法(SGD)。
隨機梯度下降法(SGD)
SGD和BGD不同點在於SGD每次迭代的時候並不是將所有數據點累加起來,而是隨機選取一個點來迭代更新,其更新函數為
其Python實現的代碼為
SGD每次迭代只用了一個數據,所以其訓練速度是非常快的,但另一方面,由於其搜索過程是隨機取數據的,看上去比較盲目,需要迭代的次數就相對更多了,而且往往無法達到全局最優。介紹的兩種演算法各有千秋,那能不能在它們之間取個折中呢?於是有了小批量梯度下降法(MBGD)。
小批量梯度下降法(MBGD)
更新函數不再像SGD那樣只用一個數據點來迭代,MBGD每次迭代選擇n(n<m)個數據點,更新函數為:
用Python實現起來很方便,只需在SGD代碼中稍作修改,這裡我就不貼出來了。MBGD在BGD和SGD之間做了一個折中,可以說秉承了兩種優點。
總結和思考
前面說到,對於非凸函數而言,梯度下降法找到的並不一定是最小值,可能是極小值,還可能是鞍點。對於這些情況,BGD表現是很差的,一旦陷入局部最小值,就難以擺脫出來,而SGD在這一點上表現得好一些,因為是每次迭代是隨機使用數據點,這種盲目性反而使得其優化過程不容易陷入局部最優處,但是它也難以達到全局最優點,只能很接近。MBGD的提出綜合了兩者的優點,使得SGD收斂的更加平穩。但是MBGD仍然存在一些挑戰!
① 難以選擇合適的學習率。學習率太小,網路收斂得太慢,而學習率太大,又會出現損失函數在最小點附近擺動,而無法達到最小點。雖說可以在學習中,逐漸改變學習率(慢慢減小),但這是人為事先設定好的,不能很好的適應數據的內在規律。
② 對不同特徵量採用不同的學習率。我們對特徵向量中的所有的特徵都採用了相同的學習率,如果訓練數據十分稀疏並且不同特徵的變化頻率差別很大,這時候對變化頻率慢得特徵採用大的學習率而對變化頻率快的特徵採用小的學習率是更好的選擇。
③ 這些梯度下降方法難以逃脫」鞍點」,鞍點既不是最大點也不是最小點,在這個點附近,所有方向上的梯度都接近於0,這些梯度下降演算法很難逃離它。關於逃離鞍點的介紹,可以看看這篇博文——演算法優化之道:避開鞍點
後續有學者提出不少優化演算法,能夠較好的解決上述問題,在這裡我就不再一一介紹,我也是在查資料的時候偶然看到這篇博客博客的介紹,有興趣可以深入研究一下。
本文涉及的完整代碼:https://github.com/wildwind0/Machine-Learning
參考:
[Machine Learning] 梯度下降法的三種形式BGD、SGD以及MBGD
梯度下降(Gradient Descent)小結
推薦閱讀:
※尋找全局最小值和防止過擬合之間是不是矛盾的?
※梯度下降法和高斯牛頓法的區別在哪裡,各自的優缺點呢?
※神經網路之梯度下降與反向傳播(上)
※梯度上升演算法與梯度下降演算法求解回歸係數怎麼理解?
※一文看懂常用的梯度下降演算法