高級API用法示例
tf.contrib.learn Quickstart
TensorFlow的機器學習高級API(tf.contrib.learn)使配置、訓練、評估不同的學習模型變得更加容易。在這個教程里,你將使用tf.contrib.learn在Iris data set上構建一個神經網路分類器。代碼有一下5個步驟:
- 在TensorFlow數據集上載入Iris
- 構建神經網路
- 用訓練數據擬合
- 評估模型的準確性
- 在新樣本上分類
Complete Neural Network Source Code
這裡是神經網路的源代碼:
from __future__ import absolute_import
from __future__ import divisionfrom __future__ import print_functionimport os
import urllibimport numpy as npimport tensorflow as tf# Data setsIRIS_TRAINING = "iris_training.csv"IRIS_TRAINING_URL = "http://download.tensorflow.org/data/iris_training.csv"IRIS_TEST = "iris_test.csv"IRIS_TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"def main():# If the training and test sets arent stored locally, download them.
if not os.path.exists(IRIS_TRAINING): raw = urllib.urlopen(IRIS_TRAINING_URL).read() with open(IRIS_TRAINING, "w") as f: f.write(raw) if not os.path.exists(IRIS_TEST): raw = urllib.urlopen(IRIS_TEST_URL).read() with open(IRIS_TEST, "w") as f: f.write(raw) # Load datasets.training_set = tf.contrib.learn.datasets.base.load_csv_with_header( filename=IRIS_TRAINING, target_dtype=np.int,
features_dtype=np.float32) test_set = tf.contrib.learn.datasets.base.load_csv_with_header( filename=IRIS_TEST, target_dtype=np.int, features_dtype=np.float32) # Specify that all features have real-value data feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)] # Build 3 layer DNN with 10, 20, 10 units respectively. classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns, hidden_units=[10, 20, 10], n_classes=3, model_dir="/tmp/iris_model")# Define the training inputs def get_train_inputs():
x = tf.constant(training_set.data) y = tf.constant(training_set.target) return x, y # Fit model. classifier.fit(input_fn=get_train_inputs, steps=2000) # Define the test inputs def get_test_inputs(): x = tf.constant(test_set.data) y = tf.constant(test_set.target) return x, y # Evaluate accuracy.accuracy_score = classifier.evaluate(input_fn=get_test_inputs, steps=1)["accuracy"]
print("nTest Accuracy: {0:f}n".format(accuracy_score)) # Classify two new flower samples. def new_samples():return np.array(
[[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=np.float32) predictions = list(classifier.predict(input_fn=new_samples)) print( "New Samples, Class Predictions: {}n" .format(predictions))if __name__ == "__main__":main()
Load the Iris CSV data to TensorFlow
Iris data set包含了150行數據,3個種類:Iris setosa, Iris virginica, and Iris versicolor.
每一行包括了以下的數據:花萼的寬度,長度,花瓣的寬度,花的種類。花的種類有整數表示,0表示Iris setosa, 1表示Iris virginica, 2表示Iris versicolor.
這裡,Iris數據隨機分割成了兩組不同的CSV文件:
- 120個樣本的訓練數據(iris_training.csv)
- 30個樣本的測試數據(iris_test.csv).
開始時,首先引進所有必要的模塊,然後定義下載存儲數據集的路徑:
from __future__ import absolute_
import from __future__ import division from __future__ import print_functionimport os import urllib
import tensorflow as tfimport numpy as np IRIS_TRAINING = "iris_training.csv" IRIS_TRAINING_URL = "http://download.tensorflow.org/data/iris_training.csv" IRIS_TEST = "iris_test.csv" IRIS_TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"然後,如果訓練和測試集沒有在本地存儲,下載:
if not os.path.exists(IRIS_TRAINING):
raw = urllib.urlopen(IRIS_TRAINING_URL).read()with open(IRIS_TRAINING,w) as f:
f.write(raw) if not os.path.exists(IRIS_TEST): raw = urllib.urlopen(IRIS_TEST_URL).read() with open(IRIS_TEST,w) as f: f.write(raw)然後,用learn.datasets.base的load_csv_with_header()方法載入訓練集和測試集成Dataset S,load_csv_with_header()包涵一下三個參數:
- filename,CSV文件的路徑
- target_dtype,數據集目標值的numpy數據類型
- features_dtype,數據集特徵值的numpy數據類型
這裡,目標是花的種類,是0-2的整數,所以數據類型是np.int:
# Load datasets.
training_set = tf.contrib.learn.datasets.base.load_csv_with_header(filename=IRIS_TRAINING,
target_dtype=np.int, features_dtype=np.float32) test_set = tf.contrib.learn.datasets.base.load_csv_with_header( filename=IRIS_TEST, target_dtype=np.int, features_dtype=np.float32)tf.contrib.learn中的Dataset S是tuple,你可以通過data,target來訪問特徵值和目標值,比如,training_set.data,training_set.target
Construct a Deep Neural Network Classifier
tf.contrib.learn提供了多種預定義的模型,稱為 Estimator S,你可以用「黑盒子」在你的數據上來訓練和評估節點。這裡,你講配置深度神經網路分類器來擬合Iris數據,你可以用tf.contrib.learn.DNNClassifier作為示例:
# Specify that all features have real-value data
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)] # Build 3 layer DNN with 10, 20, 10 units respectively.classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns, hidden_units=[10, 20, 10], n_classes=3, model_dir="/tmp/iris_model")首先定義特徵所在的列,有4個特徵,所以dimension設定為4.
然後,構建了DNNClassifier,包含以下參數:
- feature_columns=feature_columns.上面定義的特徵的列
- hidden_units=[10, 20, 10]. 三個隱層,分別包含10,20,10個神經元
- n_classes=3.三個目標
- model_dir=/tmp/iris_model.訓練模型時保存的斷點數據
Describe the training input pipeline
tf.contrib.learn API使用輸入函數,創建TensorFlow節點來生成模型數據。這裡,數據比較小,可以放在tf.constant。
# Define the test inputs def get_train_inputs():
x = tf.constant(training_set.data)
y = tf.constant(training_set.target) return x, yFit the DNNClassifier to the Iris Training Data
配置了DNN分類器,你可以用fit方法來擬合數據,傳遞get_train_inputs到input_fn參數中,循環訓練2000次:
# Fit model. classifier.fit(input_fn=get_train_inputs, steps=2000)
等效於:
classifier.fit(x=training_set.data, y=training_set.target, steps=1000) classifier.fit(x=training_set.data, y=training_set.target, steps=1000)
如果你想追蹤訓練模型,你可以用TensorFlow monitor來執行節點的日誌。
「Logging and Monitoring Basics with tf.contrib.learn」
Evaluate Model Accuracy
你已經用訓練數據擬合了模型,現在,你可以用evaluate方法在測試集上評估準確性。像fit一樣,evaluate也需要一個輸入函數來構建輸入的通道,並返回評估結果的字典。
# Define the test inputs def get_test_inputs():
x = tf.constant(test_set.data) y = tf.constant(test_set.target) return x, y # Evaluate accuracy.accuracy_score = classifier.evaluate(input_fn=get_test_inputs, steps=1)["accuracy"] print("nTest Accuracy: {0:f}n".format(accuracy_score))運行整個腳本,列印:
Test Accuracy: 0.966667
Classify New Samples
用predict()方法來分類新的樣本,比如,你有下面的兩個新樣本:
predict方法返回一個generator,可以轉換成list
# Classify two new flower samples.
def new_samples():
return np.array( [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]],dtype=np.float32) predictions = list(classifier.predict(input_fn=new_samples))print(
"New Samples, Class Predictions: {}n" .format(predictions))結果大致如下:
New Samples, Class Predictions: [1 2]
原文鏈接:http://www.jianshu.com/p/daf5b68d8bc8
推薦閱讀: