重新梳理tensorflow
終於換了一個高檔顯卡,所以對tf重新進行梳理,趕緊追上時代潮流
1)使用MNIST數據集
import input_data
mnist = input_data.read_data_sets(MNIST_data/,one_hot = True)
2)使用tf定義placeholder和Variable
x = tf.placeholder(float,[None,784])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
3)使用定義的placeholder和Variable計算y值
y = tf.nn.softmax(tf.matmult(x,W)+b)
4)定義誤差函數
cross_entropy = tf.reduce_sum(y_*tf.log(y))
5)定義訓練的operator
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
6)初始化所有variable
sess.run(tf.global_variables_initializer())
7)對模型進行訓練
batch_xs,batch_ys = mnist.train.next_batch(100)
sess.run(train_step,feed_dict={x:batch_xs,y_batch_ys})
8)對模型進行評估
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_rediction,float))
sess.run(accuracy,feed_dict={})
歸納一下,tf通常套路:
1)定義placeholder作為輸入和真實結果
2)定義Variable作為參數
3)中間的值可以直接使用tf.nn庫中激活函數形式加以定義
4)定義loss函數
5)定義train_op,使用的是tf.train庫中優化方法,優化目標為定義的loss函數
6)在tf.Session中首先全部初始化Variable(global_variables_initializer),注意sess均使用的是run函數
7)定義循環次數,在for循環內部使用sess.run執行train_op,完成模型訓練
8)對於模型評估,直接使用真實結果placeholder與定義的模型輸出結果進行比較即可
推薦閱讀:
※NLP(2) Tensorflow 文本- 價格建模 Part2
※五分鐘喝不完一杯咖啡,但五分鐘可以帶你入門TensorFlow
※TensorFlow 訓練好模型參數的保存和恢復代碼
※TF使用例子-情感分類
※TensorFlow 教程 #04 - 保存 & 恢復
TAG:TensorFlow |