乾脆面君,你給我站住!你已經被TensorFlow盯上了
大數據文摘作品,轉載要求見文末
作者 | Dat Tran
編譯 | 康璐、元元、寧雲州
誰動了我最愛的乾脆面?!
美好的周五,大數據文摘的辦公室居然出現了一起偷竊事件。查看監控後,偉大的文摘菌很快用TensorFlow抓住了兇手,TA就是——一隻蠢萌的小浣熊!
來,一起聽文摘菌講講,這一簡易浣熊識別器是如何實現的吧~
文摘菌的這個浣熊識別器到底長啥樣呢?先給你看看最終效果~
小偷浣熊獨白:文摘菌,我不是故意要吃你的乾脆面的 ><
想知道這是如何實現的?在這篇文章中,我會詳細說明製作這個浣熊識別器的所有步驟。
為什麼要選擇浣熊???
不為什麼,就是可愛!!!
戳鏈接可觀看視頻:https://youtu.be/Bl-QY84hojs
建立數據集
讓我們開始吧!我們需要做的第一件事是建立自己的數據集:
- TensorFlow物體識別器API使用TFRecord文件格式,所以我們需要把最終數據集轉化成這種文件格式。
- 有幾種方法可以生成TFRecord文件。如果你的數據與PASCAL VOC數據集或者Oxford Pet數據集結構類似,可以利用現成的腳本(參考create_pascal_tf_record.py和create_pet_tf_record.py)。如果你的數據集不是上述的數據結構,你需要自己寫一個腳本來生成TFRecords(官網上有此做法的解釋)。我就是這麼做的。
- 為了準備API的輸入文件,你需要解決兩個問題。第一,你需要用jpeg或者png編碼的RGB的圖片,第二,你需要一個圖片的邊界框(xmin, ymin, xmax, ymax)並標識物體類別。我的所有圖片都只有一個類別,所有對我而言,這很簡單。
- 我從Google Images和Pixabay爬取了200張浣熊的圖片(主要是jpeg格式,也有個別是png格式),並且確保了圖片在大小、姿勢和光線方面有所區別。下面是我收集的一部分圖片。
- 然後,我用LabelImg手動給圖片打上了標籤。LabelImg是一個用Python編寫和用Qt做圖形交互的圖像標註工具。它可以支持Python2和3,但是我使用的是Python2和Qt4來從頭編譯,因為我用不了Python3和Qt5 。LabelImg非常好用,標註可以保存為PASCAL VOC格式的 XML文件。雖然我可以用create_pascal_tf_record.py腳本生成TFRecord文件,但我還是想自己編寫腳本。
- 不知道為什麼,在MAC OSX系統上LabelImg無法打開jpeg格式的圖片,所以我不得不把他們轉化成png格式然後再轉化回jpeg格式。實際上,因為API也能支持png格式,我不需要再轉化為jpeg格式,但是當時我還不知道這一點。下次我會直接使用png格式圖片。
- 最終,在標識了這些圖片後,我寫了一個腳本把XML文件轉化成csv格式並建立了TFRecord。我使用160張圖片來訓練(train.records),40張圖片來測試(test.records)。
注意:
- 我發現另外一個很好用的標註工具叫做FIAT (Fast Image Data Annotation Tool)。以後我可能也會試試它。
- ImageMagick可以在命令行上進行圖片處理,例如圖片格式轉換。假如你從未使用過,這個軟體值得一試。
- 通常來說,建立數據集是最費事的部分。我用了整整兩個小時來分類和標註圖片,這還是在我只需要分出一個類的前提下。
- 確保圖片是中型號的(參考google圖片來看什麼是中型圖片)。如果圖片太大了,你又沒有更改默認的批量大小設置,很可能會在訓練時因內存不足而報錯。
訓練模型
在建立好符合要求的API輸入文件後,就可以訓練模型了。
在訓練中,你需要下述部分:
- 一個物體識別訓練管道。Tensorflow官網上提供配置文件示例。我在訓練過程中使用ssd_mobilenet_v1_pets.config作為基礎配置。我需要把num_classed參數調整為1,並且為模型檢查點、訓練和測試文件、標籤映射設置路徑(PATH_TO_BE_CONFIGURED)。對於其他的配置,比如學習率、樣本量等等,我都使用默認設置。
注意:如果你的數據集多樣性不足,如比例、姿態等沒有太多變化,data_augmentation_option的設置值得選擇。完整的選線清單可以在這裡找到(參考PREPROCESSING_FUNCTION_MAP)。
- 數據集(TFRecord文件)和相對應的標籤映射。建立標籤映射的例子可以在下面看到,因為我只有一個類所以非常簡單。
注意:所有id編號都要從1開始,這是很重要的。0是一個佔位索引。
- (可選)訓練前的模型檢查點。推薦使用檢查點,因為從零開始訓練模型可能需要幾天才能得到好結果,所以最好能從之前訓練過的模型開始。官網上提供了幾個模型檢查點。在我的識別器中,我根據ssd_mobilenet_v1_coco模型開始訓練,因為模型訓練速度對我來說比準確度更重要。
開始訓練!
- 訓練可以在本地或者在雲端完成(AWS,Google雲等等)。如果你家有GPU(至少大於2GB),那你可以在本地完成工作,否則我建議使用雲端。我這次用的是Google雲,基本上是按照說明文檔一步步完成的。
- 對於Google雲,你需要定義一個YAML配置文件。官網提供有樣例文件,而且我基本上使用了默認配置。
- 我也建議在訓練時就開始評估工作。這樣可以監控整個流程,並且通過在本地運行TensorBoard來評估你的工作。
設置TensorBoard路徑: tensorboard — logdir=gs://${YOUR_CLOUD_BUCKET}
下面是我的訓練和評估工作結果。總體來說,我以批量大小24運行了一個小時,約22000步。在大概40分鐘時我已經得到了很好的結果。
因為是從預訓練模型開始訓練的,總誤差下降的很快。
因為我只有一個類,只需要看總體平均準確率就足夠了。
平均準確率在20000步的時候就達到了0.8,這個結果很不錯。
下面是在訓練模型的過程中,一個圖像評估的例子。
導出模型
- 在訓練完成之後,我把模型導出到一個文件中(Tensorflow graph proto),便於我用這個模型進行推論。
- 在我的課題中,我只能從Google雲中把模型檢查點拷貝到本地,然後用官網提供的腳本來導出模型。
我把訓練後的模型用在了我在youtube找的視頻上
- 看過這段視頻後你會發現有一些浣熊被漏掉了,也有一些誤判。這是合理的,因為我們只在一個小數據集上訓練了模型。如果要建立一個通用且穩定的識別器,(比如你需要它能識別最有名浣熊——銀河護衛隊裡面的火箭浣熊),我們需要的只是更多數據。這也是AI現在的局限性之一。
結論
在本文中,我只使用了一個類,因為我懶得標註更多數據。有很多公司比如CrowdFlower、 CrowdAI和Amazon』s Mechanical Turk均提供標註服務,但是本文還用不到這樣的服務。
我用了很短的訓練時間就得到了相當不錯的結果,這也是由於識別器只需要訓練一個類。對於多類別的情況,總平均準確率就不會這麼高了,也需要更長的訓練時間來獲得好的結果。實際上,我也在Udacity提供的帶標註的駕駛數據集上訓練了識別器。訓練一個能識別小汽車、卡車和行人的識別器花了很長時間。很多其他類似的案例中可能需要使用更複雜的模型。我們還要考慮在模型速度和模型準確度之間尋找平衡。
原文鏈接:https://medium.com/towards-data-science/how-to-train-your-own-object-detector-with-tensorflows-object-detector-api-bec72ecfe1d9
推薦閱讀:
※【博客存檔】TensorFlow之深入理解Neural Style
※TensorFlow小試牛刀(1):CNN圖像分類
※2.2 RNN入門
※2.1 TensorFlow實踐-入門與數字識別示例解析(1)
TAG:TensorFlow | 大数据 |