【機器學習】Cross-Validation(交叉驗證)詳解

本文章部分內容基於之前的一篇專欄文章:統計學習引論

在機器學習里,通常來說我們不能將全部用於數據訓練模型,否則我們將沒有數據集對該模型進行驗證,從而評估我們的模型的預測效果。為了解決這一問題,有如下常用的方法:

1.The Validation Set Approach

第一種是最簡單的,也是很容易就想到的。我們可以把整個數據集分成兩部分,一部分用於訓練,一部分用於驗證,這也就是我們經常提到的訓練集(training set)和測試集(test set)。

例如,如上圖所示,我們可以將藍色部分的數據作為訓練集(包含7、22、13等數據),將右側的數據作為測試集(包含91等),這樣通過在藍色的訓練集上訓練模型,在測試集上觀察不同模型不同參數對應的MSE的大小,就可以合適選擇模型和參數了。

不過,這個簡單的方法存在兩個弊端。

1.最終模型與參數的選取將極大程度依賴於你對訓練集和測試集的劃分方法。什麼意思呢?我們再看一張圖:

右邊是十種不同的訓練集和測試集劃分方法得到的test MSE,可以看到,在不同的劃分方法下,test MSE的變動是很大的,而且對應的最優degree也不一樣。所以如果我們的訓練集和測試集的劃分方法不夠好,很有可能無法選擇到最好的模型與參數。

2.該方法只用了部分數據進行模型的訓練

我們都知道,當用於模型訓練的數據量越大時,訓練出來的模型通常效果會越好。所以訓練集和測試集的劃分意味著我們無法充分利用我們手頭已有的數據,所以得到的模型效果也會受到一定的影響。

基於這樣的背景,有人就提出了Cross-Validation方法,也就是交叉驗證。

2.Cross-Validation

2.1 LOOCV

首先,我們先介紹LOOCV方法,即(Leave-one-out cross-validation)。像Test set approach一樣,LOOCV方法也包含將數據集分為訓練集和測試集這一步驟。但是不同的是,我們現在只用一個數據作為測試集,其他的數據都作為訓練集,並將此步驟重複N次(N為數據集的數據數量)。

如上圖所示,假設我們現在有n個數據組成的數據集,那麼LOOCV的方法就是每次取出一個數據作為測試集的唯一元素,而其他n-1個數據都作為訓練集用於訓練模型和調參。結果就是我們最終訓練了n個模型,每次都能得到一個MSE。而計算最終test MSE則就是將這n個MSE取平均。

y_i比起test set approach,LOOCV有很多優點。首先它不受測試集合訓練集劃分方法的影響,因為每一個數據都單獨的做過測試集。同時,其用了n-1個數據訓練模型,也幾乎用到了所有的數據,保證了模型的bias更小。不過LOOCV的缺點也很明顯,那就是計算量過於大,是test set approach耗時的n-1倍。

為了解決計算成本太大的弊端,又有人提供了下面的式子,使得LOOCV計算成本和只訓練一個模型一樣快。

其中hat{y_i}表示第i個擬合值,而h_i則表示leverage。關於h_i的計算方法詳見線性回歸的部分(以後會涉及)。

2.2 K-fold Cross Validation

另外一種折中的辦法叫做K折交叉驗證,和LOOCV的不同在於,我們每次的測試集將不再只包含一個數據,而是多個,具體數目將根據K的選取決定。比如,如果K=5,那麼我們利用五折交叉驗證的步驟就是:

1.將所有數據集分成5份

2.不重複地每次取其中一份做測試集,用其他四份做訓練集訓練模型,之後計算該模型在測試集上的MSE_i

3.將5次的MSE_i取平均得到最後的MSE

不難理解,其實LOOCV是一種特殊的K-fold Cross Validation(K=N)。再來看一組圖:

每一幅圖種藍色表示的真實的test MSE,而黑色虛線和橙線則分貝表示的是LOOCV方法和10-fold CV方法得到的test MSE。我們可以看到事實上LOOCV和10-fold CV對test MSE的估計是很相似的,但是相比LOOCV,10-fold CV的計算成本卻小了很多,耗時更少。

2.3 Bias-Variance Trade-Off for k-Fold Cross-Validation

最後,我們要說說K的選取。事實上,和開頭給出的文章里的部分內容一樣,K的選取是一個Bias和Variance的trade-off。

K越大,每次投入的訓練集的數據越多,模型的Bias越小。但是K越大,又意味著每一次選取的訓練集之前的相關性越大(考慮最極端的例子,當k=N,也就是在LOOCV里,每次都訓練數據幾乎是一樣的)。而這種大相關性會導致最終的test error具有更大的Variance。

一般來說,根據經驗我們一般選擇k=5或10。

2.4 Cross-Validation on Classification Problems

上面我們講的都是回歸問題,所以用MSE來衡量test error。如果是分類問題,那麼我們可以用以下式子來衡量Cross-Validation的test error:

其中Erri表示的是第i個模型在第i組測試集上的分類錯誤的個數。

圖片來源:《An Introduction to Statistical Learning with Applications in R》

說在後面

關於機器學習的內容還未結束,請持續關注該專欄的後續文章。

更多內容請關注我的專欄:R Language and Data Mining

或者關注我的知乎賬號:溫如

推薦閱讀:

如何流程挖掘/模擬?
Deep Matrix factorization Models for Recommender System
2017摩拜杯演算法挑戰賽 第三名團隊解決方案
誰來用最通俗易懂的語言跟我講一下k平均演算法(k means clustering)??
R實戰案例:利用演算法識別糖尿病患者(R語言實現)

TAG:机器学习 | 数据挖掘 | 数据分析 |