Tensorflow入門教程(6)
上次文章:
閆17:Tensorflow入門教程(5)本次內容相關代碼:
https://github.com/SaoYan/LearningTensorflow/blob/master/exp11_user_dataset_low_API_1.py同步連載於個人公眾號「SaoYan」
Overview
之前幾次的全部常式,使用的都是tensorflow預處理過的數據集,直接載入即可。例如:
然而實際中我們使用的通常不會是這種超級經典的數據集,如果我們有一組圖像存儲在磁碟上面,如何以mini-batch的形式把它們讀取進來然後高效的送進網路訓練?
這次我們首先用tensorflow最底層的API處理這個問題,後面介紹高層API。高層API是對底層的進一步封裝,用戶可以不必關心過多細節。不過了解一下比較底層的API還是有好處的。
當你有一組自己的數據的時候,你需要經過以下兩個步驟:
(1)將全部數據寫入一個後綴 .tfredords 的文件。
這個步驟涉及讀入->預處理->寫入tfrecords,對你的數據是什麼格式沒有要求。例如,如果你手中是圖像數據,那用opencv/PIL等介面讀入;如果是matlab數據(mat文件),那可以用h5py協議讀入,等等。不管如何讀入,最終都要寫入到統一的tfrecords文件中,以便用tensorflow提供的介面高效讀取。
(2)以mini-batch的形式從tfrecords中讀取數據,送到模型的placeholder中支持網路訓練。
實驗設置
代碼中使用的數據是存在磁碟中的400張png圖像,也傳到了github上面,存在my_data路徑下面。部分如下:
代碼實現以下功能:製備tfrecords形式的數據集,然後再以mini-batch讀入,為了測試讀入是否成功,把讀入的數據顯示在tensorboard上面。
製備tfrecords數據集
在教程5中(Tensorboard),大部分代碼都是遵循API介面的固定「模式」寫就可以,這次也主要以這種方式進行,而不過多討論背後的理論細節。
- 兩個輔助函數
定義這倆輔助函數的目的完全是不想讓後面的代碼太冗長
- 讀取圖像&寫入tfrecords文件
幾點說明
(1)讀取圖像文件的時候用到了glob和opencv兩個包。glob是將路徑下全部文件名一次性存到一個list中,方面後面逐個讀取;opencv則只是利用imread介面讀取圖像文件的。
(2)和一切文件操作一樣,向tfrecords文件中寫入內容也需要建立一個writer對象,創建這個對象的是函數 tf.python_io.TFRecordWriter
(3)feature是我們創建的一個字典對象,這裡面可以包含你想記錄的任何信息。在這裡我們存入了三對鍵值(key: value):image_raw(圖像數據,這個是核心內容),heigh(高),width(寬)。你也可以加入更多的信息,例如,通道數目,文件名等等。這些信息在後面讀取數據的時候都可以一併讀取出來。比如:在主程序中,你需要用到圖像的尺寸參數,那麼你可以將圖像和尺寸參數一起讀出。
(4)注意數據格式。圖像數據本身是8bit的,因此我們用前面定義的輔助函數 _bytes_feature_ 把數據轉化成tensorflow要求的tf.train.BytesList格式存入。實際中還會碰到圖像本身是以float形式存儲的,代碼就需要相應的變動,這個下次再說。
從tfrecords中載入nimi-batch
- 定義函數:讀取一個樣本
幾點說明:
(1)整個代碼過程很煩雜,因為是調用的底層API,不過都是固定寫法,其中的內部原理主頁菌一知半解,不敢在這裡隨便講
(2)特別注意這裡這個字典對象的定義方式
首先,這裡的三個key要和前面製備tfrecords時候一致;其次,注意數據格式,image_raw是8bit存儲的,所以讀取的時候限定tf.string類型,同理,height和width要限定tf.int64
(3)如前文所說,字典中存入的信息都可以通過key來讀取,上面的代碼只讀取了圖像信息,如果想獲取height的值,可以補充這樣一句代碼:
height = tf.decode_raw(features[height], tf.int64)
然後在函數返回值中把height也返回即可
(4)每一個樣本是以一維的形式從數據流中抽取出來的,所以需要reshape成原始尺寸
- 定義mini-batch
用前面定義的read_record獲取一個樣本,然後用tf.train.shuffle_batch來封裝一個mini-batch。
tf.train.shuffle_batch會多次通過read_record抽取樣本,並且開闢一塊內存空間建立隊列(queue),將樣本洗牌打亂,空間開闢越大,數據混亂度會越高。控制洗牌的參數是capacity和min_after_dequeue,官網文檔中給出了這倆參數的取值建議,我粘貼到了代碼注釋中。
注意:
從最開始介紹tensorflow的時候主頁菌就在強調一個事情:任何東西在用Session運行之前都是沒有實際值的。這裡也不例外。在主程序部分,每一個step都要這麼一句代碼:
batch = sess.run(data_batch)
這個batch才是實際的數據,是可以feed給placeholder的
- 主程序部分
我們的主程序是讀取mini-batch然後用tensorboard顯示。
說明:
有四行代碼必不可少session開頭的兩行:
coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)
session結尾的兩行
coord.request_stop()coord.join(threads)
至於內部機理,文檔寫的太模糊,主頁菌缺少計算機基礎理論知識,並沒有看懂
我相信你可能已經看暈了......這部分太過瑣碎,細節很多,官方文檔裡面寫的也很模糊,對內部機理解釋的不到位。
面對這種情況,主頁菌最初選擇的方法就是,親自嘗試,用幾乎一整天的時間摸索出了這一套代碼的套路。雖然對機理還是一知半解,但是對代碼思路十分清晰了,在自己的項目中能夠迅速擼出一套數據預處理的代碼。
所以,主頁菌的建議就是,親自調通一套demo!
預告
這次教程使用的數據是8bit的,然而如果我想用float格式存儲怎麼辦?(或者原始數據就是float格式的,總不能截斷成8bit來存儲吧......)
雖然這部分內容不多,但是由於這次信息量夠大了,還是放到下次單獨說吧。
推薦閱讀:
TAG:深度學習DeepLearning | 機器學習 | TensorFlow |