Rocket Training: 一種提升輕量網路性能的訓練方法
Motivation
工業界,一些在線模型,對響應時間提出非常嚴苛的要求,從而一定程度上限制了模型的複雜程度。模型複雜程度的受限可能會導致模型學習能力的降低從而帶來效果的下降。
目前有2種思路來解決這個問題:一方面,可以在固定模型結構和參數的情況下,用計算數值壓縮來降低inference時間,同時也有設計更精簡的模型以及更改模型計算方式的工作,如Mobile Net和ShuffleNet等工作;另一方面,利用複雜的模型來輔助一個精簡模型的訓練,測試階段,利用學習好的小模型來進行推斷。這兩種方案並不衝突,在大多數情況下第二種方案可以通過第一種方案進一步降低inference時間,同時,考慮到相對於嚴苛的在線響應時間,我們有更自由的訓練時間,有能力訓練一個複雜的模型,所以我們採用第二種思路,來設計了我們的方法。
Our Approach
如圖所示,訓練階段,我們同時學習兩個網路:Light Net 和Booster Net, 兩個網路共享部分信息。我們把大部分的模型理解為表示層學習和判別層學習,表示層學習的是對輸入信息做一些高階處理,而判別層則是和當前子task目標相關的學習,我們認為表示層的學習是可以共享的,如multi task learning中的思路。所以在我們的方法里,共享的信息為底層參數(如圖像領域的前幾個卷積層,NLP中的embedding), 這些底層參數能一定程度上反應了對輸入信息的基本刻畫。
整個訓練過程,網路的loss如下:
Loss包含三部分:第一項,為light net對ground truth的學習,第二項,為booster net對ground truth的學習,第三項,為兩個網路softmax之前的logits的均方誤差(MSE),該項作為hint loss, 用來使兩個網路學習得到的logits盡量相似。
Co-Training
兩個網路一起訓練,從而booster net的 會全程監督小網路 的學習,一定程度上,booster net指導了light net整個求解過程,這與一般的teacher-student 範式下,學習好大模型,僅用大模型固定的輸出作為soft target來監督小網路的學習有著明顯區別,因為booster net的每一次迭代輸出 雖然不能保證對應一個和label非常接近的預測值,但是到達這個解之後一定能找到最終收斂的解 。
Hint Loss
Hint Loss這一項在SNN-MIMIC中採用的是和我們一致的對softmax之前的logits做L2 Loss:
Hinton的KD方法是在softmax之後做KL散度,同時加入了一個RL領域常用的超參temperature T:
也有一個半監督的工作再softmax之後接L2 Loss:
大家都沒有給出一個合理的解釋為什麼要用這個Loss,而是僅僅給出實驗結果說明這個Loss在他們的方法中表現得好。KD的paper中提出在T足夠大的情況下,KD的Loss 是等價於 的。我們在論文里做了一個稍微細緻的推導,發現這個假設T足夠大使得 成立的情況下,梯度也是一個無窮小,沒有意義了。同時我們在paper的appendix里 在一些假設下 我們從最大似然的角度證明了 的合理性。
Gradient Block
由於booster net有更多的參數,有更強的擬合能力,我們需要給他更大的自由度來學習,盡量減少小網路對他的拖累,我們提出了gradient block的技術,該技術的目的是,在第三項hint loss進行梯度回傳時,我們固定booster net獨有的參數 不更新,讓該時刻,大網路前向傳遞得到的 ,來監督小網路的學習,從而使得小網路向大網路靠近。
在預測階段,拋棄booster net獨有的部分,剩下的light net獨自用於推斷。整個過程就像火箭發射,在開始階段,助推器(booster)載著衛星(light net)共同前進,助推器(booster)提供動力,推進衛星(light net)的前行,第二階段,助推器(booster)被丟棄,只剩輕巧的衛星(light net)獨自前行。所以,我們把我們的方法叫做Rocket Launching。
Experiments
實驗方面,我們驗證了方法中各個子部分的必要性。同時在公開數據集上,我們還與幾個teacher-student方法進行對比,包括Knowledge Distillation(KD), Attention Transfer(AT)。 為了與目前效果出色的AT進行公平比較,我們採用了和他們一致的網路結構寬殘差網路(WRN)。 實驗網路結構如下:
(a) bottom rocket net on wide residual net
(b) interval rocket net on wide residual net
紅色+黃色表示light net, 藍色+紅色表示booster net。(a)表示兩個網路共享最底層的block,符合我們一般的共享結構的設計。(b)表示兩網路共享每個group最底層的block,該種共享方式和AT在每個group之後進行attention transfer的概念一致。
我們通過各種對比實驗,驗證了參數共享和梯度固定都能帶來效果的提升:
通過可視化實驗,我們觀察到,通過我們的方法,light net能學到booster net的底層group的特徵表示。
除了自身方法效果的驗證,在公開數據集上,我們也進行了幾組實驗。
在CIFAR-10上, 我們嘗試不同的網路結構和參數共享方式,我們的方法均顯著優於已有的teacher-student的方法。在多數實驗設置下,我們的方法疊加KD,效果會進一步提升。
這裡WRN-16-1,0.2M 表示wide residual net, 深度為16,寬度為1,參數量為0.2M.
同時在CIFAR-100和SVHN上,取得了同樣優異的表現:
真實應用
同時,在阿里展示廣告數據集上,我們的方法,相比單純跑light net, 可以將GAUC提升0.3%.
如圖:
我們的線上模型在後面的全連接層只要把參數量和深度同時調大,就能有一個提高,但是在線的時候有很大一部分的計算耗時消耗在全連接層(embedding 只是一個取操作耗時隨參數量增加並不明顯),所以後端一個深而寬的模型直接上線壓力會比較大。表格里列出了我們的模型參數對比以及離線的效果對比:
最後附上我們的paper 和code地址:
paper: https://arxiv.org/abs/1708.04106
code: zhougr1993/Rocket-Launching
我們來自阿里媽媽精準定向業務線的演算法團隊 歡迎各位英豪加入~
推薦閱讀:
※2017年歷史文章匯總|深度學習
※brox近期論文
※計算機視覺常見領域問題概要(深度學習)
※計算機視覺部分演算法最佳解釋
※3D卷積神經網路Note01
TAG:深度學習DeepLearning | 機器學習 | 計算機視覺 |