[讀論文]Big Batch SGD: Automated Inference using Adaptive Batch Sizes
大家好我是zyy,本人是機器學習和深度學習的初學愛好者,想跟大家一起分享我的學習經驗,大家一起交流。我寫的東西不一定全對,但肯定是我一步一步走出來的坑,嚼爛了的經驗,可以供大家直接「吸收」
我的文章主要會涉及各種機器學習和深度學習演算法的推導和輪子的實現,以及一些小的應用demo,偶爾還會有一些論文的演算法實現。
文中出現的所有代碼都可以在我的GitHub上找到。
GitHub
題外話
今天先寫題外話。上周寫完那兩篇文章後,漲粉漲的厲害,各種各樣的人都關注了,有這個行業的博士,有工業界的老前輩,還有跟我一樣跨界,想進入這個領域的學者等等。如今接觸數據挖掘機器學習已經兩年多了,這兩年多不說學了多少東西,回想起來只有自己努力鑽研的情景,周圍沒有志同道合的人,只能自己探索,有了困難無人討論,有了突破看懂某個模型無人訴說喜悅,這種感覺太慘了。
如今看到這麼多人跟我一起熱愛這個東西,看到你們喜歡我的文章,我非常開心,但感到更多的是責任,如何去把我的經驗、我看見的聽見的好玩的東西變成通俗易懂的東西帶給大家,一直是我思考的問題。昨天我接到了一個阿里的電面,他問了我一個問題,如何把一個複雜的模型解釋給跟你對接的業務人員呢?由於沒有經驗,我也是胡亂說了一通,今天回想起來,還為此感到遺憾。現在我有了個答案,我要做的就是這個啊,讓更多的人,不光是做研究的人、做應用的人,還有看熱鬧的人,都能看懂機器學習的模型,通曉其中的意思。我由於自身條件的限制,不能在這個領域做出突破,那我就帶領更多人入門罷,這就是我為這個領域做出的貢獻。
所以,下次有人問我這個問題時,我會說:
「那就讓他們來看我的專欄吧。」
對了有人說我拿封面圖騙粉...
啊呸,,你要這麼說我要報警了呢(╯‵□′)╯︵┴─┴
[讀論文]Big Batch SGD: Automated Inference using Adaptive Batch Sizes
今天帶來訓練神經網路的另外的trick,batch size的trick。
我們一般的模型可以寫為:
是以分布抽出來得數據,刻畫了一組參數在數據上的表現好壞,也就是損失函數,為模型的期望損失。
當很大的時候,我們採用隨機梯度下降的方式來逐步減小loss,在迭代的第次時,我們選擇一組(batch)的數據來計算梯度:
理想情況下,有,即我們隨機一組的數據的梯度的期望為整體數據的期望,但是一組數據裡面肯定會有一定的雜訊數據,一組數據的梯度方向不一定朝向整體的梯度方向,所以隨機梯度下降會有震蕩的表現,收斂的時候也不是很穩定。所以說,當接近最優值時,我們應該讓步長儘可能小,讓它不至於衝過頭,錯過極值點。但是傳統的梯度下降方法,這些都是固定的超參數,是人為給的,然後本文就給出了梯度具體調整的策略。
然後就是大家喜聞樂見的推公式過程了。想看具體證明的看原文附錄,反正我都自己驗證過了,大家放心看。
首先我們有
這個不等式說明了batch的梯度誤差與batch的梯度的關係,理論上講這個誤差的期望為:
等號右邊就是的方差,所以我們就得到了梯度和batch損失的方差的關係。
後面作者通過一些關於bound的證明,推出了:
所以有,,
作者提出了第一個改進:
- 計算;
- 選擇步長;
- 進行更新。然後作者又加入了backtracking line search,中文好像叫回溯線搜索,BLS法。這個方法的主要思想是給定初始步長後,然後判斷不斷調整,找到最優步長後再進行參數更新。演算法太長,我直接貼截圖了:
作者後面又提了一種用Barzilai-Borwein估計的方法,把它當成一個二次規劃問題進行求解的,我沒太仔細看推導分析(時間不多,我也沒看太懂),重點來看一下怎麼做吧。
後面的是由上面的公式算出來的步長,然後與上一次迭代的步長加權相加。作者在一次循環中進行兩次迭代,前一次如之前提到的迭代方式,後一次迭代先從判斷batch的大小開始,對學習率進行加權修正,之後再進行迭代。重複上述過程。依作者的觀點他這麼做有這些好處:
- 不需要人為選擇迭代步長的超參數,更大的batch會減小雜訊影響。
- BLS演算法結合最大batch的方法可以做出較優的迭代決策。
- 高階的梯度下降方法需要在下降方向上做較多的矯正,近似計算可以減少這方面的工作量。
- 對於非凸的問題,每次迭代的計算複雜度是線性的,在接近極值點時,漸進的梯度也會逐漸消失,趨於收斂。
說說我的看法吧,這篇文章是對SGD進行了方方面面的小改進,不管是batch size還是learning rate,都有一些比較實用的改進,可以縮小整體訓練時長。但是一方面這些改進會縮小整體訓練周期,另一方面會增加每次訓練的計算複雜度。大家在做小規模數據的時候,不一定要用這些,殺雞焉用牛刀啊,要自己選擇適合的學習演算法。
推薦閱讀:
※機器學習筆記8 —— 邏輯回歸模型的代價函數和梯度下降演算法
※一文看懂常用的梯度下降演算法
※瞎談CNN:通過優化求解輸入圖像
※神經網路之梯度下降與反向傳播(上)
※梯度下降法快速教程 | 第三章:學習率衰減因子(decay)的原理與Python實現
TAG:深度学习DeepLearning | 论文 | 梯度下降 |