SWATS演算法剖析(自動切換adam與sgd)
SWATS是ICLR在2018的高分論文,提出的一種自動由Adam切換為SGD而實現更好的泛化性能的方法。
論文名為Improving Generalization Performance by Switching from Adam to SGD,下載地址為:https://arxiv.org/abs/1712.07628。
作者指出,基於歷史梯度平方的滑動平均值的如adam等演算法並不能收斂到最優解,因此在泛化誤差上可能要比SGD等方法差,因此提出了一種轉換機制,試圖讓演算法自動在經過一定輪次的adam學習後,轉而由SGD去執行接下來的操作。
演算法本身思想很簡單,就是採用adam這種無需操心learning rate的方法,在開始階段進行梯度下降,但是在學習到一定階段後,由SGD接管。這裡前面的部分與常規的adam實現區別不大,重要的是在切換到sgd後,這個更新的learning rate如何計算。 整個演算法步驟流程如下:
熟悉adam的應該能熟悉藍色的部分,這個就是adam的原生實現過程。
作者比較trick的地方就是14行到24行這一部分。這一部分作者做了部分推導,作為最後的切換learning rate。
演算法的整個實現邏輯並不複雜,這裡列出自己實現時遇到的一些問題。
填坑 & 問題
- 在上面的演算法流程第12行,有個,這個在整個流程中未介紹如何實現,本人閱讀論文後,發現應該是學習率衰減的設計。一如很多深度學習策略一樣,這裡可以設置經過若干輪迭代後,學習率降為原來的1/N。在論文中,作者使用了在150輪後,將學習速率減少10倍。即。
- 上面說了的更新,我們通過公式推導,其實能發現和有一定的關係,自己代碼實現的版本,發現切換的時機很大程度上和有關,因為切換涉及到第17行的一個比較過程,和本身都與相關,當降一個量級時,|本身也會更接近。其有些類似正比關係,因此一般都是在經過一定輪次的衰減後,才能觸發SGD切換時機。這一點目前本人實現驗證是這樣,未深入推理。
- 這個還有個坑,就是實現該演算法,開始不太清楚這個k到底指的是epoch,還是指的經歷的batch數量。最後按照常規學習率衰減應該是按照epoch來算的,因此推測其k應該為epoch。
- 還有和大坑是作為學習率,在切換到SGD後應一直不變,該值為標量,因此應該如常用eta等學習率一樣,為正值,因此需要在17行加個約束,即。(該場景難以復現,之前有次更新發現不設置為正值時,導致切換sgd後準確度大減)
總結
通過若干的對比,該論文變相增加了一些超參數,所以實際使用有待商榷。自己的數據集上經常就在還未滿足切換條件就已經收斂了。 目前已做了相應的實現,放在scalaML中,位置為https://github.com/sloth2012/scalaML/blob/master/src/main/scala/com/lx/algos/ml/optim/GradientDescent/SWATS.scala,使用見https://github.com/sloth2012/scalaML/blob/master/src/test/scala/com/lx/algos/ml/GradientDescentTest.scala。最後想要查看切換過程的話,建議將early_stop設置為false,然後將學習率衰減係數設置低一點。 代碼目前僅支持二分類。
推薦閱讀:
※【乾貨】我是怎麼用四個月時間速成全棧機器學習的
※機器學習演算法如何調參?這裡有一份神經網路學習速率設置指南
※強化學習——簡介
※Hands-On ML,CH2:房價預測
※機器學習常見演算法分類匯總
TAG:機器學習 |