如何用TensorFlow在安卓手機上識別皮卡丘?

如何用TensorFlow在安卓手機上識別皮卡丘?

來自專欄景略集智

前段時間我們分享了怎麼用 TensorFlow 和 Swift 做個 iOS 應用識別 Taylor Swift:

景略集智:如何用TensorFlow和Swift寫個App識別霉霉??

zhuanlan.zhihu.com圖標

那麼可能有朋友就會想:在安卓手機上怎麼玩?

今天就分享一下如何藉助 TensorFlow 在安卓手機上識別物體,不過這次不是識別「霉霉」了,而是人見人愛的小機靈鬼——皮卡丘

TensorFlow Object Detection API 有很多用途,從名字上也能看出來,它的用途就是訓練能夠識別出物體的神經網路。

這個庫的用處可以說是幾乎沒有限制,可以拿它訓練 AI 識別照片里的人類、貓、狗、汽車等等。所以我(作者 Juan De Dios Santos——譯者注)就想搞點有意思事情,自己用它訓練一個模型,幹嘛呢?識別皮卡丘!

本文我會分享訓練模型識別皮卡丘的詳細步驟,你也可以自己試著玩一玩。首先,我會簡單講講 TensorFlow Object Detection API,總結一下要點;然後,我會講講怎樣將我收集的皮卡丘照片轉換為正確的格式,創建數據集;接著我會詳細分享模型的訓練過程以及如何用TensorBoard 評估模型;最後,我會展示如何在 Python Notebook 中使用訓練好的模型,以及將模型導出到 Android 的過程。

本項目的代碼地址見文末,文中提及的代碼腳本均來自該項目代碼庫。

在開始前,先展示一下識別皮卡丘的示例,可能大家會有點好奇:

TensorFlow Object Detection API

TensorFlow Object Detection API 是 TensorFlow 中用於解決物體檢測問題的庫,所謂物體檢測就是在一個畫面幀中檢測各種物體(比如皮卡丘)的過程。這個庫比較特殊的一點是,它能根據不同的速度和內存使用情況變換模型的準確度,因此這樣你就可以讓模型適配於你自己的需求和所用平台(比如手機)。這個庫還有一些非常創新的物體檢測架構,比如 SSD、Faster R-CNN、R-FCN 和 MobileNet等。

而且它還提供了一些訓練好的物體檢測模型,支持谷歌雲端運行,以及支持使用TensorBoards 監控模型的訓練。

簡單了解後,下面講講怎樣搭建你自己的自定義模型。

搭建你的自定義模型

安裝

在我們開始前,確保你的電腦上安裝了TensorFlow。如果沒有,可以參考集智主站分享的這篇安裝教程:jizhi.im/blog/post/inst

接著克隆包含 Object Detection API 的代碼庫,鏈接在此:github.com/tensorflow/m

完成克隆後,導航至「Research」目錄,執行如下代碼:

# From tensorflow/models/research/protoc object_detection/protos/*.proto --python_out=.

這會編譯 Protobuf 庫,如果還沒有安裝 Protobuf 庫,可以從這裡安裝:developers.google.com/p

最後,你需要將 Protobuf 庫添加至 PYTHONPATH,執行代碼如下:

# From tensorflow/models/research/export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim

更多需要安裝的環境依賴,請看這裡的官方文檔:

github.com/tensorflow/m

現在我們如何用 TensorFlow 將皮卡丘圖像轉換為乾淨整潔的數據集。

創建數據集,處理圖像

創建數據集是成功訓練模型所需的第一步,在這部分,我就談談創建數據集需要的所有步驟。

對於識別皮卡丘的項目,我下載了 230 個中等大小的皮卡丘照片,將保存照片的目錄命名為「images」。為了能有更好的結果,我試著又找了些不同角度和不同形狀的皮卡丘的照片,不過說實話,皮卡丘就是個黃色皮膚、長耳朵的小精靈,並不是自然界真實存在的動物,所以找到大量各種各樣的皮卡丘照片還是有點困難的。

等獲取完圖像後,下一個任務就是為它們打標籤。這是什麼意思?因為我們要做的是物體檢測,所以需要明確物體是什麼。這一步,我們需要圍繞物體畫一個邊界框,告訴系統框里的這個東東實際上就是我們要學習識別的物體。用 RectLabel 可以完成這項任務。

下面就是帶著邊界框的圖像的樣子:

在 RectLabel 里,你需要為圖像的每個邊界框設置一個標籤,在這裡我將標籤設為了「Pikachu」。等為所有圖像標上標籤後,你會發現得到一個名為「annotations」的目錄,內含很多描述了每張圖像的邊界框的 xml 文件。

將數據集分割為訓練集和測試集

等為所有圖像標上標籤後,下一步就是將數據集拆分為訓練集和測試集。在圖像所在的同一目錄里,我創建了一個名為「train」和一個名為「test」的目錄,並向訓練目錄里添加了約70%的圖像及其 xml 文件,向測試目錄里添加了約 30% 的圖像及其 xml 文件。

創建 TFRECORD

在拆分數據集後,我們還需要將圖像及其 xml 文件的格式轉換成 TensorFlow 可讀取的格式,也就是 Tfrecord 格式。為圖像創建這個格式,需要兩個步驟。首先,為了簡單起見,將測試集和訓練集的所有 xml 數據轉換為兩個 CSV 文件,這裡我用的是 GitHub 上一個庫里的 xml_to_csv.py 代碼,該代碼庫地址:gist.github.com/wbrickn 。 接著,根據生成的 CSV 文件,利用腳本 generate_tfecord.py(也是來自上面的那個代碼庫)創建 Tfrecord 數據集。記住在運行腳本前,必須在函數 class_text_to_int 中指明你需要識別的物體的類。

創建標籤映射

「標籤」映射會指明標籤及其需要的索引。如下所示:

item { id: 1 name: Pikachu}

如果你自己實踐其它項目時,用你自己的標籤替換掉「Pikachu」即可,重要的是,要始終從索引1開始,因為0已經保存在內了。我將該文件保存為object-detection.pbtxtin,放在一個名為「training」的新目錄里。

訓練模型

流水線

通過名為「pipeline」的配置文件可以執行完整的訓練過程。流水線分為5個主要部分,分別負責定義模型、訓練、評估過程參數,訓練數據集輸入和評估數據集輸入。整個流水線的骨架如下所示:

model {(... Add model config here...)}train_config : {(... Add train_config here...)}train_input_reader: {(... Add train_input configuration here...)}eval_config: {}eval_input_reader: {(... Add eval_input configuration here...)}

但是!你不必從頭開始自己寫整個流水線。可以使用 TensorFlow 上的訓練工具和一些預訓練模型,因為從頭開始訓練一個新模型會耗費很多時間。這樣,利用 TensorFlow 提供的一些配置文件,只需略作變動,就能將它們應用在新的訓練環境中。我使用的模型是 ssd_mobilenet_v1_coco_11_06_2017,地址在這裡:download.tensorflow.org

對於我的訓練,我使用了ssd_mobilenet_v1_pets.config配置文件作為起點。由於我的類數量只有一個,所以對變數num_classes做了些改動。此外,num_steps用來停止訓練,fine_tune_checkpoint用來指向模型下載的位置,train_input_reader和eval_input_reader的 變數input_path和label_map_path用於指明訓練集,測試集和標籤映射。

這裡簡要說說 TensorFlow 的兩個預訓練模型:SSD 和 MobileNet。

SSD,即單幀檢測器,是一種神經網路架構,其基於單個前饋卷積神經網路。之所以被稱為「單幀」,是因為它是在同一時步中預測圖像的類和表示檢測(即錨點)的邊界框的位置。

MobileNet是一種卷積特徵提取器,設計用途為在移動設備上獲取圖像的高級特徵。

等流水線就緒後,就將其添加至「training」目錄。然後,繼續用如下命令開始訓練:

python object_detection/train.py --logtostderr --train_dir=path/to/training/ --pipeline_config_path=path/to/training/ssd_mobilenet_v1_pets.config

在訓練期間和訓練之後評估模型

TensorFlow Object Detection API 也提供了在訓練模型期間和訓練之後用於評估模型的代碼。每次訓練會產生一個新的檢查點,評估工具會用給定目錄中的圖像進行預測(我是用的測試集中的圖像)。

執行如下代碼,運行評估工具:

python object_detection/eval.py --logtostderr --train_dir=path/to/training/ --pipeline_config_path=path/to/training/ssd_mobilenet_v1_pets.config --checkpoint_dir=path/to/training/ --eval_dir=path/to/training/

TensorBoard

使用TensorFlow的可視化平台TensorBoard可以看到訓練和評估階段的模型性能。我們可以觀察幾個指標,比如訓練時間、總損失、時步數量等。即使在模型訓練期間,也能正常使用 TensorBoard。

執行如下代碼,啟動 TensorBoard:

tensorboard --logdir=path/to/training/

導出模型

等訓練完成後,下一步就是導出模型,供我們使用。Object Detection API 也提供了導出模型的腳本,叫做 export_inference_graph.py 。

在導出模型之前,一定確保你的訓練目錄中有以下文件:

model.ckpt-${CHECKPOINT_NUMBER}.data-00000-of-00001,model.ckpt-${CHECKPOINT_NUMBER}.indexmodel.ckpt-${CHECKPOINT_NUMBER}.meta

可能有些文件有相同的格式,卻有不同數量的檢查點。如果是這樣,只選擇理想的檢查點就行,並執行如下代碼:

python object_detection/export_inference_graph.py --input_type image_tensor --pipeline_config_path=path/to/training/ssd_mobilenet_v1_pets.config --trained_checkpoint_prefix=path/to/training/model.ckpt-xxxxx --output_directory path/to/output/directory

輸出會是一個包含了模型「凍結」版本的文件,名為 frozen_inference_graph.pb 。

結果

訓練結束時,模型的準確率達到了 87%,總損失為 0.67。然而,在訓練期間,模型的最高準確率達到了 95%。不過,準確率最高時的模型性能並未達到我的預期結果。例如,模型將很多黃顏色的物體(甚至是人類)都給識別成了皮卡丘。另一方面,我還注意到,模型準確為87%時,產生了更少的假陽性,只是會漏掉一些皮卡丘。下圖是 TensorBoard 生成的模型準確度和總損失可視化圖形。

TensorBoard 還能自動評估來自評估數據集中的某些圖像。通過滾動條就能看到模型的預測置信度在不同檢查點位置的變化情況。

圖像檢測程序包中有個 notebook,用於測試TensorFlow提供的預訓練模型。是的,這個 notebook 也能進行改動,以處理我們的自定義模型的「凍結」版本(也就是我們導出的模型)。下面你可以在 notebook 中看到一些檢測結果:

在安卓端識別皮卡丘

到目前為止,我們已經訓練、導出和評估了模型。現在是時候將模型導入到安卓端了,這樣我們就能用手機攝像頭識別皮卡丘了!

從這裡開始內容就有點複雜了,也讓我很頭疼。所以我會盡量詳細的解釋相關步驟。不過不確定大家自己試著做時會不會出現各種各樣的問題。

我們找到 TensorFlow 的安卓部分。首先,你需要下載 Android Studio。下載完後,克隆 TensorFlow 代碼庫(可能你已經克隆完畢)。然後用你剛克隆的 TensorFlow 代碼庫里的目錄在 Android Studio 里導入一個新項目,命名為「Android」。我建議可以先在 READ.ME 文件里熟悉一個這個庫。

READ.ME 建議讓搭建程序的過程儘可能簡單點,並且建議將 Gradle 中的 nativeBuildSystem 變數改為 none。不過,我改成了cmake。

當完成構建後,下一步就是將凍結的模型添加到資源目錄。然後,還是在這個文件夾里,創建一個叫做「labels」的文件,在第一行寫上???(還記得我之前說過第一個類已經保存了嗎?),在第二行寫上你的物體的標籤,在我們這個例子中那就是「Pikachu」啦。

然後,打開位於「Java」目錄中的叫做「DetectionActitivity.java」的文件;這是讓APP執行物體檢測的代碼。看看這裡的TF_OD_API_MODEL_FILE 和TF_OD_API_LABELS_FILE變數。在第一個變數中,將它的值改為位於資源文件夾中的凍結模型的路徑,在第二個變數中,寫出帶有標籤的文件的路徑。你需要掌握的另一個有用的變數是MINIMUM_CONFIDENCE_TF_OD_API ,它是指追蹤檢測結果所需的最低置信度。

現在我們做好準備了!點擊運行按鈕,選擇你的安卓設備,等待十幾秒,直到APP安裝在手機上。注意,會有4個應用下載到手機上,不過是那個叫TF Detect的應用包含了檢測模型。如果一切正常,APP啟動後,你會發現一些物體的照片,看看模型能否識別它們。

就是這樣!

下面是我用自己的手機識別的皮卡丘:

總結回顧

本文我解釋了用 TensorFlow Object Detection API 訓練一個自定義模型識別皮卡丘的所需步驟。開頭先聊了聊該 API 的背景信息及工作原理,然後講了如何處理圖像以創建數據集。

之後,我重點說了說怎樣訓練、評估模型。待模型訓練完畢後,將模型導出到 Python notebook 和 Android中。

這個項目最難的部分是讓模型在Android上運行。不吹不黑,我自己搗鼓了一個月的時間,反覆試錯才徹底弄懂了這部分的知識,中間一度想放棄。建議大家遇到麻煩不要氣餒,不斷嘗試,最終一定能搞出一個在安卓手機上識別物體的應用,亮瞎周圍人的雙眼!

附本項目代碼庫:

github.com/juandes/pika

最後解答一下封面圖的問題:當然還得是皮卡丘。


明天晚上的直播可別錯過,現在就掃碼入群吧。


推薦閱讀:

可愛拉滿!春夏踏青出遊一定要有它!
知乎的皮卡神教是什麼?
口袋妖怪間有沒有生殖隔離?
寵物小精靈是如何存在於精靈球里的?
皮卡丘有多強?

TAG:TensorFlow | 皮卡丘 | 圖像識別 |