如何理解 natural gradient descent?

natural gradient descent 就是在把gradient乘上一個fisher information matrix的逆,感覺有點像牛頓法。但intuition上要用到distribution的黎曼結構,不是很能理解。https://arxiv.org/pdf/1412.1193.pdf 這篇paper介紹了natural gradient 但沒有看得特別明白。有誰能簡單地介紹下?


將分布之間的Fisher information matrix (FIM)看成是統計流形上的黎曼度量,然後用流形上的最速下降方向作為搜索方向,就是自然梯度。這是一種概念上比較簡潔漂亮的處理方式,但顯然不是一種最容易理解的方式,很多人第一次接觸到的時候都是有些懵的。

思路上更簡單直接的方式可能是從約束優化來理解。考慮一個基本的函數 f,考慮概率分布 p(	heta) 上的優化問題

min E[f] = int f(x)p(x|	heta)dx

如果我們想找參數化的分布 p(	heta+delta 	heta) ,使得E(f)的改進程度最大,最直觀的方法自然是直接對E(f)做一步梯度下降。但是由於 p(	heta)p(	heta+delta 	heta) 是概率分布,他們之間的距離不是用參數之間的歐式距離來定義的(簡單來說,沿梯度下降一步之後的 	heta 可能不滿足分布參數的要求,比如正態分布的協方差矩陣變得不正定了),而是用分布之間的KL-divergence來定義的 I(p(	heta + delta 	heta)|p(	heta)) = int ln left(frac{ p(x|	heta + delta 	heta)} {p(x|	heta)} 
ight) p(x|	heta +delta 	heta) d x =: KL(p( 	heta + delta 	heta)||p( 	heta)).

由於這個KL-div 不對稱,它不滿足距離的定義。同時由於 delta 	heta 比較小,我們可以對此式展開做二階近似

I(p(	heta + delta 	heta)|p(	heta)) =frac{1}{2}sum_{i,j} I_{ij}(	heta)delta	heta_idelta	heta_j + O(|delta	heta|^3),

其中 I_{ij} 就是Fisher information matrix 的分量,換句話說,FIM就是KL-div的二階近似

I_{ij} = int frac{partial ln p(x|	heta)}{partial 	heta_i}frac{partial ln p(x|	heta)}{partial 	heta_j}p(x|	heta) dx

回到原來的優化問題,我們面對的問題變成了

E[f(x)|	heta+delta	heta] = E[f(x)|	heta]+sum_i frac{partial E[f|	heta]}{partial 	heta_i}delta 	heta_i + cdots\ s.t. ~~KL(p(	heta+delta	heta)||p(	heta)) =epsilon

將上面的KL-div的二階近似帶入,構造Lagrange 函數,就有

L(delta	heta, lambda) = sum_i frac{partial E[f|	heta]}{partial 	heta_i} delta 	heta_i + lambdaleft( epsilon - frac{1}{2}sum_{i,j} I_{ij}(	heta)delta	heta_i	heta_j 
ight).

此式可以寫成矩陣形式 L(delta	heta, lambda) = 
abla_{	heta}^T E[f|	heta] delta 	heta + lambdaleft( epsilon - frac{1}{2} mathbf{ delta 	heta}^T I ~ delta 	heta 
ight). 對此式稍作推導,就得到最速下降方向

hat{delta	heta} =- alpha I^{-1}(	heta)
abla_{	heta}E[f|	heta]

這裡的 alpha 是一個 epsilon 	o 0 的無窮小量。這個方向就是所謂的自然梯度方向。

可以看到,這裡的推導沒有用到任何微分幾何和黎曼度量的概念,唯一用到的就是概率分布之間的KL-div 和它的二階近似,然後套用約束優化的拉格朗日乘子,也就無所謂「自然」了。當然,這裡的推導會比黎曼度量-自然梯度 更加技術化一些,技術化的東西相對來說不容易推廣。

自然梯度和牛頓法是有關聯的,在某些特殊情況下可以認為是Gauss-Newton法的近似。


擴展一點:

後來 John Schulman 等人的工作,https://arxiv.org/pdf/1502.05477.pdf 中的TRPO 演算法就採取了一種近似 Natural Gradient Descrent 的方法,參見 Appendix C。


本科的時候稍微做過一下下相關的東西,所以來略微了解一些些,不過現在現在轉行專註做RL了,還是受益很多來自bayesian那一套。

理解比較粗,大概就是把歐式度量換成了黎曼度量。也就是進行了warp。好處是和坐標系無關。是一個二階方法。難點在於計算fisher的逆。但是很多model本身有很好的structure,導致可以很高效地近似,比如neural network。

具體可以參考kfac的paper以及一篇natural gradient的review。可以搜一下James martens,看看同期文章。

natural gradient在RL裡面比較有用,特別是在exploration上。可以參考VIME,裡面推出來的公式就是natural gradient。

以及david blei 的SVI裡面也有講到,裡面是假設了所有的prior是exp family的分布,然後推著推著發現用coordinate ascent (註:已修正)的方式更新parameter是自帶那個natural gradient的…

natural gradient的大佬有一位叫Amari的日本人,在微分幾何上有很高的造詣… 最近好像有本新書叫information geometry?? 不太了解,之前實驗室的師兄推薦過,買了一本一直沒時間看。


從kl-divergence到fisher information matrix之間的推導,高票答案已經說明的很好了。

我稍微補充下開頭和結尾,擴充下這個topic的big picture。具體到細節我也不是很理解,大家可以通過給出的鏈接及鏈接里索引的paper再一步步擴充。

如果解釋有錯誤歡迎指正。


開頭:KL的引入

假設我們通過sgd來更新參數,每次更新都是通過找出原始網路的loss對於每個參數的gradient再乘以一個步長(learning rate)以求新的網路的loss更低。但這種方式並不對預測結果有任何保證,新的網路可能會在參數上與原網路很接近,但預測結果上大不相同。

而我們可以引入新/舊網路的預測相似度作為一種制約,找出同等相似情況下loss最低的方向進行更新。兩個分布之間相似度的計算可以通過kl divergence來度量,這也正是KL的引入。更多的細節請參考

A intuitive explanation of natural gradient descent?

kvfrans.com

關於natural gradient和傳統gradient的關係可以參考原作者的另一篇文章

What is the natural gradient, and how does it work??

kvfrans.com

結尾:natural gradient的優點

假設我們找到了natural gradient的真實方向(fisher information matrix只是其中一種二階逼近的方法?)我們可以保證每次更新的穩定性,這對於一系列robotics/control的任務極為重要,對於一些基於pre-trained model進行fine-tune或是擴展的任務也是有極大幫助的,至少可以「解決」例如catastrophic forgetting等問題。


推薦閱讀:

BAT機器學習面試1000題系列(136-140題)
[讀論文]Big Batch SGD: Automated Inference using Adaptive Batch Sizes
為什麼梯度的負方向是局部下降最快的方向?
深入淺出--梯度下降法及其實現
瞎談CNN:通過優化求解輸入圖像

TAG:微分幾何 | 梯度下降 | 深度學習DeepLearning |