梯度,梯度下降,隨機梯度下降,batch梯度下降
之前也被問到了梯度下降和隨機梯度下降的對比,所以根據之前挖的線性回歸求解的坑,現在來看看梯度下降這一種數值解法。
(1)首先講講梯度下降法:
梯度下降法是一種數值方法進行極小值求解的方式,通過每一輪不停的迭代,直到最終收斂在極值點(如果函數是凸的,最後收斂在最小值點)。
(1梯度的定義:
梯度的定義是從方嚮導數引出來的,首先假設如果f(x,y)連續可微,那方嚮導數的定義:
式子(1)
我們現在已知在點 對於x偏導數 和對於y的偏導數 ,於是我們可以表示方嚮導數為式子(2),具體的推到請自己翻書了:
式子(2)
(2為什麼沿著梯度方向變化最快
根據我們高中的知識,導數表示斜率,也可以表示變化率,斜率越大,曲線越陡峭,函數值變化的越快,因此現在我們想知道,那個方向上導數最大,也就是說,那個沿著哪個方向走最陡。顯然我們先把式子(2)寫成向量的形式:
式子(3)
要向使得Df(x,y)最大,上訴點乘可以寫成餘弦的形式,即
式子(4)
其中 是向量A和向量B的夾角,於是我們知道,對於給定的點 , 為定值,因此我們只需要 最大即可,因此,當A和B同向時候, 最大為1,因此我們可以知道,而向量A的方向我們通常稱為梯度方向,沿著梯度方向,增加最快,因此,沿著梯度反方向,下降最快。
(3梯度下降法
上邊我們解釋了,為什麼沿著梯度反方向,下降最快,下邊來講講梯度下降法:
剛剛我們講了,沿著梯度反方向函數值下降最快,因此,如果我們迭代求解,每一次沿著當前所處點的梯度負方向下降,那麼,最終會收斂到一個極小值。如果我們以線性回歸為例,由於MSE函數,即均方損失函數是一個凸函數,因此最終會收斂到全局最小值。對於要優化的函數
式子(5)
我們需要求解參數 (一個M維的向量,是訓練樣本X的特徵數目,也就是x也是個M維向量),現在對於我們需要求解的其中第j個維度 ,我們有第k輪的迭代值 ,現在我們求k+1輪的值:
式子(6)
其中 為學習的步長,這個值可以預先設定。最終我們迭代到收斂即可,當然對於這裡收斂的定義我們有兩種方式:
- 認為 不超過某個上界 就行,不過不常用
- 通常最常用的是loss function的差不超過某個 ,即
不超過 就行。
(4圖解梯度下降
下邊是一個梯度下降的圖:
再給出一個函數的等值曲線,我們可以直觀的看到梯度下降過程(梯度方向為曲線上某點的法向量方向)
(2)現在來講講隨機梯度下降法和batch梯度下降
上邊我們可以看到,用梯度下降法時候,每次都要計算所有樣本(i=1...N),這樣十分消耗內存,計算慢,收斂也慢,但是用梯度下降法能確保最終收斂的點一定是極值點或者最值點(函數是凸的情況下)。
那我們是不是有更快的方法使得梯度下降法收斂呢,答案就是隨機梯度下降法和batch梯度下降法,所謂隨機梯度下降,我們就是在計算每一輪的迭代值的時候,不需要用所有的樣本點,隨機選取其中一個就行,因此以線性回歸為例,我們得到每一輪的迭代式子變成:
式子(7)
這樣我們就能很快的求解完每一輪迭代,當然理論告訴我們,如果樣本足夠多,採樣足夠隨機,隨機梯度下降的結果和梯度下降的結果一樣。不過現實是,並不是那麼美好的,最終等高線為例,我們可以得到下邊的圖:
我們對比兩幅登高線圖,可以明顯的發現:
- 梯度下降是每回按照迭代點所在等高線法向量方向前進,雖然計算的慢,收斂的慢,但是迭代的輪數比較少,走的很明確,而且最終會收斂到極小值點,凸函數的話會收斂到最小值點。
- 隨機梯度下降每回走的就比較隨意,迭代輪數比較多,但是有點就是算的快,收斂的快,但是缺點就是,最終收斂的地方,可能不是極小值點,可能收斂(這裡我們定義的收斂就是前後兩次迭代的差值小於某個上限)在最後一直在極小值附近的點。
而batch梯度下降,故名思意就是從所以樣本中選一批,也不像隨機梯度下降一樣選一個,取了一個二者折中的辦法。
omega_{j+1}^{i}= omega_{j}^{i}-alphaast2/Nastsum_{i = 1}^{N}{(y_{i} - ast x_{i}x_{a}^{b}) } omega_{j+1}^{i}= omega_{j}^{i}-alphaast2/Nastsum_{i = 1}^{N}{(y_{i} - ast x_{i}x_{a}^{b}) } omega_{j+1}^{i}= omega_{j}^{i}-alphaast2/Nastsum_{i = 1}^{N}{(y_{i} - ast x_{i}x_{a}^{b}) }向著梯度反方向行走,最終是可以收斂的全局最小點的。
那麼對於我們線性回歸需要解的參數 (一個M維的向量,和x的維度一樣),我們選擇其中的第i維的分量 ,我們不斷的迭代,假設已經知道第j次迭代的值,我們可以求j+1次的值:
推薦閱讀:
※結合google facets進行機器學習數據可視化
※【機器學習】SVM理解及推導
※數以萬計的廣告中,客戶為什麼會點擊你的這幅?
※為什麼要對特徵進行縮放(歸一化)
※Pytorch用GPU到底能比CPU快多少?