如何使用TensorFlow中的高級API:Estimator、Experiment和Dataset

近日,背景調查公司 Onfido 研究主管 Peter Roelants 在 Medium 上發表了一篇題為《Higher-Level APIs in TensorFlow》的文章,通過實例詳細介紹了如何使用 TensorFlow 中的高級 API(Estimator、Experiment 和 Dataset)訓練模型。值得一提的是 Experiment 和 Dataset 可以獨立使用。這些高級 API 已被最新發布的 TensorFlow1.3 版收錄。

TensorFlow 中有許多流行的庫,如 Keras、TFLearn 和 Sonnet,它們可以讓你輕鬆訓練模型,而無需接觸哪些低級別函數。目前,Keras API 正傾向於直接在 TensorFlow 中實現,TensorFlow 也在提供越來越多的高級構造,其中的一些已經被最新發布的 TensorFlow1.3 版收錄。

在本文中,我們將通過一個例子來學習如何使用一些高級構造,其中包括 Estimator、Experiment 和 Dataset。閱讀本文需要預先了解有關 TensorFlow 的基本知識。

Experiment、Estimator 和 DataSet 框架和它們的相互作用(以下將對這些組件進行說明)

在本文中,我們使用 MNIST 作為數據集。它是一個易於使用的數據集,可以通過 TensorFlow 訪問。你可以在這個 gist 中找到完整的示例代碼。使用這些框架的一個好處是我們不需要直接處理圖形和會話

Estimator

Estimator(評估器)類代表一個模型,以及這些模型被訓練和評估的方式。我們可以這樣構建一個評估器:

為了構建一個 Estimator,我們需要傳遞一個模型函數,一個參數集合以及一些配置。

  • 參數應該是模型超參數的集合,它可以是一個字典,但我們將在本示例中將其表示為 HParams 對象,用作 namedtuple。
  • 該配置指定如何運行訓練和評估,以及如何存出結果。這些配置通過 RunConfig 對象表示,該對象傳達 Estimator 需要了解的關於運行模型的環境的所有內容。
  • 模型函數是一個 Python 函數,它構建了給定輸入的模型(見後文)。

模型函數

模型函數是一個 Python 函數,它作為第一級函數傳遞給 Estimator。稍後我們就會看到,TensorFlow 也會在其他地方使用第一級函數。模型表示為函數的好處在於模型可以通過實例化函數不斷重新構建。該模型可以在訓練過程中被不同的輸入不斷創建,例如:在訓練期間運行驗證測試。

模型函數將輸入特徵作為參數,相應標籤作為張量。它還有一種模式來標記模型是否正在訓練、評估或執行推理。模型函數的最後一個參數是超參數的集合,它們與傳遞給 Estimator 的內容相同。模型函數需要返回一個 EstimatorSpec 對象——它會定義完整的模型。

EstimatorSpec 接受預測,損失,訓練和評估幾種操作,因此它定義了用於訓練,評估和推理的完整模型圖。由於 EstimatorSpec 採用常規 TensorFlow Operations,因此我們可以使用像 TF-Slim 這樣的框架來定義自己的模型。

Experiment

Experiment(實驗)類是定義如何訓練模型,並將其與 Estimator 進行集成的方式。我們可以這樣創建一個實驗類:

Experiment 作為輸入:

  • 一個 Estimator(例如上面定義的那個)。
  • 訓練和評估數據作為第一級函數。這裡用到了和前述模型函數相同的概念,通過傳遞函數而非操作,如有需要,輸入圖可以被重建。我們會在後面繼續討論這個概念。
  • 訓練和評估鉤子(hooks)。這些鉤子可以用於監視或保存特定內容,或在圖形和會話中進行一些操作。例如,我們將通過操作來幫助初始化數據載入器。
  • 不同參數解釋了訓練時間和評估時間。
  • 一旦我們定義了 experiment,我們就可以通過 learn_runner.run 運行它來訓練和評估模型:

與模型函數和數據函數一樣,函數中的學習運算符將創建 experiment 作為參數。

Dataset

我們將使用 Dataset 類和相應的 Iterator 來表示我們的訓練和評估數據,並創建在訓練期間迭代數據的數據饋送器。在本示例中,我們將使用 TensorFlow 中可用的 MNIST 數據,並在其周圍構建一個 Dataset 包裝器。例如,我們把訓練的輸入數據表示為:

調用這個 get_train_inputs 會返回一個一級函數,它在 TensorFlow 圖中創建數據載入操作,以及一個 Hook 初始化迭代器。

本示例中,我們使用的 MNIST 數據最初表示為 Numpy 數組。我們創建一個佔位符張量來獲取數據,再使用佔位符來避免數據被複制。接下來,我們在 from_tensor_slices 的幫助下創建一個切片數據集。我們將確保該數據集運行無限長時間(experiment 可以考慮 epoch 的數量),讓數據得到清晰,並分成所需的尺寸。

為了迭代數據,我們需要在數據集的基礎上創建迭代器。因為我們正在使用佔位符,所以我們需要在 NumPy 數據的相關會話中初始化佔位符。我們可以通過創建一個可初始化的迭代器來實現。創建圖形時,我們將創建一個自定義的 IteratorInitializerHook 對象來初始化迭代器:

IteratorInitializerHook 繼承自 SessionRunHook。一旦創建了相關會話,這個鉤子就會調用 call after_create_session,並用正確的數據初始化佔位符。這個鉤子會通過 get_train_inputs 函數返回,並在創建時傳遞給 Experiment 對象。

train_inputs 函數返回的數據載入操作是 TensorFlow 操作,每次評估時都會返回一個新的批處理。

運行代碼

現在我們已經定義了所有的東西,我們可以用以下命令運行代碼:

如果你不傳遞參數,它將使用文件頂部的默認標誌來確定保存數據和模型的位置。訓練將在終端輸出全局步長、損失、精度等信息。除此之外,實驗和估算器框架將記錄 TensorBoard 可以顯示的某些統計信息。如果我們運行:

我們就可以看到所有訓練統計數據,如訓練損失、評估準確性、每步時間和模型圖。

評估精度在 TensorBoard 中的可視化

在 TensorFlow 中,有關 Estimator、Experiment 和 Dataset 框架的示例很少,這也是本文存在的原因。希望這篇文章可以向大家介紹這些架構工作的原理,它們應該採用哪些抽象方法,以及如何使用它們。如果你對它們很感興趣,以下是其他相關文檔。

關於 Estimator、Experiment 和 Dataset 的注釋

  • 論文《TensorFlow Estimators: Managing Simplicity vs. Flexibility in High-Level Machine Learning Frameworks》:terrytangyuan.github.io
  • Using the Dataset API for TensorFlow Input Pipelines:tensorflow.org/versions
  • tf.estimator.Estimator:tensorflow.org/api_docs
  • tf.contrib.learn.RunConfig:tensorflow.org/api_docs
  • tf.estimator.DNNClassifier:tensorflow.org/api_docs
  • tf.estimator.DNNRegressor:tensorflow.org/api_docs
  • Creating Estimators in tf.estimator:tensorflow.org/extend/e
  • tf.contrib.learn.Head:tensorflow.org/api_docs
  • 本文用到的 Slim 框架:github.com/tensorflow/m

完整示例

推理訓練模式

在訓練模型後,我們可以運行 estimateator.predict 來預測給定圖像的類別。可使用以下代碼示例。

原文鏈接:medium.com/onfido-tech/

選自Medium

作者:Peter Roelants

機器之心編譯

參與:李澤南、黃小天

本文為機器之心編譯,轉載請聯繫本公眾號獲得授權。

推薦閱讀:

TensorFlow 教程 #01 - 簡單線性模型
深入淺出TensorFlow(六)TensorFlow高層封裝
TF使用例子-情感分類
tf.set_random_seed

TAG:TensorFlow | 人工智能 | API |