使用TensorFlow時輸入數據的三個姿勢

本文有一半安利一半吐槽,因此沒有什麼特別的代碼級的乾貨,一些例子代碼都在這裡。希望大家看了以後,在自己探索TensorFlow輸入數據的API的時候可以少被坑。

本來是準備發一個各大雲服務GPU實例的橫向評測的,因為Google ML Engine非常特立獨行的把輸入輸出都搞到他的Google Storage上,因此需要改動自己的代碼來對存儲在GS的數據進行讀入。

首先當然是跟隨大Google的Doc來耍一波他和TensorFlow深度集成的服務,文檔可以翻牆看這裡。在看的時候請先跟著這裡設置下命令行環境。跟著這個文檔,我們可以在上面運行一個簡單的分類器。

跟著他的例子跑的時候,我發現本地訓練和遠程訓練(使用Google ML Engine服務)訓練數據指定路徑的時候參數名稱都一樣!因為在上次使用Tensorflow做Kaggle入門題之數字識別 - 知乎專欄的時候,我嘗試用了一下TensorFlow讀命令行參數的姿勢,這讓我非常懷疑程序內部,本地文件和遠程文件的路徑是用同一個變數表示的。於是我就打開了他的代碼,準備學習一個判定文件是否是本地,以及讀取gs://xxx文件的正確姿勢。

然後我就看到

train_input = model.generate_input_fn(n train_files,n num_epochs=num_epochs,n batch_size=train_batch_size,n)n

這怎麼看都不像是要做什麼特殊處理啊!後面都開始讀數據了!!

而且在閱讀代碼的過程中,我隱隱約約發現了一套似乎非常爽的讀取數據的姿勢,比如

reader = tf.TextLineReader(skip_header_lines=skip_header_lines)n

跳過首行的文本讀行器有沒有爽!

以及各種文本值轉embeding的便捷方法,強烈安利看一下並且實踐一波,本文是講輸入的方法的就不展開了。原代碼較大,我把相關的兩個文件貼到gist上如下。

model.py

task.py

好,回到讀GoogleStorage文件上。從他的例子代碼來看,tensorflow支持一套比較方便的讀入訓練數據的方法,相比自己讀入,一個是支持的URI比較多(比如gs://),因此方便一套代碼通吃多種平台和輸入。另一個是對數據很多需要並行多次讀取的情況,可以有效避免自己做各種隨機讀取和管理IO的工作。

藉此機會,看了下對讀入數據相關的API,TensorFlow的文檔和極客學院的中文翻譯。本文不再重複一些很基礎的概念,只對實踐過程中,文檔和翻譯沒有提及,或者google的例子代碼沒有提及的一些比較坑的點進行說明,希望大家試驗的時候少走彎路T-T。

數據輸入一共三個方式,

  1. 定義佔位符然後用feed_dict給佔位符填數據

  2. 文件讀取
  3. 定義一個常量(然後初始化它),或者定義一個變數,然後定義他的initializer為一個placeholder,然後用feed_dict給佔位符填數據

1和3簡直不能再常見,我們之前的兩個在Kaggle上做題的文章

使用Tensorflow做Kaggle入門題之數字識別 - 知乎專欄

使用Tensorflow做Kaggle入門題之泰坦尼克 - 知乎專欄

都用的方法1,方法3因為目前沒有用到BP時需要修改輸入數據的情況,暫時沒有用它。

而2到本文之前還沒有嘗試過。

2的搞法我的理解就是開闢一個線程往一個文件名隊列裡面填文件名,然後開啟一個讀取器從這個隊列裡面拿文件,然後從文件中讀出記錄並轉化為tensor。API比較少,命名也比較直觀,但是有一些坑需要注意。

首先,如果你用了我前言裡面的代碼,請注意95到97行這三行初始化代碼。

sess.run(tf.local_variables_initializer())n sess.run(tf.global_variables_initializer())n sess.run(tf.tables_initializer())n

這三行代碼我到目前只見過tf.global_variables_initializer(),但是實際上這套代碼裡面的parse_label_column要求你使用tf.tables_initializer()初始化table;其中讀文件的部分實際還要求tf.local_variables_initializer(),但是文檔中卻沒有提到任何一個initializer,翻譯的就更不用說了,因為這個有個人還在github上憤怒的提了個問。

第二,如果你只看了model.py和task.py這兩個Google ML Engine的例子代碼,會發現為什麼拷了一段代碼過來,程序會運行到一半卡住,ctrl+c都搞不定,必須kill。原因可以從文檔里發現,請注意我的代碼中99-106行代碼,它們就是參考文檔加的:

coord = tf.train.Coordinator()nthreads = tf.train.start_queue_runners(coord=coord)nx = sess.run(X)nprint(x)n# y = sess.run(Y)n# print(y)ncoord.request_stop()ncoord.join(threads)n

注意其中的coord和tf.train.start_queue_runners。

如果我不執行tf.train.start_queue_runners,實際上在代碼60行的filename_queue中會永遠不填充名字,因此會一直卡在64行read_up_to上。之所以例子代碼沒有問題,是由於它調用的 tf.contrib.learn.Experiment把這個細節藏起來了。而coord如文檔所言,就是為了實現類似Java中Thread.join(xxxx)這種多woker同步的工具。

最後,如果你加了from __future__ import print_function,那麼記得print就是函數了,直接print xxxxx會報語法錯誤。


推薦閱讀:

如何理解卷積神經網路中的權值共享?
如何評價文章「谷歌太可怕?專家:中國智能晶元引領世界」?
如何評價谷歌最近在人臉數據集上取得驚人效果的BEGAN模型?
caffe里的clip gradient是什麼意思?
如何評價MXNet發布的1.0版本?

TAG:TensorFlow | 深度学习DeepLearning |