在TensorFlow中使用pipeline載入數據
作者:hzyido
原文鏈接:https://www.jianshu.com/p/12b52e54a63c查看更多的專業文章、課程信息和產品信息,請移步至「人工智慧LeadAI」公眾號;官網:www.leadai.org.
正文共2028個字,6張圖,預計閱讀時間6分鐘。
前面對TensorFlow的多線程做了測試,接下來就利用多線程和Queue pipeline地載入數據。數據流如下圖所示:
首先,A、B、C三個文件通過RandomShuffle進程被隨機載入到FilenameQueue里,然後Reader1和Reader2進程同FilenameQueue里取文件名讀取文件,讀取的內容再被放到ExampleQueue里。最後,計算進程會從ExampleQueue里取數據。各個進程獨立操作,互不影響,這樣可以加快程序速度。
我們簡單地生成3個樣本文件。
#生成三個樣本文件,每個文件包含5列,假設前4列為特徵,最後1列為標籤data = np.zeros([20,5])np.savetxt(file0.csv, data, fmt=%d, delimiter=,)data += 1np.savetxt(file1.csv, data, fmt=%d, delimiter=,)data += 1np.savetxt(file2.csv, data, fmt=%d, delimiter=,)
然後,創建pipeline數據流。
#定義FilenameQueuefilename_queue = tf.train.string_input_producer(["file%d.csv"%i for i in range(3)])#定義ExampleQueueexample_queue = tf.RandomShuffleQueue( capacity=1000, min_after_dequeue=0, dtypes=[tf.int32,tf.int32], shapes=[[4],[1]])#讀取CSV文件,每次讀一行reader = tf.TextLineReader()key, value = reader.read(filename_queue)#對一行數據進行解碼record_defaults = [[1], [1], [1], [1], [1]]col1, col2, col3, col4, col5 = tf.decode_csv( value, record_defaults=record_defaults)features = tf.stack([col1, col2, col3, col4])#將特徵和標籤push進ExampleQueueenq_op = example_queue.enqueue([features, [col5]])#使用QueueRunner創建兩個進程載入數據到ExampleQueueqr = tf.train.QueueRunner(example_queue, [enq_op]*2)#使用此方法方便後面tf.train.start_queue_runner統一開始進程tf.train.add_queue_runner(qr)xs = example_queue.dequeue()with tf.Session() as sess: coord = tf.train.Coordinator()#開始所有進程 threads = tf.train.start_queue_runners(coord=coord) for i in range(200): x = sess.run(xs) print(x) coord.request_stop() coord.join(threads)
以上我們採用for循環step_num次來控制訓練迭代次數。我們也可以通過tf.train.string_input_producer的num_epochs參數來設置FilenameQueue循環次數來控制訓練,當達到num_epochs時,TensorFlow會拋出OutOfRangeError異常,通過捕獲該異常,停止訓練。
filename_queue = tf.train.string_input_producer(["file%d.csv"%i for i in range(3)], num_epochs=6)...with tf.Session() as sess: sess.run(tf.initialize_local_variables()) #必須加上這句話,否則報錯! coord = tf.train.Coordinator()#開始所有進程 threads = tf.train.start_queue_runners(coord=coord) try: while not coord.should_stop(): x = sess.run(xs) print(x) except tf.errors.OutOfRangeError: print(Done training -- epch limit reached) finally: coord.request_stop()
捕獲到異常時,請求結束所有進程。
原文: 在TensorFlow中使用pipeline載入數據(https://goo.gl/jbVPjM)
推薦閱讀:
※tf 三
※學習筆記TF003:數據流圖定義、執行、可視化
※把故宮裝進手機,讓柯潔走向最強!谷歌在北京秀出四個AI商用大招
※tf.reset_default_graph
※Tensorflow學習(一):模型保存與恢復
TAG:TensorFlow |