python3機器學習經典實例-學習筆記8-分類演算法
可視化混淆矩陣
混淆矩陣是我們用來理解分類模型性能的表格。 這有助於我們理解如何將測試數據分類到不同的類中。 當我們想微調我們的演算法時,我們需要了解在做出這些更改之前數據是如何被錯誤分類的。 有些種類比其他課程更糟糕,混淆矩陣將幫助我們理解這一點。 我們來看看下圖:
在前面的圖表中,我們可以看到我們如何將數據分類到不同的類中。 理想情況下,我們希望所有非對角線元素都為0.這表明完美的分類!讓我們考慮class 0。總體而言,52個項目實際上屬於class 0。如果我們總結第一行中的數字,則得到52。 現在,這些項目中有45項被正確預測,但是分類器說其中4項屬於class 1,3項屬於class 2。我們可以對其餘兩行應用相同的分析。值得注意的是,來自class 1的11個項被錯誤分類為class 0。這構成了該類中約16%的數據點。 這是我們可以用來優化模型的見解。
- 導入必要的資料庫
import numpy as npimport matplotlib.pyplot as pltfrom sklearn.metrics import confusion_matrix
- 生成數據調用confusion_matrix模塊
y_true = [1, 0, 0, 2, 1, 0, 3, 3, 3]y_pred = [1, 1, 0, 2, 1, 0, 1, 3, 3]confusion_mat = confusion_matrix(y_true, y_pred)
- 定義顯示的結
# Show confusion matrixdef plot_confusion_matrix(confusion_mat): plt.imshow(confusion_mat, interpolation=nearest, cmap=plt.cm.gray) plt.title(Confusion matrix) plt.colorbar() tick_marks = np.arange(4) plt.xticks(tick_marks, tick_marks) plt.yticks(tick_marks, tick_marks) plt.ylabel(True label) plt.xlabel(Predicted label) plt.show()
我們使用imshow函數來繪製混淆矩陣。 其他功能都很簡單! 我們只需使用相關功能設置標題,顏色條,標記和標籤。 tick_marks參數的範圍從0到3,因為我們在數據集中有四個不同的標籤。 np.arangefunction給了我們這個numpy數組。
- 進行顯示結果
plot_confusion_matrix(confusion_mat)
輸出結果:
對角線的顏色很強烈,我們希望它們的顏色變得深。 淺黃色表示零。 非對角線空間中有幾個綠色,表示錯誤分類。 例如,當真實標籤為0時,預測標籤為1,如我們在第一行中所看到的。 事實上,所有的錯誤分類屬於第一類,因為第二列包含三個非零的行。 從圖中很容易看到這一點。
- 提取性能報告
# Print classification reportfrom sklearn.metrics import classification_reporttarget_names = [Class-0, Class-1, Class-2, Class-3]print (classification_report(y_true, y_pred, target_names=target_names))
輸出的結果:
precision recall f1-score support Class-0 1.00 0.67 0.80 3 Class-1 0.50 1.00 0.67 2 Class-2 1.00 1.00 1.00 1 Class-3 1.00 0.67 0.80 3avg / total 0.89 0.78 0.79 9
結果分析
補充:
- TP: 預測為1(Positive),實際也為1(Truth-預測對了)
- TN: 預測為0(Negative),實際也為0(Truth-預測對了)
- FP: 預測為1(Positive),實際為0(False-預測錯了)
- FN: 預測為0(Negative),實際為1(False-預測錯了)
Accuracy = (預測正確的樣本數)/(總樣本數)=(TP+TN)/(TP+TN+FP+FN)
Precision = (預測為1且正確預測的樣本數)/(所有預測為1的樣本數) = TP/(TP+FP)
Recall = (預測為1且正確預測的樣本數)/(所有真實情況為1的樣本數) = TP/(TP+FN)
推薦閱讀:
※深度學習與故障診斷的另一次嘗試
※EdX-Columbia機器學習課第5講筆記:貝葉斯線性回歸
※Tensorflow入門教程(7)
※Embedding向量召回在蘑菇街的實踐