[DNN] 嘗試理解深度神經網路的Large-batch魔咒

最近貴司的「一小時訓練ImageNet」論文在國內外各種刷屏(research.fb.com/publica),看了一下,確實非常實用主義的文章,介紹很多有用的trick,包括系統實現上的很多坑都覆蓋到了。其中談到加速訓練的難點之一是:需要用到更大的mini-batch size,而這通常會降低準確率,所以他們通過linear-scaling learning rate解決了這個問題。看到這裡我對於這個難點產生了疑問——batch size越大,不應該訓練的方差越小,隨機性越小,從而能夠更準確地擬合數據集么?

從一個對深度學習接觸不多的人(比如我)的角度,這點確實有點反直覺。當batch-size不斷增大,直到跟數據集一樣大的時候,SGD (Stochastic Gradient Descent)就變成了最樸素的GD,一次梯度更新會掃描一遍所有的數據來算梯度。看教科書和在CMU上Machine Learning的時候被灌輸的理念都是:SGD相對於GD,或者小的batch相對於大的batch,有助於更快收斂,但是準確度會下降。為什麼到了深度神經網路這裡就反過來了呢?

我的第一猜想是:神經網路的函數空間非常non-convex,所以mini-batch越小就越容易不斷跳出local minima,尋找更好的最小值。但是自己馬上感覺這個猜想有很多漏洞,不能自圓其說,所以我去查證了一下其他人的分析——Facebook的論文原文有提及過這個問題,以及ICLR 2017上也有一篇論文針對這個問題分析了一下。有趣的事,兩篇文章的觀點並不相同,Facebook的論文還輕踩了對方一下說「根據我們跑的實驗這事不是你們說的那樣兒的」。由於ICLR 的論文先出來,我們先看看它怎麼說:

ICLR 2017: ON LARGE-BATCH TRAINING FOR DEEP LEARNING:nGENERALIZATION GAP AND SHARP MINIMA openreview.net/pdf?

這篇文章主要研究了「為什麼Large batch size會讓錯誤率提高」的問題,提出了四個可能的猜想:

(i) LB methods over-fit the model;

(ii)nLB methods are attracted to saddle points;

(iii) LB methods lack the explorative properties of SBnmethods and tend to zoom-in on the minimizer closest to the initial point;

(iv) SB and LB methodsnconverge to qualitatively different minimizers with differing generalization properties.

然後通過實驗,得出了支持(iii)和(iv)的證據。也就是說,主要是兩點原因:

1) LB (Large-Batch) 方法探索性太差,容易在離起始點附近很近的地方停下來

2) LB和SB由於訓練方式上的差異,最終會導致它們最終收斂的點具有一些數學屬性的差異

#1 很好理解,跟我前面的猜想有點類似。這裡著重談談#2 - 文章談到,LB方法會收斂到Sharp-minimum,而SB方法會收斂到Flat-minimum。這兩種minimum的差別如圖所示:

在同樣的Bias下,明顯Flat的曲線比Sharp的曲線更加接近真實情況,所以Flat Minimum的generalization performance更好。

然後,基於這個假設,他們給出的解決方案是:先用SB方法訓練幾個epoch,讓它先探索一下,找到一個比較Flat的區域,再用LB方法慢慢收斂到正確的地方。論文給出了performance vs. # of epoch trained with SB,但個人感覺不是很有說服力。。。

Facebook: Accurate, Large Minibatch SGD:nTraining ImageNet in 1 Hour

再回到Facebook這篇文章,作者認為,LB之所以不work,不是因為上面那篇論文提到的泛化能力的問題,而主要是一個optimization issue(我的理解是優化過程/優化演算法的問題)。文章沒有給出理論分析, 而是直接給出了實驗數據:首先,這篇論文是基於「Linear Scaling Learning Rate」來做的,簡單來說,假如說原來batch size是256,learning rate是0.1;那麼當把batch size設成8192的時候,learning rate就設成3.2 。batch size翻多少倍,learning rate就翻多少倍。然後,基於這個方法,論文作者發現,如果用LB方法,剛開始就用很大的learning rate的話,效果其實是很差的;但是,只要剛開始把LR設小點,後來逐步把LR提高到正常的大小,那麼效果拔群,LB能夠得到跟SB幾乎一毛一樣的training curve,以及基本相同的準確度。

基於這個觀察,作者認為,LB不work的主要原因是

large minibatch sizes arenchallenged by optimization difficulties in early training

(至於為什麼,這個跟Linear Scaling Learning Rate的assumption有關:簡單來說,就是Linear Scaling Learning Rate這個trick是基於一定的assumption的,而這個assumption在網路權重急劇變化的時候——也就是剛開始訓練的時候——是不成立的。所以,一開始就應用那麼大的learning rate會出事。我解釋的不是很清楚,具體可以去看原論文)

總結

上篇兩篇論文各有千秋:ICLR那篇著重理論分析,用漂亮的實驗驗證了Sharp-minimum和Flat-minimum的區別,啟發性非常大,但是給出的解決方案不是很令人信服;Facebook這篇直接從實戰經驗出發,實驗和解釋都比較令人信服,不過理論上相對弱些。

對於兩者的Claim,其實不能說誰對誰錯,因為兩者的實驗方法不一樣;ICLR那篇沒有應用Linear Scaling Learning Rate而是直接應用了ADAM來作為optimizer,得出的結果跟Facebook的肯定不能直接相比。如果ICLR那篇論文的作者可以使用Facebook的方法論重新跑實驗的話,說不定得出的結論會有很大不同。甚至說,雙方的結論其實不完全互斥,而是可以被統一成一個理論(比如我現在拍腦袋想的:剛開始訓練的時候,Large-batch得出來的梯度不準確,所以如果設的learning rate太大,就更加容易陷入Sharp-minimum出不來,從而影響到後面的優化,之類之類的)。


推薦閱讀:

鋼鐵俠3裡面的賈維斯系統是個什麼構造的?
分散式系統理論基礎 - 一致性、2PC和3PC
Zeppelin:一個分散式KV存儲平台之概述

TAG:分布式系统 | 深度学习DeepLearning | Facebook |