人工智慧理論 | 逃離鞍點

多年以後,面對這篇文章Lemma 16的證明一臉懵逼的時候,我將會想起Alexander Coward教我saddle point的那個遙遠的下午。

和大部分人一樣,我在微積分課里接觸到「鞍點」這個概念的時候,並不知道關於它人們還有那麼多沒有搞懂的地方。鞍點在多變數微積分里是一個很簡單的概念。我們都知道對於一個可微分的函數,其導數為0的地方叫做critical point。對於convex/concave的函數來說,critical point要麼是最小值要麼是最大值,然而對於一個普通的函數來說,導數等於0還有可能意味著第三種情況,那就是鞍點。最直觀的講解這三種情況的方法當然是看圖。我的三維畫圖能力一直都是渣渣,連個立方體都畫不好,不過好在這樣的圖片到處都是,所以請看下圖,從左到右分別是局部最小值,局部最大值和鞍點。

當然這是兩個變數的函數的情況,再往上就沒法畫出來只能自行腦補了。從圖裡也能看出來,簡單來說,局部最小值就是說在這個地方,你沿著每一個方向移動一點都會使得函數的值變大,而在鞍點處,沿著某些方向走能使函數值上升,另外一些方向則能使函數值下降(或不變)。然後大家肯定也學過如何判定一個critical point是屬於哪種情況的辦法,概括起來就是看Hessian矩陣的eigenvalue:如果所有的eigenvalue都是正的,那就是局部最小值;所有的eigenvalue都是負的,那就是局部最大值;eigenvalue有正有負(或者有0),那就是鞍點。看到這裡你應該能領悟到,對於一個普通函數來說,鞍點遠遠比想像中的要多,因為比如有一個100個變數的函數,假設他Hessian的每一個eigenvalue分別有50%的概率是正或者負(當然這個假設並不是很合理),那麼是不是意味著一個critical point極大概率是一個鞍點,而不是優化問題里我們想要的最大或最小值?那假如一個函數的眾多critical points中,絕大部分都是鞍點,那麼這有可能會給優化帶來很大的問題。

首先我們要明白為什麼我們需要研究如何逃離鞍點。在現實中,非凸優化的目標是收斂到一個局部最小值。之所以不追求全局最小值,一方面是因為找全局最小值顯然是一個NP-Hard問題,另一方面是因為很多研究表明,對於一些常見的問題,比如矩陣補全,張量分解,甚至最受關注的神經網路,找到一個局部最小值和找到一個全局最小值的效果是差不多的。可以理解為在每一個局部最小值處,函數的值都差不多(對於有些問題甚至可以證明每一個局部最小值都是全局最小值)。然而鞍點就完全不是這麼回事了。簡單的一維例子,y=x^3這個函數,x=0是一個critical point,然而你從0處往左走,這個函數就趨向於負無窮。所以一個演算法如果收斂於鞍點,那是絕對不可以接受的。[1]這篇文章就指出了鞍點可能是訓練神經網路的一個主要障礙。

那麼問題來了,我們現在訓練神經網路,一般用的是gradient descent演算法或者其變種,然而gradient descent在普通函數上似乎並沒有對收斂到哪裡有任何的保證,因為對gradient descent的理論分析是基於gradient的norm收斂到0的,而gradient的norm為0隻能保證這是一個critical point。那我們有什麼辦法能逃離鞍點呢?在討論這個問題之前,我們先對「逃離鞍點」做一個正式的定義。嚴格來說,逃離鞍點意味著優化演算法將收斂到一個ε-second-order-critical-point,其定義如下:

是說,gradient的norm很小,同時Hessian的最小的一個eigenvalue要大於一個很小的負數。這和數學上對局部最小值的定義略微有些不用(要求gradient的norm是0,Hessian最小的eigenvalue大於0),當然這是因為證明收斂的時候我們一般都只能獲取一個關於ε的結果,很難說他剛好能收斂到某個數。因為ε可以任意小,現實中可以把這個就理解成局部最小值。

過去的一些年裡,對於一些特定的非凸優化問題,人們通過一些辦法,比如特殊的初始點,獲得了可以收斂到最小值的結果,但是對於general的問題並不通用。如何通用的解決這個問題呢?一個很簡單思路就是利用Hessian的信息。我們之前說過,分類一個critical point的標準就是看它Hessian的eigenvalues,那我們計算Hessian,識別出鞍點,同時在鞍點處沿著可以下降的方向前進不就行了?在傳統上,人們認為這是可以逃離鞍點的唯一辦法,畢竟不算Hessian,你連是不是鞍點都不知道,何談逃離。這方面的結果最好的是我們熟悉的Yurii Nestrov提出的一個演算法[2],大致是牛頓法加上cubic regularization,思路其實非常簡單。然而基於牛頓法的演算法需要在每一步求出Hessian的inverse,這個計算量無疑是很可怕的。一個很自然的延伸就是用Hessian-vector product,正如從牛頓法進化到Quasi-Newton Method。這樣每一步的計算量下降了,但是需要的步數卻帶上了一個關於問題的維度d的項(這一點一會還要細說)

更嚴重的是,現在的問題維度過高,比如神經網路通常都是上百萬個變數,別說Hessian了,就連Hessian-vector product都不現實,唯一在計算時間上現實的就是以SGD為基礎的演算法。順便一提,關於stochastic Quasi-Newtown的研究並不是沒有,然而收斂性質並沒有明顯的好於SGD。也就是說,我們必須要只依賴gradient的信息來逃離鞍點

可能你和我一樣,看到這裡的時候是不相信的。畢竟只憑gradient,連是不是鞍點都不知道,怎麼逃離?然而大自然就是這麼的神奇。現在問題的背景以及介紹完畢,馬上進入真正的重點部分。

首先說一個令人驚訝的結果,這個結果就是喬神Michael Jordan和Benjamin Recht兩位大佬合作的一篇論文[3]的標題:Gradient Descent Converges to Minimizers! 也就是說gradient descent自己就會收斂到最小值,並不會被鞍點卡住。這個結果可是夠令人震驚的。當然這篇文章也略微有些標題黨,他想說的是,在一些合理的假設下,gradient descent不被鞍點卡住的概率為1。學過測度論的同學都知道,概率為1和「這件事一定會發生」的微妙區別。當然這個結果也足夠強了。具體來說,他用拓撲裡面的Stable Manifold theorem證明了所有能讓gradient descent收斂到鞍點的初始值的測度為0,也就是說如果我們隨機選擇初始值,那麼幾乎可以肯定我們可以收斂到最小值。一個簡單的例子(出自Nestrov著名的lecture notes),考慮以下這個函數

這個函數有三個critical points,(0, 0),(0, -1) 以及 (0, 1),第一個是鞍點,後兩個是最小值。簡單的計算可以得出,如果初始點在剛好在x軸上,那麼gradient descent會收斂於鞍點,

除此之外都會收斂到兩個最小值之一。我們知道x軸在R^2里是零測度的,所以隨機初始值幾乎可以保證收斂到最小值。

看到這裡,你不由得問喬神,這個結論這麼牛逼,那我們能多快的收斂到最小值呢?那喬神就要告訴你四個字:

因為這個證明方法並沒有給出任何關於iteration次數的上限。你可以把這理解成讓gradient descent永遠跑下去,總有一天會收斂到最小值的。

這不是坑爹嗎,說了和沒說一樣。然而對於基礎的gradient descent的分析似乎也就只能到這一步為止了。那麼有沒有辦法能給這個時間加上一個bound呢?篇幅所限,我們下篇文章繼續說。

作者簡介:

Xavier,本科就讀於加州大學伯克利分校,應用數學和統計學雙專業。目前於芝加哥大學就讀計算數學專業博士。本科階段的學習和研究主要集中於偏微分方程的數值解以及醫學圖像分析。博士階段主要興趣是非凸優化及其在機器學習中的應用,以及計算機視覺的相關問題。

編輯:蜜汁醬,Echo


推薦閱讀:

TAG:人工智慧演算法 | 人工智慧 | 科技 |