機器學習 | 更進一步,用評估器給花卉分類

機器學習總的來說還是非常贊的(當然,除了某些時候你不得不對付複雜惱人的數學運算)。當下機器學習相關工具已經得到很大的改善,並且訓練模型從未像今天這樣簡單。

我們將利用對數據集的認知來編寫洞察數據的模型,而不是基於對原始數學問題的理解。

https://www.zhihu.com/video/909895884572094464

這一期,我們將會用一些趁手的代碼來實操訓練出一個簡單的分類程序,所用到的程序可以在這個 GitHub Gist 當中得到:AIA003_tf_estimators.ipynb hosted with ? by GitHub;

TensorFlow 中專門解決機器學習問題的評估器

要訓練我們的分類模型,我們會使用 Google 的開源機器學習庫— TensorFlow。TensorFlow 對外暴露了非常豐富的 API 介面,但是我們要使用的僅僅是一些高層次 API,也就是評估器(Estimator)。

評估器替我們打包好了訓練的循環過程,所以我們可以通過配置評估器來訓練模型,而不是手動碼出訓練循環。這消耗了大量的樣板,使我們能夠在更高級的抽象中考慮問題。同時這也意味著我們可以享受機器學習中充滿趣味的部分,而不是糾結實現的細節。

既然我們目前僅僅討論過線性模型的設計,那我們就此打住。今後我們會有機會重新來討論這個問題,並提升其識別能力。

花卉分類:是否和酒水判別同樣充滿趣味?

這周我們會構建一個模型來區分三種非常相似的花卉。雖說這次比不上上次那樣同酒水打交道那麼令人嚮往,但是由於花卉更難辨認,所以這次的任務會更具挑戰性。

上期回顧:機器學習「七步走」

特別是我們今天需要區分開不同品種的鳶尾花。目前我不太確定自己能不能從一片玫瑰當中找出一枝鳶尾花,但是我們的模型將會準確認出山鳶尾、變色鳶尾和維吉尼亞鳶尾。

山鳶尾、變色鳶尾和維吉尼亞鳶尾

我們有一個數據集(如下圖)記錄了花卉的花瓣、萼片的寬高數據。表中的四列也就是之前提到的「特徵」。

載入數據

在引入 TensorFlow and NumPy 之後,我們需要使用 TensorFlow 的 load_csv_with_header 函數來載入數據集。這些數據(或者說特徵),都以浮點數的形式呈現,而每一列數據、目標花卉的標記則用 0、1 和 2 來表示,與三種花卉的品類相對應。

我已經將載入數據的結果輸出出來,現在我們可以通過命名屬性來取得訓練數據和相關的標記、目標。

構建模型

下一步我們需要構建模型了。為了完成這一步操作,首先需要設定特徵列。特徵列定義了進入模型當中的數據類型。我們使用四個維度的特徵列來表示數據集中的特徵,並將之稱為「花卉特徵」。

搞定評估器非常簡單。 我們通過傳入特徵列、模型預測的輸出數量(此處為 3)和指定的用於存儲模型訓練進程及輸出結果的目錄來調用 tf.estimator.LinearClassifier 就能實例化模型了。這些操作讓 TensorFlow 從上次中斷的地方繼續進行。

輸入函數

上述的 classifier 對象會為我們記錄狀態,並且此時我們差不多可以開始訓練了。勝利在望,我們的模型距離成功連接訓練數據僅僅只差一個輸入函數。輸入函數的主要工作是創建一個可以為模型生成數據的 TensorFlow 操作。

現在,所以我們已經完成了從處理原始數據,到創建輸入函數(傳入之後會以特徵列來映射的數據)的過程。要注意的是我們使用與元數據中同樣的特徵列名稱作為特徵值的標記,這樣數據與模型訓練才能對應起來。

開始訓練

接下來開展訓練工作。只需將輸入函數傳入 classifier.train() 方法就可以了。這就是我們將數據與模型聯繫起來的操作。

訓練函數會控制在數據集上循環或迭代的過程,同時在每一階中不斷提升自身性能。正如我們所料,下圖中我們已經成功完成了 1000 個階的訓練!我們的數據集並不算大,所以這個過程非常快。

精確度評估

好啦,是時候評估結果了。由於前面的 classifier 對象保存了模型訓練的狀態,所以我們此處仍然使用同一個對象來評估。要評估模型的優劣,我們通過調用 classifier.evaluate() 並傳入測試數據,然後從返回的矩陣當中提取出精確度數據即可。

我們得到的精確度是 96.88%,厲害了我的哥!

評估器:一路向前的流水線

我們這周就先到這裡,並且一起回顧學習評估器的收穫。

評估器 API 為我們提供了一個優秀的流水線用於獲取元數據、傳入輸入函數、設定特徵列和模型結構、運行訓練過程和進行預測。這個便於理解的框架讓我們能夠關注數據和他們的屬性,而不需要一直糾結數學上的問題,實在是贊!

下回預告

這次我們一同用封裝好的評估器來淺嘗了簡單版的 TensorFlow 高級 API。後面我們還會了解怎樣為模型加入更多細節、使用更複雜的數據和加入更多高級功能。

敬請期待!


▏原文出處:Plain and Simple Estimators

▏封面來源:YouTube 視頻預覽圖

▏視頻出處:Plain and Simple Estimators - YouTube

▏字幕翻譯:谷創字幕組

▏文章編輯: @楊棟


推薦閱讀:

機器學習基礎:線性回歸
深度學習框架跑分測驗(TensorFlow/Caffe/MXNet/Keras/PyTorch)
深度卷積GAN之圖像生成
Docker--深度學習環境配置一站式解決方案

TAG:机器学习 | 谷歌Google | TensorFlow |