機器學習入門札記(一)用Naive Bayes識別手寫數字
一、問題描述
分類(classification)問題,是人們在日常生活中經常遇到的問題,也是機器學習研究的重要方向。我們經常見到的垃圾郵件識別系統,可能是最早的商用機器分類應用,其通過關鍵字匹配,達到判斷垃圾郵件的目的。近年來,隨之機器學習研究的深入,越來越多的分類應用深入到了我們的生活。銀行通過申請人的個人信息,判斷信用卡持有人的風險級別;醫療系統,通過各項生理指標,判斷病人的病症;無人車通過圖像信息,識別前進方向的物體。
二、演算法解析
一般地,對於數據集 , 代表第 i 項數據的特徵,它是一個多維度向量,每一個維度代表該項數據的一個值(features),可以是諸如籍貫、性別等分類型數據(categorical data),也可以是身高、血糖濃度等數值型數據(numerical data)。 是對應第 i 項數據的分類(label),比如,在判斷病人是否為糖尿病患者的應用中, 的可能值就是 1 (是糖尿病患者)或者0 (不是糖尿病患者)。在機器學習中,我們有時把數據集分成訓練集(training set)和測試集(test set)兩部分,通過訓練集的「訓練」,找到合適分類方法(classifier),再通過測試集去計算此種分類方法的準確率(accuracy)。
樸素貝葉斯(Naive Bayes)是一種簡單又高效的分類演算法。其基本思路是,使用貝葉斯法則計算 k = 0,1,...的值,哪個結果所對應的概率越高,我們就把這項數據的類型定為對應的 。上述的模型用數學語言表示為
(1)
我們的目標是,找到能使上式最大的 。 注意到,(1)式假設不同的features之間的分布是獨立的,這當然是一個值得懷疑的假設。但是從實驗看來,即是這個假設有問題,我們依然能得到較為準確的結論。
三、數據準備
MNIST 資料庫是一個手寫數學圖片庫,是美國普查員和高中學生的手寫數字。其包含60000張訓練圖片和10000張測試圖片,是常用的分類問題資料庫。
注意到該數據在官網以binary data的形式保存,如何從binary data轉化為常見的csv格式不是本文的主題,有興趣的朋友可以Google MNIST csv了解細節。或者直接從Digit Recognizer | Kaggle 網上獲取資料庫的csv版本。以下是R讀取csv數據的R代碼。
#load_data.rtrain = read.csv("mnist_train.csv", header=TRUE)test = read.csv("mnist_test.csv", header=TRUE)#把label定為分類型數據train[, 785] <- as.factor(train[,785])test[, 785] <- as.factor(test[,785])#grey the datatrain[,1:784][train[,1:784] <= 128] <- 0train[,1:784][train[,1:784] > 128] <- 1test[,1:784][test[,1:784] <= 128] <- 0test[,1:784][test[,1:784] > 128] <- 1
注意到train set的每一行代表一張黑白圖片,一行共有785個數字,其中前面784個數字代表28見方的圖像,值取0到255之間,0代表白,255代表黑色。為了簡化運算,我們把0-128之間的顏色,強行轉化為白色,反正為黑色。每行最後一個數字是該行圖片的分類,即0到9其一。
#show_data.rlibrary(imager)DIM = 28 show_image <- function(train_item){ DIM <- dim(train_item)[2]^0.5 m <- matrix(unlist(train_item),DIM,DIM) mi <- as.cimg(m,dim=c(DIM,DIM,1,1)) par(mar=c(1,1,1,1)) plot(mi)}
為了方便,我們寫了一個show_image函數呈現圖片。
第一行的圖片顯示為5,和實際分類一致。
四、代碼實現
注意到(1)式中,對於每一個label,分母都一樣,所以我們只需要比較分子的大小。由於分子出現了700餘項概率相乘,可見是一個非常小的數字。我們對分子施加自然對數,以免指數下溢。由此,我們可以把(1)式改為下式。
統計學裡,把後項稱為prior,把前面j項和稱為likelihood,把兩者之和稱為posterior。 這裡對 的計算稍作解釋。對於一個給定的類(如y=1),我們需要找到train data中所有y=1的數據。再計算在 j 的位置上,上述數據集中對應項等於 的數據比例,這個比例即是 。
# naive_bayes_classifier.r# find the subset in the training set and the priors P(yk)train_class <- list()priors <- rep(0,10)means <- list()for (i in seq(0,9)) { train_class[[i+1]] <- train[which(train$class == i),] priors[i+1] <- dim(train_class[[i+1]])[1]/60000 means[[i+1]] <- apply(train_class[[i+1]][,1:784],2,mean) # means save all the percentages of 1 in each pixel. }classify <- function (test_item) { posteriors <- rep(0,10) for (i in seq(0,9)) { #calculate likelihood P(x|yi) likelihood <- 0 for (j in seq(1,784)) { # iterate 784 pixels. # if it is 1, P(xj|yi) is the mean; otherwise P(xj|yi) is 1 - mean likelihood <- likelihood + ifelse(test_item[j],means[[i+1]][j],1-means[[i+1]][j]) } posteriors[i+1] <- priors[i+1] + likelihood } return (which.max(posteriors))}
為了檢驗演算法,我們將所有測試數據帶入,計算預測分類,並核對正確率。
#build the output summary tablesummary <- matrix(rep(0,100),ncol=10,byrow=TRUE)colnames(summary) <- seq(0,9)rownames(summary) <- seq(0,9)#check each individual in test setfor (k in 1:dim(test)[1]) { predict = classify(test[k,1:784]) label = as.integer(test[k,785]) summary[predict,as.integer(label)] = summary[predict,label] + 1}summary <- as.table(summary)print.table(summary)cat("Accuracy: ",(sum(diag(summary)))/sum(summary))
注意到對10000個數據,程序到運行時間較長,會超過3小時,這可能是由classifier的for loop引起的。因此改寫該函數,用matrix operation代替loop。
classify2 <- function (test_item) { posteriors <- rep(0,10) test_item_num_zeros = 784 - sum(test_item) test_item[test_item==0] <- -1 for (i in seq(0,9)) { #calculate likelihood P(x|yi) likelihood <- matrix(unlist(test_item),nrow=1) %*% matrix(means[[i+1]],ncol=1) + test_item_num_zeros posteriors[i+1] <- priors[i+1] + likelihood } return (which.max(posteriors))}
新的函數,程序運行不到10分鐘即得到結論:整體的正確率為58%,注意到這個演算法把很多不是1的圖片識別為1。
可以從上圖計算出這個演算法的precision和sensitivity,可見目前該演算法對0、4、6、7、9的識別都不錯。
五、項目改進
注意到每張圖片里數學的大小不一,位置各異,可以考慮對各圖片進行預處理,將照片居中放大顯示,相信能得到更好的正確率(約80%)。
進一步也可以嘗試不再灰度數據,用高斯分布計算 。
這裡,我們就不詳細描述了。
Reference:
[1] D.A. Forsyth, Applied Machine Learning
[2] Statistical classification
推薦閱讀:
※Learning Explanatory Rules from Noisy Data 閱讀筆記2
※【機器學習Machine Learning】資料大全
※【翻譯】Brian2高級指導_Brian如何工作
※學習筆記CB001:NLTK庫、語料庫、詞概率、雙連詞、詞典
TAG:機器學習 |