標籤:

反向傳播演算法和梯度下降理解

所有人都在說AI、CNN、RNN、GAN好像所有人都會TensorFlow、Torch、Keras、MxNet…..的時候我還在寫這種東西,應該會被鄙視的吧。

梯度下降

絕大部分的機器學習演算法最後都是最優化一個目標函數,而梯度下降演算法是尋找一個函數局部最優解的有效方法。我做了個例子,用Softmax實現mnist上的分類器(手寫數字識別),具體代碼參見GitHub.Softmax 的詳細介紹可以參照ufldl 使用交叉熵作為誤差函數Softmax 的目標函數可以用下圖中的公式表示。

其中mm 是訓練樣本數量,kk 是分類數, hetaθ 是要求的參數。

我對誤差函數直觀的理解就是:模型對所有樣本的預測結果與真實結果之間的差異的期望,預測和真實結果之間的差異用交叉熵來計算的。如果把負號放到求和符號的最裡面,第二個求和符號計算的是交叉熵,第一個求和符號再乘以1/m則計算的是期望。如果換做是二次誤差函數只是把預測和真實之間的差異的度量方式改了。

理論上來說誤差越小越好(不考慮過擬合),如何求函數的最小值呢?首先想到的是直接求解,通常這是行不同的,對於計算機來說方便的一種方法就是遍歷,我可以遍歷所有的可能取值然後選最小的,在這裡的 hetaθ 是一個784X10 的參數矩陣,要遍歷所有可能的實數自然不行。

梯度下降演算法給出了一種搜索方法,首先隨機初始化 hetaθ 然後沿著梯度方向(求極大值)和梯度相反的方向(求極小值,也就是梯度下降方向)來尋找參數。那為何梯度方向是函數變化最大的方向呢?這要扯到方嚮導數了。對於三維空間來說,偏導數是函數沿坐標軸方向上的變化率,則方嚮導數是函數沿著任意方向的變化率。方嚮導數是可以通過偏導數和方向餘弦計算的到的。

梯度是偏導數組成的向量,比如f(x,y,z)f(x,y,z) 在點P(x_0, y_0, z_0)P(x0,y0,z0)的梯度為向量(f_x(P) , f_y(P) , f_z(P))(fx(P),fy(P),fz(P))梯度本身是向量所以才有所謂的梯度下降或上升方向這種說法。從方嚮導數和梯度直接的關係可以推出梯度方向是函數增長最快的方向,其實這種問題可以不糾結,因為和代碼實現關係不大。

知道了梯度反方向是函數下降最快的方向,所以可以每次都在參數上減去對應的偏導數。

如何計算偏導數呢?對上圖中的J( heta)J(θ) 求偏導數,根據求導法則我們可以把求導放到第一個求和符號的裡面,或者直接把第一個求和符號拆開,整個等式的右邊就會是很多加號相連的,所以代碼實現的時候可以求每一個樣本的那部分再加總。如過理解這個上面說的這幾句,對於隨機梯度下降,或只使用少量樣本的批量梯度下降為什麼work其實也相對比較容易了。(體會:深入理解公式是代碼實現的前提)

反向傳播演算法

上面知道求解最優問題的時候,如果使用梯度下降,需要計算函數的梯度,反向傳播演算法解釋計算神經網路中誤差函數梯度的一種方法。

在網路(這裡指的是多次全連接網路)中上一層的輸出是下一次的輸入,導致這個求導過程及其複雜,聰明的人想出了一個求導方法即反向傳播,至於怎麼想出來的 Michael Nielsen 寫的書里說It』s just a lot of hard work ,這裡寫的也參考這本說的第二部分

反向傳播的大體思路就是找到一個中間變數,通過鏈式法則來求解誤差對參數的偏導數。書里總結了四個公式我覺得真的是太好了,書里證明了前兩個公式,後兩個看完應該也能證了。

公式里的Z^lZl 是第l 層激活函數的輸入(就是wx+b之後的結果,這樣應該比較清楚),delta^lδl是誤差函數對z^lzl的偏導。

第一個公式是誤差函數對最後一層的z的偏導Delta CΔC是誤差函數,對最後一層激活函數的輸出的偏導,odot⊙後面部分是激活函數對z求導。

第二個公式是delta^lδl 在各層之間的遞推關係

第三個和第四個公式是我們最終要計算的。怎麼得到第三個和第四個公式其實是比較簡單了,方法就是把誤差函數C對w和b的求導拆成C對z 的導數乘z對w或b 的導數(連上法則)。

如果真的理解了,想想下面的問題。當網路使用不同的誤差函數,或激活函數,會影響到公式的哪些地方?,代碼實現上會有哪裡差異?

誤差函數的變化只會影響第一個公式,也就是最後一層誤差的部分,激活函數的變化為影響第一和第二個公式

我讀了很多對反向傳播演算法原理解釋的文章,我發現搞懂這四個公式是對代碼實現和原理理解最好的方法,當然你或許會有其他的方法。

另外作者的代碼實現也非常的好,我也寫了自己的版本,雖然看過作者的代碼,自己重寫的時候還是遇到很多問題,調了好久。

——————————————————

知乎不能顯示公式我表示很遺憾


推薦閱讀:

Cousera deeplearning.ai筆記 — 淺層神經網路(Shallow neural network)
全面理解word2vec
有關NLP的比賽
機器學習數學:梯度下降法
深入機器學習系列21-最大熵模型

TAG:機器學習 |