TensorFlow中的那些高級API
摘要: 在這篇文章中,我們將看到一個使用了最新高級構件的例子,包括Estimator(估算器)、Experiment(實驗)和Dataset(數據集)。值得注意的是,你可以獨立地使用Experiment和Dataset。不妨進來看看作者是如何玩轉這些高級API的。
TensorFlow擁有很多庫,比如Keras、TFLearn和Sonnet,對於模型訓練來說,使用這些庫比使用低級功能更簡單。儘管Keras的API目前正在添加到TensorFlow中去,但TensorFlow本身就提供了一些高級構件,而且最新的1.3版本中也引入了一些新的構件。
在這篇文章中,我們將看到一個使用了這些最新的高級構件的例子,包括Estimator(估算器)、Experiment(實驗)和Dataset(數據集)。值得注意的是,你可以獨立地使用Experiment和Dataset。我在這裡假設你已經了解TensorFlow的基礎知識;如果沒有的話,那麼TensorFlow官網上提供的教程值得學習。
Experiment、Estimator和DataSet框架以及它們之間的交互。
我們在本文中將使用MNIST作為數據集。這是一個使用起來很簡單的數據集,可以從TensorFlow官網獲取到。你可以在這個gist中找到完整的代碼示例。使用這些框架的其中一個好處是,我們不需要直接處理圖和會話。
Estimator(估算器)類
Estimator類代表了一個模型,以及如何對這個模型進行訓練和評估。我們可以像下面這段代碼創建一個Estimator:
要創建Estimator,需要傳入一個模型函數、一組參數和一些配置。
- 傳入的**參數**應該是模型超參數的一個集合。這可以是一個dictionary,但是我們將在這個例子中把它表示成一個HParams對象,就像namedtuple一樣。
- 傳入的**配置**用於指定如何運行訓練和評估,以及在哪裡存儲結果。這個配置是一個RunConfig對象,該對象會把模型運行環境相關的信息告訴Estimator。
- 模型函數是一個Python函數,它根據給定的輸入構建模型。
模型函數
模型函數是一個Python函數,並作為一級函數傳遞給Estimator。稍後我們會看到,TensorFlow在其他地方也使用了一級函數。將模型表示為一個函數的好處是可以通過實例化函數來多次創建模型。模型可以在訓練過程中用不同的輸入重新創建,例如,在訓練過程中運行驗證測試。
模型函數把**輸入特徵**作為參數,將相應的**標籤**作為張量。它也能以某種方式來告知用戶模型是在訓練、評估或是在執行推理。模型函數的最後一個參數是**超參數**集合,它們與傳遞給Estimator的超參數集合相同。模型函數返回一個**EstimatorSpec**對象,該對象定義了一個完整的模型。
EstimatorSpec對象用於對操作進行預測、損失、訓練和評估,因此,它定義了一個用於訓練、評估和推理的完整的模型圖。由於EstimatorSpec只可用於常規的TensorFlow操作,因此,我們可以使用像TF-Slim這樣的框架來定義模型。
Experiment(實驗)類
Experiment類定義了如何訓練模型,它與Estimator完美地集成在一起。我們可以像如下代碼創建一個Experiment對象:
以下幾種情況會把Experiment對象作為輸入:
- 一個**estimator**(例如我們上面定義的)。
- 作為一級函數**訓練和評估數據**。這裡使用了與前面提到的模型函數相同的概念。如果需要的話,通過傳入函數而不是操作,可以重新創建輸入圖。稍後我們還會談到這個。
- 訓練和評估hook(鉤子)。鉤子可用於保存或監視特定的內容,或者在圖或會話中設置某些操作。例如,我們將其傳入到操作中,幫助初始化數據載入器。
- 描述需要訓練多久以及何時評估的各種參數。
一旦定義了experiment,我們就可以像下面這段代碼那樣使用learn_runner.run來運行它訓練和評估模型:
與模型函數和數據函數一樣,learn_runner將一個創建experiment的函數作為參數傳入。
Dataset(數據集)類
我們將使用Dataset類和相應的Iterator來表示數據的訓練和評估,以及創建在訓練過程中迭代數據的數據饋送器。 在本示例中,我們將使用在Tensorflow中可用的MNIST數據,並為其構建一個Dataset包裝。例如,我們將把訓練輸入數據表示為:
調用這個get_train_inputs將返回一個一級函數,用於在TensorFlow圖中創建數據載入操作,以及返回一個用於初始化迭代器的Hook。
本示例中使用的MNIST數據最初是一個Numpy數組。我們創建了一個佔位符張量來獲取數據;使用佔位符的目的是為了避免數據的複製。接下來,我們在from_tensor_slices的幫助下創建一個切片數據集。我們要確保該數據集可以運行無限次數,並且數據被重新洗牌並放入指定大小的批次中。
要迭代數據,就需要從數據集中創建一個迭代器。由於我們正在使用佔位符,因此需要使用NumPy數據在相關會話中對佔位符進行初始化。可以通過創建一個可初始化的迭代器來實現這個。在創建圖的時候,將創建一個自定義的IteratorInitializerHook對象來初始化迭代器:
IteratorInitializerHook繼承自SessionRunHook。這個鉤子將在相關會話創建後立即調用after_create_session,並使用正確的數據初始化佔位符。這個鉤子由我們的get_train_inputs函數返回,並在創建時傳遞給Experiment對象。
train_inputs函數返回的數據載入操作是TensorFlow的操作,該操作每次評估時都會返回一個新的批處理。
運行代碼
現在,我們已經定義了所有內容,可以使用下面這個命令運行代碼了:
如果不傳入參數,它將使用文件開頭的默認標誌來確定數據和模型保存的位置。
在訓練過程中,在終端上會輸出這段時間內的全局步驟、損失和準確性等信息。除此之外,Experiment和Estimator框架將記錄TensorBoard可視化的某些統計信息。如果我們運行這個命令:
在訓練過程中,在終端上會輸出這段時間內的全局步驟、損失和準確性等信息。除此之外,Experiment和Estimator框架將記錄TensorBoard可視化的某些統計信息。如果我們運行這個命令:
TensorBoard可視化中的評估準確度
我寫這篇文章,是因為我在編寫代碼示例時,無法找到有關Tensorflow Estimator 、Experiment和Dataset框架太多的信息和示例。我希望這篇文章能向你簡要介紹一下這些框架是如何工作的,它們採用了什麼樣的抽象方法以及如何使用它們。如果你對使用這些框架感興趣,下面我將介紹一些注意點和其他的文檔。
有關Estimator、Experiment和Dataset框架的注意點
- 有一篇名為《TensorFlow Estimators:掌握高級機器學習框架中的簡單性與靈活性》的文章描述了Estimator框架的高級別設計。
- TensorFlow官網上有更多有關使用Dataset API的文檔。
- 有2個版本的Estimator類。在這個例子中,我們使用的是tf.estimator.Estimator,但在tf.contrib.learn.Estimator中還有一個較老的不穩定版本。
- 也有2個版本的RunConfig類。當我們使用tf.contrib.learn.RunConfig的時候,另外還有一個tf.estimator.RunConfig的版本。我無法讓後者與Experiment框架結合在一起,所以我堅持使用tf.contrib版本。
- 雖然我們在這個例子中沒有使用它們,但是Estimator框架定義了典型模型(如分類器和回歸器)的預定義估算器。這些預定義的估算器使用起來很簡單,並附有詳細的教程。
- TensorFlow還定義了模型「頭」的抽象,這個「頭」是架構的上層,定義了損失、評估和訓練操作。這個「頭」負責定義模型函數和所有必需的操作。你可以在tf.contrib.learn.Head中找到一個版本。在較新的Estimator框架中也有一個原型版本。在這個例子中我們不打算使用,因為它的開發非常不穩定。
- 本文使用了TensorFlow slim框架來定義模型的架構。 Slim是一個用於定義TensorFlow中複雜模型的輕量級庫。它定義了預定義的架構和預先訓練的模型。
- 更複雜的例子,請參見:https://gist.github.com/peterroelants/9956ec93a07ca4e9ba5bc415b014bcca#file-mnist_estimator-py
文章原標題《Higher-Level APIs in TensorFlow》,作者:Peter Roelants,譯者:夏天,審校:主題曲。
文章為簡譯,更為詳細的內容,請查看原文需要爬梯,不方便的同學也可以下載下方的PDF附件,閱讀原文內容。
附件下載: Higher-L...[【方向】].1504516657.pdf
更多技術乾貨敬請關注云棲社區知乎機構號:阿里云云棲社區 - 知乎
推薦閱讀: