如何理解 natural gradient descent?
natural gradient descent 就是在把gradient乘上一個fisher information matrix的逆,感覺有點像牛頓法。但intuition上要用到distribution的黎曼結構,不是很能理解。https://arxiv.org/pdf/1412.1193.pdf 這篇paper介紹了natural gradient 但沒有看得特別明白。有誰能簡單地介紹下?
將分布之間的Fisher information matrix (FIM)看成是統計流形上的黎曼度量,然後用流形上的最速下降方向作為搜索方向,就是自然梯度。這是一種概念上比較簡潔漂亮的處理方式,但顯然不是一種最容易理解的方式,很多人第一次接觸到的時候都是有些懵的。
思路上更簡單直接的方式可能是從約束優化來理解。考慮一個基本的函數 ,考慮概率分布 上的優化問題
如果我們想找參數化的分布 ,使得E(f)的改進程度最大,最直觀的方法自然是直接對E(f)做一步梯度下降。但是由於 和 是概率分布,他們之間的距離不是用參數之間的歐式距離來定義的(簡單來說,沿梯度下降一步之後的 可能不滿足分布參數的要求,比如正態分布的協方差矩陣變得不正定了),而是用分布之間的KL-divergence來定義的
由於這個KL-div 不對稱,它不滿足距離的定義。同時由於 比較小,我們可以對此式展開做二階近似
其中 就是Fisher information matrix 的分量,換句話說,FIM就是KL-div的二階近似
回到原來的優化問題,我們面對的問題變成了
將上面的KL-div的二階近似帶入,構造Lagrange 函數,就有
此式可以寫成矩陣形式 對此式稍作推導,就得到最速下降方向
這裡的 是一個 的無窮小量。這個方向就是所謂的自然梯度方向。
可以看到,這裡的推導沒有用到任何微分幾何和黎曼度量的概念,唯一用到的就是概率分布之間的KL-div 和它的二階近似,然後套用約束優化的拉格朗日乘子,也就無所謂「自然」了。當然,這裡的推導會比黎曼度量-自然梯度 更加技術化一些,技術化的東西相對來說不容易推廣。
自然梯度和牛頓法是有關聯的,在某些特殊情況下可以認為是Gauss-Newton法的近似。
擴展一點:
後來 John Schulman 等人的工作,https://arxiv.org/pdf/1502.05477.pdf 中的TRPO 演算法就採取了一種近似 Natural Gradient Descrent 的方法,參見 Appendix C。
本科的時候稍微做過一下下相關的東西,所以來略微了解一些些,不過現在現在轉行專註做RL了,還是受益很多來自bayesian那一套。
理解比較粗,大概就是把歐式度量換成了黎曼度量。也就是進行了warp。好處是和坐標系無關。是一個二階方法。難點在於計算fisher的逆。但是很多model本身有很好的structure,導致可以很高效地近似,比如neural network。
具體可以參考kfac的paper以及一篇natural gradient的review。可以搜一下James martens,看看同期文章。
natural gradient在RL裡面比較有用,特別是在exploration上。可以參考VIME,裡面推出來的公式就是natural gradient。
以及david blei 的SVI裡面也有講到,裡面是假設了所有的prior是exp family的分布,然後推著推著發現用coordinate ascent (註:已修正)的方式更新parameter是自帶那個natural gradient的…
natural gradient的大佬有一位叫Amari的日本人,在微分幾何上有很高的造詣… 最近好像有本新書叫information geometry?? 不太了解,之前實驗室的師兄推薦過,買了一本一直沒時間看。
從kl-divergence到fisher information matrix之間的推導,高票答案已經說明的很好了。
我稍微補充下開頭和結尾,擴充下這個topic的big picture。具體到細節我也不是很理解,大家可以通過給出的鏈接及鏈接里索引的paper再一步步擴充。
如果解釋有錯誤歡迎指正。
開頭:KL的引入
假設我們通過sgd來更新參數,每次更新都是通過找出原始網路的loss對於每個參數的gradient再乘以一個步長(learning rate)以求新的網路的loss更低。但這種方式並不對預測結果有任何保證,新的網路可能會在參數上與原網路很接近,但預測結果上大不相同。
而我們可以引入新/舊網路的預測相似度作為一種制約,找出同等相似情況下loss最低的方向進行更新。兩個分布之間相似度的計算可以通過kl divergence來度量,這也正是KL的引入。更多的細節請參考
A intuitive explanation of natural gradient descent?kvfrans.com關於natural gradient和傳統gradient的關係可以參考原作者的另一篇文章
What is the natural gradient, and how does it work??kvfrans.com結尾:natural gradient的優點
假設我們找到了natural gradient的真實方向(fisher information matrix只是其中一種二階逼近的方法?)我們可以保證每次更新的穩定性,這對於一系列robotics/control的任務極為重要,對於一些基於pre-trained model進行fine-tune或是擴展的任務也是有極大幫助的,至少可以「解決」例如catastrophic forgetting等問題。
推薦閱讀:
※BAT機器學習面試1000題系列(136-140題)
※[讀論文]Big Batch SGD: Automated Inference using Adaptive Batch Sizes
※為什麼梯度的負方向是局部下降最快的方向?
※深入淺出--梯度下降法及其實現
※瞎談CNN:通過優化求解輸入圖像
TAG:微分幾何 | 梯度下降 | 深度學習DeepLearning |