網格搜索演算法與K折交叉驗證

網格搜索演算法和K折交叉驗證法是機器學習入門的時候遇到的重要的概念。

網格搜索演算法是一種通過遍歷給定的參數組合來優化模型表現的方法。

以決策樹為例,當我們確定了要使用決策樹演算法的時候,為了能夠更好地擬合和預測,我們需要調整它的參數。在決策樹演算法中,我們通常選擇的參數是決策樹的最大深度

於是我們會給出一系列的最大深度的值,比如 {max_depth: [1,2,3,4,5]},我們會儘可能包含最優最大深度。

不過,我們如何知道哪一個最大深度的模型是最好的呢?我們需要一種可靠的評分方法,對每個最大深度的決策樹模型都進行評分,這其中非常經典的一種方法就是交叉驗證,下面我們就以K折交叉驗證為例,詳細介紹它的演算法過程。

首先我們先看一下數據集是如何分割的。我們拿到的原始數據集首先會按照一定的比例劃分成訓練集和測試集。比如下圖,以8:2分割的數據集:

訓練集用來訓練我們的模型,它的作用就像我們平時做的練習題;測試集用來評估我們訓練好的模型表現如何,它的作用像我們做的高考題,這是要絕對保密不能提前被模型看到的。

因此,在K折交叉驗證中,我們用到的數據是訓練集中的所有數據。我們將訓練集的所有數據平均劃分成K份(通常選擇K=10),取第K份作為驗證集,它的作用就像我們用來估計高考分數的模擬題,餘下的K-1份作為交叉驗證的訓練集。

對於我們最開始選擇的決策樹的5個最大深度 ,以 max_depth=1 為例,我們先用第2-10份數據作為訓練集訓練模型,用第1份數據作為驗證集對這次訓練的模型進行評分,得到第一個分數;然後重新構建一個 max_depth=1 的決策樹,用第1和3-10份數據作為訓練集訓練模型,用第2份數據作為驗證集對這次訓練的模型進行評分,得到第二個分數……以此類推,最後構建一個 max_depth=1 的決策樹用第1-9份數據作為訓練集訓練模型,用第10份數據作為驗證集對這次訓練的模型進行評分,得到第十個分數。於是對於 max_depth=1 的決策樹模型,我們訓練了10次,驗證了10次,得到了10個驗證分數,然後計算這10個驗證分數的平均分數,就是 max_depth=1 的決策樹模型的最終驗證分數。

對於 max_depth = 2,3,4,5 時,分別進行和 max_depth=1 相同的交叉驗證過程,得到它們的最終驗證分數。然後我們就可以對這5個最大深度的決策樹的最終驗證分數進行比較,分數最高的那一個就是最優最大深度,我們利用最優參數在全部訓練集上訓練一個新的模型,整個模型就是最優模型

下面提供一個簡單的利用決策樹預測乳腺癌的例子:

from sklearn.model_selection import GridSearchCV, KFold, train_test_splitnfrom sklearn.metrics import make_scorer, accuracy_scorenfrom sklearn.tree import DecisionTreeClassifiernfrom sklearn.datasets import load_breast_cancernndata = load_breast_cancer()nnX_train, X_test, y_train, y_test = train_test_split(n data[data], data[target], train_size=0.8, random_state=0)nnregressor = DecisionTreeClassifier(random_state=0)nparameters = {max_depth: range(1, 6)}nscoring_fnc = make_scorer(accuracy_score)nkfold = KFold(n_splits=10)nngrid = GridSearchCV(regressor, parameters, scoring_fnc, cv=kfold)ngrid = grid.fit(X_train, y_train)nreg = grid.best_estimator_nnprint(best score: %f%grid.best_score_)nprint(best parameters:)nfor key in parameters.keys():n print(%s: %d%(key, reg.get_params()[key]))nnprint(test score: %f%reg.score(X_test, y_test))nnimport pandas as pdnpd.DataFrame(grid.cv_results_).Tn

直接用決策樹得到的分數大約是92%,經過網格搜索優化以後,我們可以在測試集得到95.6%的準確率:

best score: 0.938462nbest parameters:nmax_depth: 4ntest score: 0.956140n

推薦閱讀:

Scikit-learn(sklearn)0.19 中文文檔翻譯計劃
sklearn中各種分類器回歸器都適用於什麼樣的數據呢?
scikit-learn中如何保存模型?

TAG:机器学习 | sklearn | 算法 |