標籤:

TensorFlow的高階介面Estimator的使用(1)

本文摘自人民郵電出版社非同步社區epubit.com.cn/article/1

日話題討論,贈送非同步新

在《TensorFlow機器學習項目實戰》的4.4節,作者使用了skflow。skflow剛出來的時候火了一陣,但是介面變化非常頻繁,所以後來用的人也越來越少,也導致4.4的程序不能運行了。

但是最近發布的TensorFlow 1.4中,我們發現該模塊已經集成到了核心模塊,意味著介面基本穩定下來,並有推廣使用的趨勢。所以我把4.4的程序重新用Estimator寫了一下,變數名基本保持不變,代碼如下:

# -*- coding: utf-8 -*- import tensorflow as tf from sklearn import datasets, metrics, preprocessing import numpy as np import pandas as pd import os df = pd. read_csv("data/CHD.csv", header=0) print( df.describe()) X=df[age]. astype(float) feature_columns = [tf.contrib.layers.real_valued_column ("X", dimension=1)] classifier = tf.estimator.LinearClassifier (feature_columns=feature_columns, model_dir=os.path.join(".","tmp","logistic")) #classifier = tf.estimator.LinearClassifier(feature_columns=feature_columns) input_fn_train= tf.estimator.inputs.numpy_input_fn( x={"X" : np.array(X)}, y=np.array(df[chd]), batch_size=2, num_epochs=None, shuffle=True) classifier.train(input_fn=input_fn_train,steps=2000) #模型的準確度 score = classifier.evaluate(input_fn=input_fn_train,steps=50) ["accuracy"] print("Accuracy: %f" % score)

註:這段程序可以在Ubuntu和MacOS下面跑,但是Windows下面還不行,是路徑的問題。這應該是Estimator的一個BUG,在contrib.learn下也是一樣的不行。如果想在windows下,一定要用注釋掉的部分。

這裡面最難寫的是input_fn函數,也是最重要的函數,我在這段程序中直接使用了numpy_input_fn來構建。[1]中除了這個方法還給出了從pandas構建的方法,大家可以自己嘗試。

input_fn帶來了一個好處,就是可以按照生產者消費者模式讀取數據,具體的解釋可以參考[2]。簡單的解釋,就是IO一般都比較慢,我們需要在數據處理的過程中進行讀取數據,那樣就可以充分的節省時間,這樣就設計多線程在後台不斷的取數據。

feature_colums的構建需要一定的技巧,這個主要參考[3]

另外的一個變化就是模型的準確度不再是用metric模塊,而是Estimator自帶的模塊。

如果大家有什麼問題歡迎留言

Reference

[1] 為Estimator構建輸入函數

[2] 從零開始山寨Caffe·陸:IO系統(一)

[3] tf.estimator Quickstart

本文摘自人民郵電出版社非同步社區,點擊下方閱讀原文查看更多。

延伸推薦

點擊關鍵詞閱讀更多新書:

Python|機器學習|Kotlin|Java|移動開發|機器人|有獎活動|Web前端|書單

在「非同步圖書」後台回復「關注」,即可免費獲得2000門在線視頻課程;推薦朋友關注根據提示獲取贈書鏈接,免費得非同步圖書一本。趕緊來參加哦!

點擊閱讀原文,查看本書更多信息

掃一掃上方二維碼,回復「關注」參與活動!


推薦閱讀:

YJango的前饋神經網路--代碼LV1
TensorFlow教程 - 預告
隨機森林再複習
TensorFlow 的簡單例子

TAG:TensorFlow |