標籤:

SWATS演算法剖析(自動切換adam與sgd)

SWATS是ICLR在2018的高分論文,提出的一種自動由Adam切換為SGD而實現更好的泛化性能的方法。

論文名為Improving Generalization Performance by Switching from Adam to SGD,下載地址為:arxiv.org/abs/1712.0762

作者指出,基於歷史梯度平方的滑動平均值的如adam等演算法並不能收斂到最優解,因此在泛化誤差上可能要比SGD等方法差,因此提出了一種轉換機制,試圖讓演算法自動在經過一定輪次的adam學習後,轉而由SGD去執行接下來的操作。

演算法本身思想很簡單,就是採用adam這種無需操心learning rate的方法,在開始階段進行梯度下降,但是在學習到一定階段後,由SGD接管。這裡前面的部分與常規的adam實現區別不大,重要的是在切換到sgd後,這個更新的learning rate如何計算。 整個演算法步驟流程如下:

熟悉adam的應該能熟悉藍色的部分,這個就是adam的原生實現過程。

作者比較trick的地方就是14行到24行這一部分。這一部分作者做了部分推導,Lambda=lambda_k/(1-{eta_2}^k)作為最後的切換learning rate。

演算法的整個實現邏輯並不複雜,這裡列出自己實現時遇到的一些問題。

填坑 & 問題

  1. 在上面的演算法流程第12行,有個alpha_k,這個在整個流程中未介紹如何實現,本人閱讀論文後,發現應該是學習率衰減的設計。一如很多深度學習策略一樣,這裡可以設置經過若干輪迭代後,學習率降為原來的1/N。在論文中,作者使用了在150輪後,將學習速率減少10倍。即alpha_{k+1}=left{egin{matrix} {alpha_k/10}& if(k\%150==0)\ alpha_k & alpha_0=alpha end{matrix}
ight.
  2. 上面說了alpha_k的更新,我們通過公式推導,其實能發現lambda_kalpha_k有一定的關係,自己代碼實現的版本,發現切換的時機很大程度上和alpha_k有關,因為切換涉及到第17行的一個比較過程,lambda_kgamma_k本身都與alpha_k相關,當alpha_k降一個量級時,|frac{lambda_k}{1-{eta_2}^k}-gamma_k|本身也會更接近epsilon。其有些類似正比關係,因此一般都是在經過一定輪次的衰減後,才能觸發SGD切換時機。這一點目前本人實現驗證是這樣,未深入推理。
  3. 這個alpha_k還有個坑,就是實現該演算法,開始不太清楚這個k到底指的是epoch,還是指的經歷的batch數量。最後按照常規學習率衰減應該是按照epoch來算的,因此推測其k應該為epoch。
  4. 還有和大坑是Lambda作為學習率,在切換到SGD後應一直不變,該值為標量,因此應該如常用eta等學習率一樣,為正值,因此需要在17行加個約束,即frac{lambda_k}{1-{eta_2}^k}>0。(該場景難以復現,之前有次更新發現不設置為正值時,導致切換sgd後準確度大減)

總結

通過若干的對比,該論文變相增加了一些超參數,所以實際使用有待商榷。自己的數據集上經常就在還未滿足切換條件就已經收斂了。 目前已做了相應的實現,放在scalaML中,位置為github.com/sloth2012/sc,使用見github.com/sloth2012/sc。最後想要查看切換過程的話,建議將early_stop設置為false,然後將學習率衰減係數設置低一點。 代碼目前僅支持二分類。


推薦閱讀:

【乾貨】我是怎麼用四個月時間速成全棧機器學習的
機器學習演算法如何調參?這裡有一份神經網路學習速率設置指南
強化學習——簡介
Hands-On ML,CH2:房價預測
機器學習常見演算法分類匯總

TAG:機器學習 |