標籤:

過擬合與正則化

前面小節都是在以較為簡單的一次函數在擬合數據,但是現實情況不一定都是這麼簡單了,這一節將以多項式回歸為引例,來討論一下訓練過程中經常出現的過擬合現象,然後重點闡述一下正則化方法在解決過擬合問題中的應用。

首先創建數據,並用sklearn的多項式回歸處理,得到擬合曲線。

多項式回歸得到的預測函數為  2.04598368+1.06157405x+0.52050879x^2 ,這與我們實際設置的函數  2+x+0.5x^2 相差不大,擬合效果較好。但是實際情況中,我們並不知道數據的分布情況,現在我如果用30階的函數去擬合,會是什麼效果呢?

可以看出,模型在盡量多的覆蓋數據點,以至於有一些「過」了,當數據里有一些雜訊時,模型也會去擬合雜訊點,這樣顯然就不會得到精準的模型,這也就是出現了「過擬合」現象。過擬合的模型在訓練集中會表現的很好,但是在測試集中會表現的一塌糊塗。所以,採用交叉測試的方法可以檢測模型是否過擬合。另外,繪製學習曲線(Learning curves)也可以檢測過擬合。

那麼,在學習過程中,什麼情況會造成過擬合呢?一般來說,可以分為以下四個情況:

① 數據量太少。

② 隨機雜訊高

③ 確定性雜訊高

④ 過量的VC維

關於這些原因的具體討論可以去看看台灣大學林軒田老師的機器學習基石的教程。

在機器學習中,無論是分類還是回歸,都可能存在由於特徵過多而導致的過擬合問題。當然解決的辦法有

(1)減少特徵,留取最重要的特徵。

(2)懲罰不重要的特徵的權重。

為了減小過擬合發生的情況,除了通過清洗數據來減小雜訊以外,另一種常用的方法就是正則化(Regularization)。通常情況下,我們不知道應該懲罰哪些特徵的權重取值。通過正則化方法可以防止過擬合,提高泛化能力。

具體處理方法就是在損失函數後面加上一個正則項,這裡以L2正則項為例。

 Jleft( 	heta 
ight) = frac{1}{2m}left( sum_{i=1}^m{left( h_{	heta}left( x^{left( i 
ight)} 
ight) -y^{left( i 
ight)} 
ight) ^2}+lambda sum_{j=1}^m{	heta _{j}^{2}} 
ight)

注意是從1開始的。對其求偏導後得到

frac{partial Jleft( 	heta 
ight)}{partial 	heta _j}=frac{1}{m}sum_{i=1}^m{x_jleft( h_{	heta}left( x 
ight) -y 
ight)}+frac{lambda}{m}	heta _j

梯度下降表達式如下

其中,  left( 1-frac{alpha lambda}{m} 
ight) <1 ,也就是說,  	heta _j 的權值得到了衰減,那麼為什麼權值衰減就能防止過擬合呢 ?

首先,我們要知道一個法則-奧卡姆剃刀,用更少的東西做更多事。從某種意義上說,更小的權值就意味著模型的複雜度更低,對數據的擬合更好。貼一張網上看到的圖,解釋的比較好。

試想一下,在上圖中,如果不加正則化項,那麼最優參數對應的等高線離中心點的距離可能會更近,加入正

則化項後使得訓練出的參數對應的等高線離中心點的距離不會太近,也不會太遠。從而避免了過擬合。

使用L2正則化的回歸模型我們叫做Ridge Regression,而使用L1正則項的,叫做Lasso Regression,兩個都用的就叫做Elastic Net,這三種正則化回歸的損失函數表示為:

Ridge Regression:

 Jleft( 	heta 
ight) =MSEleft( 	heta 
ight) +frac{alpha}{2}sum_{i=1}^n{	heta _{i}^{2}}

Lasso Regression:

Jleft( 	heta 
ight) =MSEleft( 	heta 
ight) +alpha sum_{i=1}^n{left| 	heta _i 
ight|}

Elastic Net:

 Jleft( 	heta 
ight) =MSEleft( 	heta 
ight) +ralpha sum_{i=1}^n{left| 	heta _i 
ight|}+frac{1-r}{2}alpha sum_{i=1}^n{	heta _{i}^{2}}

如圖,為Ridge Regression在不同參數下的擬合情況:

參考

[1]Hands On Machine Learning with Scikit Learn and TensorFlow

[2]台灣大學林軒田——機器學習基石

[3]L2正則化方法

推薦閱讀:

馬庫斯:DeepMind新出的機器心智網路不錯,但有誤導性
《機器學習基石》課程學習總結(一)
2018AI學習清單丨150個最好的機器學習和Python教程
CS259D:數據挖掘與網路安全講義筆記

TAG:機器學習 |