怎麼理解tensorflow中tf.train.shuffle_batch()函數?
怎麼理解tf.train.shuffle_batch()這個函數?
感謝邀請。batch英語翻譯就是批次,一批的意思。
至於為什麼shuffle,道理和隨機採樣調查差不多。
官方API doc說的比較清楚了,一句話:Creates batches by randomly shuffling tensors.@東方不快 已經把文檔翻譯了,可以對照著看一下。深度學習的優化演算法,說白了就是梯度下降。每次的參數更新有兩種方式。
第一種,遍歷全部數據集算一次損失函數,然後算函數對各個參數的梯度,更新梯度。這種方法每更新一次參數都要把數據集里的所有樣本都看一遍,計算量開銷大,計算速度慢,不支持在線學習,這稱為Batch gradient descent,批梯度下降。
另一種,每看一個數據就算一下損失函數,然後求梯度更新參數,這個稱為隨機梯度下降,stochastic gradient descent。這個方法速度比較快,但是收斂性能不太好,可能在最優點附近晃來晃去,hit不到最優點。兩次參數的更新也有可能互相抵消掉,造成目標函數震蕩的比較劇烈。
為了克服兩種方法的缺點,現在一般採用的是一種折中手段,mini-batch gradient decent,小批的梯度下降,這種方法把數據分為若干個批,按批來更新參數,這樣,一個批中的一組數據共同決定了本次梯度的方向,下降起來就不容易跑偏,減少了隨機性。另一方面因為批的樣本數與整個數據集相比小了很多,計算量也不是很大。
Michael Nielsen在這一章節也有解釋,mini-batch是什麼,為什麼有這個東西。
Deep Learning的這一章節的5.9小節也有解釋,還給出了mini-batch的典型值。結合上面給出的中文解釋,再看這兩個小節,應該會對mini-batch有所理解。
瀉藥,tensorflow的官方文檔中對這個函數有如下的描寫
這個函數的功能是:Creates batches by randomly shuffling tensors.
但需要注意的是它是一種圖運算,要跑在sess.run()里
This function adds the following to the current Graph:
在運行這個函數時它會在當前圖上創建如下的東西:
- A shuffling queue into which tensors from tensors are enqueued.
- 一個亂序的隊列,進隊的正是傳入的tensors
- A dequeue_many operation to create batches from the queue.
- 一個dequeue_many的操作從隊列中推出成batch的tensor
- A QueueRunner to QUEUE_RUNNER collection, to enqueue the tensors from tensors.
- 一個QueueRunner的線程,正是這個線程將傳入的數據推進隊列中.
不鼓勵在知乎中問代碼問題,下次請整理好代碼和日誌在Stackoverflow中提問。
根據錯誤「has insufficient elements(requested128,current size 64)」,猜測原因是TFRecrods中數據不夠,每次batch要求128,但隊列中只有64,建議把batch size調小,或者把filename queue的參數num_epochs設為None試試。
如有問題請參考這份可運行的代碼 GitHub - tobegit3hub/deep_recommend_system: Deep learning recommend system with TensorFlow推薦閱讀:
※tensorflow的自動求導具體是在哪部分代碼里實現的?
※如何看待Theano宣布終止開發 ?
※如何評價 Google 發布的 Tensor Processing Unit?
※誰能詳細講解一下TensorFlow Playground所展示的神經網路的概念?
TAG:TensorFlow |