標籤:

用TensorFlow做Kaggle「手寫識別」達到98%準確率-詳解

歡迎關注我們的微信公眾號「人工智慧LeadAI」(ID:atleadai)

這是一個TensorFlow的系列文章,本文是第三篇,在這個系列中,你講了解到機器學習的一些基本概念、TensorFlow的使用,並能實際完成手寫數字識別、圖像分類、風格遷移等實戰項目。

文章將盡量用平實的語言描述、少用公式、多用代碼截圖,總之這將是一份很贊的入門指南。歡迎分享/關注。

上一期,我們用Tensorflow實現了Kaggle的手寫識別項目,但準確率比較低,只有92%,這次我們打算把識別的準確率提升到98%以上。

為什麼不是上次說的提升到99%以上呢?因為92%到98%是比較容易的,而再從98%到99%是要費不少功夫的,一篇文章難以承載這麼多內容,所以將會分成兩篇文章,首先是從92%到98%,下一次是從98%到99%。

不要小看提升1%,越往後面,難度就越大。如果我們做到99%準確率,在Kaggle的手寫識別這個項目上,也就進入了前25%了,可以說入門了。

回顧上期

上期我們學習了梯度下降、神經網路、損失函數、交叉熵等概念,然後用42000張圖片數據訓練了一個簡單的神經網路,準確度92%。可以說,這只是一個Hello World。

Hello World

如何進行改進

首先,這次我們將使用卷積神經網路來進行圖片識別。眾所周知,卷積神經網路對於圖片識別是非常有效的。

這裡我打算這樣來構建這個卷積神經網路:

卷積層1+池化層1+卷積層2+池化層2+全連接1+Dropout層+輸出層

然而,什麼是卷積神經網路?什麼是卷積層、池化層、全連接層?Dropout又是什麼鬼?

1、什麼是卷積神經網路?

我們人看到一幅圖像,眨眼之間就知道圖像中有什麼,圖像中的主體在幹什麼。但計算機不同,計算機看到的每一副圖像都是一個數字矩陣。那我們怎麼讓計算機從一個個數字矩陣中得到有用的信息呢,比如邊緣,角點?更甚一點,怎麼讓計算機理解圖像呢?

對圖像進行卷積,就是接近目標的第一步。

圖像在計算機里的表示可能是這樣的:

對圖像卷積,就是求卷積核作用在圖像後,得到的圖像對於該卷積核的累加數值。這些累加的數值可以代表這個圖片的一些特徵。

如果是針對貓進行識別,人可能知道貓頭,貓尾巴等特徵。CNN對圖片進行處理後,也會學習到一些特徵,它可能不知道貓頭、貓尾巴這些特徵,但也會識別出一些我們可能看不出來的特徵,CNN通過這些學習到的特徵去做判斷。

2、什麼是卷積層

卷積層的作用是指對圖片的矩陣進行卷積運算,得到一些數值,作為圖片的某些特徵

3、什麼是池化層

池化曾的作用是對上層的數據進行採樣,也就是只留下一部分,這樣的作用是可以縮小數據量和模糊特徵。

4、什麼是全連接層

全連接層就是連在最後的分類器。前面卷積層和池化層進行處理後,得到了很多的特徵,全連接層使用這些特徵進行分類。比如識別數字,那就是對0~9的十個類別進行分類。

5、Dropout是什麼?

Dropout層是為了防止CNN對訓練樣本過擬合,而導致處理新樣本的時候效果不好,採取的丟棄部分激活參數的處理方式。

這裡對這些概念的解釋都是比較簡單的,如果希望詳細了解,可以看知乎的這個鏈接:

CNN卷積神經網路是什麼?

代碼實現

1 標籤的處理

2 把數據分為訓練集和驗證集

3 定義處理數據的函數

4 定義網路的結構

5 定義各類參數

6 進行訓練

生成結果

這裡迭代20個周期:

7 驗證集上的準確度

然後我們使用這個模型對Kaggle的測試集進行預測,並生成cvs格式的結果

8 生成結果

這裡建議跑30輪以上,因為在驗證集上有98.35%準確率,上傳到Kaggle往往就只有百分之九十七點幾的準確率了。

推薦閱讀:

泰坦尼克號倖存預測n ——Kaggle排名321名(前4%)
Kaggle入門系列:(二)Kaggle簡介
遺憾未進前10%, Kaggle&Quora競賽賽後總結

TAG:TensorFlow | Kaggle |