M.2.2 多維函數、導數與梯度迭代演算法
1. 從一維函數到多維函數性質
本節中準備講將多維函數的問題,很多人先入為主的感覺一維函數是最好理解的,這是因為我們最初遇到的函數都是一維函數,其有個簡單的性質就是有一個輸出就對應一個輸出:
這裡對於一維函數而言有一個展開:
在前面數學章節中我們已經介紹了多維函數,但是介紹的內容有一些複雜,計算機科學作為一門應用科學可能完全用不到太多複雜的概念,因此這裡多維函數我們只需用到一些導數相關的概念。
對於多維函數而言,我們將其也記為一維函數的形式,但是注意這裡的自變數是多維的 ,這裡為了方便將其寫成列向量的形式。
重新寫以下函數:
這並沒有什麼不同,對於機器學習而言就在於我們如何構建這個函數 ,但是函數展開過程則有了區別:
這裡可以進行簡化書寫:
這裡的H就是Hessian矩陣。這裡可以看的函數梯度對應於一維函數的一階導數,H對應於一維函數的二階導數。有些書為了更方便將函數梯度寫為g(x):
2. 優化演算法轉換
如果我們將函數僅僅展開成一階導數:
可以發現,當 ,如果每次都去dx=-g(x)的話,這樣函數是一步一步減少的,這也就是我們所說的最速下降法:
這裡取了一個 是因為梯度長度非坐標長度,過長會面臨迭代失穩。
#最速下降法示意nimport numpy as npnimport matplotlib.pyplot as pltndef func3(x, y):n return 4 + x**2 - 2 * y + 2*y**2 - x*2 - x*yndef dfunc3(x, y):n return 2*x-y-2, 4*y-x-2nx, y = 2.5, 4nx1 = np.linspace(0,4,20)nx2 = np.linspace(0,4,20)nx1,x2=np.meshgrid(x1,x2)nu, v = dfunc3(x1, x2)nplt.quiver(x1,x2,-u,-v)nfor itr in range(200):n gx, gy = dfunc3(x, y)n xo,yo=x,yn x+=-0.1*gxn y+=-0.1*gyn plt.plot([xo,x],[yo,y])n plt.scatter([xo],[yo])n if(itr%20==0):n print("%.f %.5f %.5f"%(x, y, func3(x, y)))nplt.show()n
對於牛頓法而言,其優化目標變為函數增量為0:
由此可以對兩邊dx進行求偏導所以:
由此迭代顯而易見,下面粘貼個代碼:
#牛頓法nimport numpy as npnimport matplotlib.pyplot as pltndef func3(x, y):n return 4 + x**2 - 2 * y + 2*y**2 - x*2 - x*yndef dfunc3(x, y):n return 2*x-y-2, 4*y-x-2ndef Hessian(x, y):n return np.array([[4, -1],[-1, 2]])nx, y = 2.5, 4nx1 = np.linspace(0,4,40)nx2 = np.linspace(0,4,40)nx1,x2=np.meshgrid(x1,x2)nu, v = dfunc3(x1, x2)nplt.quiver(x1,x2,-u,-v)nfor itr in range(200):n gx, gy = dfunc3(x, y)n H = Hessian(x, y)n iH=np.linalg.inv(H)n v=np.array([[gx],[gy]])n hg=np.dot(iH, v)n gx, gy = hg[0,0], hg[1, 0]n xo,yo=x,yn x+=-0.1*gxn y+=-0.1*gyn plt.plot([xo,x],[yo,y])n plt.scatter([xo],[yo])n if(itr%20==0):n print("%.f %.5f %.5f"%(x, y, func3(x, y)))nplt.show()n
牛頓法雖然收斂速度快,但是需要計算函數的二階導數,計算複雜度很大,而且Hessian矩陣無法保持正定,因此提出了**擬牛頓法**,這個方法的基本思想是:不用二階偏導構造近似的H正定矩陣。
以下內容為轉載,並不是重要內容,這裡列出是為了幫助熟悉多維函數求導過程。
擬牛頓條件
對於我們的多維函數展開,兩邊同時除以dx:
這就是擬牛頓條件,H與H逆可以近似的用向量表示,通常近似用B與D表示
所以構造迭代:
帶入:
對應位相等:
轉換一下:
寫成程序:
#擬牛頓法nimport numpy as npnimport matplotlib.pyplot as pltndef func3(x, y):n return 4 + x**2 - 2 * y + 2*y**2 - x*2 - x*yndef dfunc3(x, y):n return np.array([[2*x-y-2],[4*y-x-2]])nnD=np.eye(2)nx1 = np.linspace(0,4,40)nx2 = np.linspace(0,4,40)nx1,x2=np.meshgrid(x1,x2)ng=dfunc3(x1, x2)nplt.quiver(x1,x2,-g[0,0],-g[1,0])nx, y = 2.5, 4ng=dfunc3(x, y)nfor itr in range(200):n g = dfunc3(x, y)n d = -np.dot(D, g)n s = 0.1*dn gx, gy = s[0,0], s[1, 0]n xo, yo=x, yn x+=float(gx)n y+=float(gy)n g2=dfunc3(x, y)n ys=g2-gn div=1/np.dot(ys.T, s)n mt1=np.eye(2)-np.dot(s,ys.T)*divn mt2=np.eye(2)-np.dot(ys,s.T)*divn mt3=np.dot(s,s.T)*divn D=np.dot(mt1,np.dot(D,mt2))+mt3n plt.plot([xo,x],[yo,y])n plt.scatter([xo],[yo])n if(itr%20==0):n print("%.f %.5f %.5f"%(x, y, func3(x, y)))nplt.show()n
推薦閱讀:
※李宏毅機器學習2016 第十七講 遷移學習
※數學 · CNN · 從 NN 到 CNN
※深度學習入門系列,用白話文的方式讓你看得懂學的快(第七章、第八章)