深度學習——分類之Inception v2——Batch Normalization

論文:Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

作者:Sergey Ioffe, Christian Szegedy

ImageNet Top5錯誤率:4.8%

Inception v2主要就是在Inception v1地基礎上,把LRN扔了,然後加了Batch Normalization,訓練快很多,效果也好一點,在之後的網路中,Batch Normalization得到了廣泛的運用,比如ResNet裡面,很有用啊。

中心思想:希望內部的激活分布好一點,就手動把它轉換到好一點的分布,簡單除暴還很work。

動機

以前學習率太低、對網路初始化要求高、訓練帶有飽和激活函數的網路時尤其困難(比如Sigmoid激活),歸根結底,還是每一層輸出的分布,不受你控制,隨著網路逐漸加深,一個零點幾的值乘以一個零點幾的值,越來越小,很可能就遇到梯度消失的問題。

我們希望每一層的輸入都漂亮規整,然後網路就可以輕鬆愉快地進行訓練,就像一個一個的淺層網路單獨在訓練一樣。

漂亮地輸入是什麼樣地呢?0均值1方差就很漂亮。想要輸入0均值1方差,那做個變換好了,變換到0均值1方差,這差不多就是Batch Normalization做地事情了。

公式

公式如下:

每個訓練的mini-batch,每一個神經元(對FC來說)的輸出或每一個卷積核(對Conv來說)的輸出激活圖,計算均值和方差,然後變換 x_ihat{x}_i 即可,現在,0均值1方差啦。

假設mini-batch的size為m,則:

FC:每個輸入樣本,每個神經元,輸出一個標量,m個樣本就輸出m個標量,然後求均值和方差就好了;

Conv:每個輸入樣本,每個卷積核,輸出一個激活圖,大小為WxH,所有位置都做當同對待,m個樣本輸出m個WxH大小的激活圖,一共mxWxH個標量,然後求均值和方差就好了;

影響的消除

但是,並不一定變換到0均值1方差就是最優選擇了,所以,引入了兩個參數γ和β,對於輸入的變化,計算了均值和方差之後,實際上就是減均值的差除以標準差,即減一個數再除一個數,那麼反過來,乘一個數再加一個數,就可以完全消除影響:

y=frac{x-a}{b}=>x=y*b+a

這就是作者論文中公式的最後一項。而變換係數γ和β是兩個可學習參數。如果不需要Batch Normalization的話,可以自動學習消除掉BN的影響,相當於恆等映射了,這樣就完全無害了啊,只有好處,沒有壞處,網路覺得這個東西不好,那自動學習γ和β就可以消除掉其影響,變為一個identity恆等映射。

Train & Inference

在訓練的時候,均值和方差可以通過每次輸入的樣本來計算得到,但是每次輸入的樣本都不一樣,所以計算得到的均值和方差都不一樣,實際上使用的時候通過滑動平均來平滑一下:

new_mean = old_mean*momentum + computed_mean*(1-momentum)nnew_var = old_var*momentum + computed_var*(1-momentum)n

PyTorch裡面的momentum默認是0.1。

上面的均值和方差會被記住,然後測試的時候,就使用記住的均值和方差,而不再通過輸入計算(因為輸入可能只有一張而非batch,並且希望對於所有輸入同等對待而不因為不同的輸入而有不同的結果)

Inception v2

然後作者在Inception的基礎上,加了Batch Normalization,和其它一些小修改,包括用兩個3x3卷積替換5x5卷積,移除LRN,添加BN,平均池化最大池化混用等,具體參見下面的表格:

Batch Normalization完全無害,並且解決了對於初始化的嚴重依賴,最重要的是增大學習率,用14倍的速度快速達到相同的精度。


推薦閱讀:

人工智慧與中國象棋的頂尖高手對弈的結果有哪些?
【活動預告】第二屆設計與人工智慧會議:設計與數據智能的探索
案例分析(大數據)數據港SH.603881:數據中心業務擴張迅猛,產業升級迫在眉睫
如何獨立完成一個人工智慧項目? - 以 AlphaGo 為例 - 九章免費講座預告

TAG:神经网络 | 人工智能 | 深度学习DeepLearning |