標籤:

怎麼理解tensorflow中tf.train.shuffle_batch()函數?

怎麼理解tf.train.shuffle_batch()這個函數?


感謝邀請。

batch英語翻譯就是批次,一批的意思。

至於為什麼shuffle,道理和隨機採樣調查差不多。

官方API doc說的比較清楚了,一句話:

Creates batches by randomly shuffling tensors.

@東方不快 已經把文檔翻譯了,可以對照著看一下。

題外話:

我覺得題主問這個問題可能是和隨機梯度法里的mini-batch那個batch搞混了。

我的blog中也剛好闡述過什麼是mini-batch。

如何理解TensorFlow中的batch和minibatch

這是一個在TF中,或者說很多DL的框架中很常見的詞。

這個解釋我覺得比較貼切也比較容易理解。引用如下:

深度學習的優化演算法,說白了就是梯度下降。每次的參數更新有兩種方式。

第一種,遍歷全部數據集算一次損失函數,然後算函數對各個參數的梯度,更新梯度。這種方法每更新一次參數都要把數據集里的所有樣本都看一遍,計算量開銷大,計算速度慢,不支持在線學習,這稱為Batch gradient descent,批梯度下降。

另一種,每看一個數據就算一下損失函數,然後求梯度更新參數,這個稱為隨機梯度下降,stochastic gradient descent。這個方法速度比較快,但是收斂性能不太好,可能在最優點附近晃來晃去,hit不到最優點。兩次參數的更新也有可能互相抵消掉,造成目標函數震蕩的比較劇烈。

為了克服兩種方法的缺點,現在一般採用的是一種折中手段,mini-batch gradient decent,小批的梯度下降,這種方法把數據分為若干個批,按批來更新參數,這樣,一個批中的一組數據共同決定了本次梯度的方向,下降起來就不容易跑偏,減少了隨機性。另一方面因為批的樣本數與整個數據集相比小了很多,計算量也不是很大。

Michael Nielsen在這一章節也有解釋,mini-batch是什麼,為什麼有這個東西。

Deep Learning的這一章節的5.9小節也有解釋,還給出了mini-batch的典型值。

結合上面給出的中文解釋,再看這兩個小節,應該會對mini-batch有所理解。


瀉藥,tensorflow的官方文檔中對這個函數有如下的描寫

這個函數的功能是:Creates batches by randomly shuffling tensors.

但需要注意的是它是一種圖運算,要跑在sess.run()里

This function adds the following to the current Graph:

在運行這個函數時它會在當前圖上創建如下的東西:

  • A shuffling queue into which tensors from tensors are enqueued.
  • 一個亂序的隊列,進隊的正是傳入的tensors
  • A dequeue_many operation to create batches from the queue.
  • 一個dequeue_many的操作從隊列中推出成batch的tensor
  • A QueueRunner to QUEUE_RUNNER collection, to enqueue the tensors from tensors.
  • 一個QueueRunner的線程,正是這個線程將傳入的數據推進隊列中.

把數據放在隊列里有很多好處,可以完成訓練數據和測試數據的解耦,同時有利於寫成分散式訓練(個人理解),但需要注意的是在取數據的時候,容易造成堵塞的情況.

這時候,應該需要截獲超時異常來強制停止線程.


不鼓勵在知乎中問代碼問題,下次請整理好代碼和日誌在Stackoverflow中提問。

根據錯誤「has insufficient elements(requested128,current size 64)」,猜測原因是TFRecrods中數據不夠,每次batch要求128,但隊列中只有64,建議把batch size調小,或者把filename queue的參數num_epochs設為None試試。

如有問題請參考這份可運行的代碼 GitHub - tobegit3hub/deep_recommend_system: Deep learning recommend system with TensorFlow


推薦閱讀:

tensorflow的自動求導具體是在哪部分代碼里實現的?
如何看待Theano宣布終止開發 ?
如何評價 Google 發布的 Tensor Processing Unit?
誰能詳細講解一下TensorFlow Playground所展示的神經網路的概念?

TAG:TensorFlow |