Tensorflow入門教程(8)
上次文章:
閆17:Tensorflow入門教程(7)本次對應代碼:
https://github.com/SaoYan/LearningTensorflow/blob/master/exp13_user_dataset_high_API_1.py官方文檔參考閱讀:
https://www.tensorflow.org/programmers_guide/datasets同步連載於個人公眾號「SaoYan」
Overview
前兩次文章解決了使用底層API完成用戶自己的數據集構建和讀取的問題,但是有一個明顯的問題:太繁碎,細節太多
雖然主頁菌自己一直在使用底層API,但是考慮到高層API的簡潔性,還是有必要探索一下的。實際項目中諸位根據自己的喜好自行選擇吧。
Tensorflow封裝了一組API來處理數據的讀入,它們都屬於模塊 tf.data。這個模塊中包含了5個類:4種Dataset和1個迭代器類型。
使用方法很簡單:(1)構建Dataset (2)構造這個Dataset的迭代器 (3)操作迭代器讀出數據。(聽起來是不是很像一個標準的面向對象編程思路?)
本次教程使用最基本的 tf.data.Dataset,使用的數據和 教程6 相同,400張尺寸為180 x 180的灰度圖像。部分截圖如下:
構建Dataset
tf.data.Dataset 主要提供了以下幾種功能:
(1)構建數據集
(2)對數據集進行預處理
(3)將數據集打混(shuffle)和分批(mini-batch)
- 構建數據集
這一步其實就是實例化一個tf.data.Dataset 對象,Dataset的原始數據來源自程序內存。這就涉及一個問題:如果預先把全部圖像數據都裝載進內存,勢必是十分低效的,而且浪費大量資源,所以我們選擇另一種方案:將全部圖像數據的路徑裝載進內存並用其實例化Dataset,然後把實際從磁碟讀入圖像的操作放在「預處理」這個階段。
首先載入圖像路徑和標籤。由於這個demo只是隨便一組圖像,沒有類別標籤,所以我們隨機生成400個標籤,權當模擬。
然後就可以實例化Dataset對象了,Dataset中每一個「元素」是一個元組 (圖像路徑,標籤)
- 預處理
注意這裡的「預處理」是廣義的,可以是對Dataset中的「元素」進行任何操作。例如這裡我們的「預處理」實際上是根據路徑從磁碟中讀取圖片文件。
對Dataset進行預處理需要使用成員函數 map(),傳入參數是某個函數對象,map() 函數將會把這個函數作用在Dataset中每一個元素之上。
在我們這個例子中,有兩個需要著重注意的細節:
(1) 由於這裡我們的Dataset中每個元素指一個二元組,因此對應的預處理函數應該有兩個輸入參數,返回一個二元組。
(2)預處理函數只能包含tensorflow所提供的Tensor操作符,而這裡我們難以避免的要使用opencv/PIL等python原生模塊讀取圖像(前者處理的數據是tf.Tensor類型,後者處理的數據是numpy-ndarray類型,二者不兼容)。因此我們需要使用 tf.py_func 將python函數轉換成tensorflow操作符。
首先定義python函數。注意輸入兩個參數,返回一個二元組。雖然label不需要任何操作原樣返回,但是依然要這樣寫。
然後使用這個函數來「預處理」前面構建的Dataset對象。
注意這一句代碼其實包含了兩層操作:
(1)將python函數「包裝」成tensorflow操作符。tf.py_func有三個輸入參數:python函數,操作符的輸入(實參),返回類型。
(2)將tensorflow操作符封裝成lambda函數,將這個lambda函數對象傳入map()
注意:作用於Dataset的是傳入map()的函數對象(這裡就是lambda函數),不是最初定義的python函數,也不是tensorflow操作符。
- shuffle & batch
訓練機器學習模型的時候,會將數據集遍歷多輪(epoch),每一輪遍歷都以mini-batch的形式送入模型(batch),數據遍歷應該隨機進行(shuffle)。在高層API中,這三個功能各自用一行代碼就能搞定:
到這裡我們就完成了數據集的構建,可以明顯感覺到代碼量減少了,而且只需要調用API和指定參數,沒有諸如數據類型匹配、queue & dequeue之類的細節。
構建迭代器
說明:
(1)老規矩:next_element只是一個「符號」,需要用Session運行才能真正得到數據。
(2)每次用Session運行next_element,活得下一個mini-batch的數據,也就是獲得一個尺寸為 N x 180 x 180 x 1 的Tensor。
主體程序——讀取數據
說明:
當迭代器移動到尾部以後,再次運行next_element會產生OutOfRangeError,因此我們使用了try-catch語句。
其他代碼和前兩次文章沒有差別:將讀取的數據載入Tensorboard可視化。
運行效果
Tensorboard顯示結果:
預告
現在你可能會問一個問題:在使用高層API的時候是不是就不需要和TFrecords文件打交道了?答案是No。
在很多圖像處理任務中,我們不希望直接拿整張圖作為訓練數據,而是需要將圖像切割成若干子圖(patch),然後把全部的子圖混合在一起作為數據集。這個時候如果像現在這樣以整圖為單元讀入,就難以滿足我們在子圖層面充分打混數據的要求。
這種情況下,我們需要事先切割子圖,然後依然將它們統一存放在TFrecords中等待讀入。
tf.data介面也提供了從TFrecords讀入的API,下次就來嘗試一下。
推薦閱讀:
TAG:TensorFlow | 深度學習DeepLearning | 機器學習 |