如何評價DeepMind的DNI?

DeepMind最近弄了一個DNI(Decoupled Neural Interfaces using Synthetic Gradients)

Paper: https://arxiv.org/abs/1608.05343

主題相關:

  • 最前沿:神經網路訓練方法大革新,反向傳播訓練不再唯一 - 智能單元 - 知乎專欄

  • 深度 | DeepMind官方深度解讀:使用合成梯度的解耦神經介面(附論文)


DeepMind的這篇論文腦洞的確有點大,剛開始看得時候可能一下子接受不來,這不,在那篇專欄《最前沿:神經網路訓練方法大革新,反向傳播訓練不再唯一》下面甚至有一些嘲諷之聲(很慚愧本人也在其列),覺得DeepMind搞的這個「用另外的網路來生成梯度」實在是太玄學了。但是仔細看完論文之後,我發現它一點都不玄學,可以聯繫到強化學習裡面的Q-Learning演算法。雖然它大概替代不了FCN、CNN上的BP,但是卻非常有可能替代掉BPTT(back propogation through time),在RNN系列模型上大放光芒。下面具體講講我對它的理解。

一、模型解釋

想像我們在一個深度神經網路的中間隨便砍一刀,砍完後前面有N層,後面有M層,這前後兩個大塊分別稱為A網路和B網路,那麼AB之間的通訊就包括從A到B的激活(activation)傳播,以及從B到A的梯度(gradient)傳播。於是我們可以把A和B網路分別視為一個函數fa和fb,其中A接受最底層的輸入x,輸出激活值h,而B接收這個激活值h外加最高層要預測的真實目標t做為輸入,輸出給A的梯度g:

(式子1) h = fa(x)
(式子2) gt = fb(h, t), gt表示gradient_true

關於上面這個抽象函數fb要說的一點是,fb實際上對應了三個階段:前向傳播、在loss層根據預測目標和真實目標計算梯度、後向傳播。每一個階段都可以用一個確定性的函數來描述,所以合起來抽象為fb。

好了,既然fb是一個確定性的函數,那就可以用一個確定性神經網路來近似擬合對不對?於是我們就把一個新的網路模塊M放在A和B中間,讓它去近似fb。這樣子用來傳給A的梯度就不是經過B網路前向傳播、loss層計算、後向傳播這樣一個冗長的過程計算出來的,而是直接由新的模塊M根據A的輸出h和真實目標t計算給出:

(式子3) gp = fm(h, t) ~= fb(h, t), gp表示gradient_predict

現在一個很重要問題來了:我們怎麼讓這個網路模塊M真的能夠近似擬合fb函數這樣一個計算過程?或者說怎麼模塊M本身怎麼訓練?從最直接的角度考慮,M要擬合的實際梯度我們是可以得到的,就是上面式子1中的gt,那麼就根據這個真實目標和預測出來的目標(式子3中的gp)按照square error loss去算M模塊本身的更新梯度就好了。

但是這樣又產生一個嚴重的問題:真實梯度gt的計算本身不也是要經過一個冗長的過程嗎?如果要用真實梯度來訓練模塊M的話,那豈不是最終又回到原來的樣子了?這可萬萬不行。那麼DeepMind是怎麼解決這個問題的呢?就是用後面一層的預測梯度回傳來代替真實梯度。具體來說,這一層的近似模塊M1給出的預測梯度是gp1,下一層的近似模塊M2給出的預測梯度是gp2,假裝gp2是從更後面的網路回傳過來的真實梯度gt2,回傳到這一層得到gp3,最後以gp3作為ground truth來更新M1。

為什麼這麼做能work?畢竟一開始的時候所有的M模塊都不可靠,用M2預測的梯度來幫助更新M1完全不合理啊?我們假設神經網路總共有5層,對應地有5個模塊M,從一開始的時候M1到M5都不可靠,但是隨著訓練的進行,最上層M5用的是真實的ground truth來更新的,它最先達到可靠的狀態,接著由於M5可靠了,那麼基於M5訓練的M4也會開始可靠,以此類推,直到所有的M模塊都變得可靠。這樣就從直觀上解釋了為什麼可以用後面一層的M模塊來幫忙更新這一層的M模塊。

不過,這樣的解釋是假設了整個神經網路的dynamics相對穩定,而實際上訓練的過程中網路主體參數都在更新,變來變去,那麼M模塊想要去近似的那個過程函數fb根本就不是固定的,所以它可能永遠都追不上真實的fb。但是在短時間內我們可以假設整個網路的dynamics沒有太大的變化,就像微積分裡面假設一個高次函數的鄰域近似是一次函數。所以在調參合理的情況下,M模塊可以不被真實的fb甩得太遠,不遠不近地跟著,這樣的M產生的梯度雖然無法跟真實的梯度完全一樣,但也夠用了。

其實這種做法早在強化學習裡面就有了,叫做Q-learning。Q-learning的做法是說,我可以近似預測這一步和下一步的reward值,同時也知道從下一步的reward值可以怎樣反推出這一步的reward值(這個反推本身是真實的而不是近似的),那麼我就把這個反推出來的、不那麼可靠的reward值當成真的,然後來更新給出這一步reward值得模塊參數。

那麼更進一步,M預測梯度只根據上一層的激活值h,不利用真實目標t,行不行?就是把式子3變成下面這樣:

(式子4) gp = fm(h) ~= fb(h, t), gp表示gradient_predict

可想而知,這樣近似擬合的難度更大了,從論文實驗部分來看,這種做法最終模型的性能也確實低於式子3的做法。所以我們知道了真實目標t還是得用上才行。

最後,論文又大膽地把前向傳播也給decoupled了,也就是用一個和M類似的小網路模塊N來近似前K層網路的前向計算過程,這樣後面的層就不用慢慢等前面的層一個個計算完才能接著往前算了。同樣地,這個近似模塊N本身也要訓練,用的ground truth就是上一層的N預測的激活值向前傳播一層之後的結果。一切都和上述的後向decoupled的原理一樣。

這一步邁得比前面所說的大得多。因為從計算複雜度的角度上講,單單decouple後向的話,只能帶來常數級別的加速,即從O(2N)到O(N),因為僅僅是後向的時候做到並行,前向的時候還是得串列一層層算過去;但是同時decouple前向和後向的話,理論上每個層之間都可以完全並行了,複雜度瞬間變成O(1)。當然如果考慮到實際情況中GPU的內存容量,可能沒法完全並行,加速的比例就沒這麼誇張了。

二、應用價值

接著說一下為什麼我覺得這個Decoupled Neural Interface在非RNN系模型上應用價值不大,而在RNN上很有潛力。首先像FCN、CNN這樣的架構實際用的時候也不會搭得太深,雖說residual network能夠搭到一兩百層,但是感覺實際應用中也就搭個幾十層吧,層數不深的時候DNI帶來的時間增益可能並不是那麼明顯。其次在FCN、CNN上用DNI的時候,每層中間的M模塊肯定不是參數共享的,因為每一層的梯度預測任務不具有同構性,更不用說它們的輸入輸出維度不一致了,所以每個M模塊是參數獨立的。而且論文裡面的實驗結果也顯示出用了DNI的最終模型性能是比用傳統BP要差的,只能說是comparable,當然這個性能損失也在意料之中,畢竟不是用真的梯度去更新嘛。

相比之下,首先是RNN系模型在時間維上夠「深」,非常需要DNI來幫它解耦不同層之間的依賴,恰恰是其用武之地;其次是RNN在時間維度上具有同構性,每個M模塊需要做的任務是等價的,所以它們完全可以共享參數,這樣一個step上的M更新了就是所有的M更新了,或者用另一個角度來說,M可以彙集各個time step上的信息來更新自己,得到的自身的梯度更穩定,訓練更快。

三、局限

首先一個明顯的局限就是用了DNI之後模型性能會下降一些,但是不大。

然後讀到這裡,我們也可以對「DNI能夠替代BP演算法」這樣一個說法有一個準確的判斷了。如果我們把BP認為是一個從loss層計算出梯度後向傳播來得到每一層梯度的過程(狹義),那麼DNI的確可以替代BP;但如果我們把相鄰兩層之間的後向梯度計算也叫BP的話(廣義),那麼DNI是離不了它的,因為M模塊的訓練就以來於相鄰兩層之間的梯度後向傳播。

由此也可以看出,DNI目前能夠應用的場景局限於本來就能夠end-to-end differential的架構,它並不能支撐那種中間某一個地方斷開、不可導的架構,因為它要求於相鄰兩層之間具有可導性。所謂的「預測梯度」,說到底不是憑空預測,而是根據回傳得到的梯度去預測的,只不過會稍微滯後於網路真實的dynamics,這裡面並沒有什麼玄學。


hi all,

I also implement a Tensorflow version of DNI.

GitHub - andrewliao11/DNI-tensorflow: DNI(Decoupled Neural Interfaces using Synthetic Gradients) implementation with Tensorflow

I think it"s a novel idea to decouple the relation btn layers, and it"s very cool, isn"t it?


哈哈哈哈這是我這學期project題目,把它放到分散式上和Asynchronous SGD做比較…悶聲作大死…這個奇葩模型感覺一個worker慢了整個模型就藥丸…敬請期待十二月的作死結果…


推薦閱讀:

語音識別對單個聲學模型需要做 viterbi beam嗎?
語言學研究對機器的語音識別或者語音輸出有實質性的幫助嗎?
在對時間序列進行分類時,隱馬爾科夫模型、人工神經網路和支持向量機這三種模型哪種更合適,為什麼?
如何編寫易被複用的,高質量的機器學習演算法代碼?有哪些這樣的代碼示例?請舉例:代碼,原文,你發表的文獻。

TAG:機器學習 | 認知科學 | 神經網路 | 深度學習DeepLearning |