神經網路中利用矩陣進行反向傳播運算的實質

訓練神經網路模型時,為了優化目標函數,我們需要不斷地迭代更新網路中的權值,而這一過程是通過反向傳播演算法(Backpropagation,BP)實現的。在神經網路中,訓練樣本和權值參數都被表示為矩陣的形式,因為這樣更利於反向傳播的計算。

之前學習反向傳播演算法的時候一直有誤解,認為它需要用到大量的矩陣求導,但仔細理解後發現實際上用到的還是標量的求導,只不過用矩陣表示出來了而已。

本文中通過遞推的方法,用矩陣來形象化地表示神經網路模型訓練中反向傳播的過程,並從單個輸入樣本逐步擴展到多個輸入樣本(mini-batch)。

單個輸入樣本計算

對於形如L=fleft(Y
ight)=fleft(XW
ight)=fleft(sum_{i}^{n}w_ix_i
ight)的目標函數來說(省略偏置項b,因為它可以被整合進X中),若中間項Y取單一值y,則其可以表示為兩個向量相乘的形式

egin{align}left[y
ight]=left[egin{matrix}x_1&x_2&cdots&x_nend{matrix}
ight]	imesleft[egin{matrix}w_1\w_2\vdots\w_nend{matrix}
ight]end{align}

若中間值取多維left[egin{matrix}y_1 & y_2 & ... & y_nend{matrix}
ight],則可以表示為兩個矩陣相乘的形式

egin{align}left[egin{matrix}y_1&y_2&cdots&y_nend{matrix}
ight]=left[egin{matrix}x_1&x_2&cdots&x_nend{matrix}
ight]	imesleft[egin{matrix}w_{11}&w_{12}&cdots&w_{1c}\w_{21}&w_{22}&cdots&w_{2c}\vdots&vdots&ddots&vdots\w_{n1}&w_{n2}&cdots&w_{nc}end{matrix}
ight]end{align}

對於其中每一個目標值,y_c=sum_{i=1}^{n}w_{ic}x_i=w_{1c}x_1+w_{2c}x_2+cdots+w_{nc}x_n。那麼如果想要求得y_cw_{ic}的導數,只需要列出公式

egin{align}W

那麼對於參數矩陣的列向量W_c來說

frac{partial y_c}{partial W_c}=X^T

假設目標函數L對於Y的導數為Y,那麼L對於W_cW的列向量)的偏導數W_c則為

egin{align}W

W_c

那麼L對於W的偏導數W則可以通過矩陣表示為

egin{align}W

W

多個輸入樣本計算

在神經網路中,我們通常採用mini-batch的方法進行訓練,對於含有m個樣本的mini-batch來說

egin{align}left[egin{matrix}y_{11}&y_{12}&cdots&y_{1c}\y_{21}&y_{22}&cdots&y_{2c}\vdots&vdots&ddots&vdots\y_{m1}&y_{m2}&cdots&y_{mc}end{matrix}
ight]=left[egin{matrix}x_{11}&x_{12}&cdots&x_{1n}\x_{21}&x_{22}&cdots&x_{2n}\vdots&vdots&ddots&vdots\x_{m1}&x_{m2}&cdots&x_{mn}end{matrix}
ight]	imesleft[egin{matrix}w_{11}&w_{12}&cdots&w_{1c}\w_{21}&w_{22}&cdots&w_{2c}\vdots&vdots&ddots&vdots\w_{n1}&w_{n2}&cdots&w_{nc}end{matrix}
ight]end{align}

其中,W_c可表示為

egin{align}W

W_cY_cY的列向量),

W表示為

egin{align}W

W

快速計算方法

其實還有一種簡便的方法可以推導上面的公式,對於Y=XW,假設Y的維度是M 	imes CX的維度是M 	imes NW的維度是N 	imes C,那麼可以利用維度的關係進行導數的計算。Y的維度必然是M 	imes C,那麼W的維度必然是N 	imes C且與X有關,那麼必有W,同理必有X

參考

  • CS231n Convolutional Neural Networks for Visual Recognition

推薦閱讀:

用人工神經網路求解微分方程
邏輯與神經之間的橋
9行Python代碼搭建神經網路
M.1.0 神經網路的數學基礎-前言

TAG:神经网络 | 深度学习DeepLearning |