標籤:

重新梳理tensorflow

終於換了一個高檔顯卡,所以對tf重新進行梳理,趕緊追上時代潮流

1)使用MNIST數據集

import input_data

mnist = input_data.read_data_sets(MNIST_data/,one_hot = True)

2)使用tf定義placeholder和Variable

x = tf.placeholder(float,[None,784])

W = tf.Variable(tf.zeros([784,10]))

b = tf.Variable(tf.zeros([10]))

3)使用定義的placeholder和Variable計算y值

y = tf.nn.softmax(tf.matmult(x,W)+b)

4)定義誤差函數

cross_entropy = tf.reduce_sum(y_*tf.log(y))

5)定義訓練的operator

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

6)初始化所有variable

sess.run(tf.global_variables_initializer())

7)對模型進行訓練

batch_xs,batch_ys = mnist.train.next_batch(100)

sess.run(train_step,feed_dict={x:batch_xs,y_batch_ys})

8)對模型進行評估

correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))

accuracy = tf.reduce_mean(tf.cast(correct_rediction,float))

sess.run(accuracy,feed_dict={})

歸納一下,tf通常套路:

1)定義placeholder作為輸入和真實結果

2)定義Variable作為參數

3)中間的值可以直接使用tf.nn庫中激活函數形式加以定義

4)定義loss函數

5)定義train_op,使用的是tf.train庫中優化方法,優化目標為定義的loss函數

6)在tf.Session中首先全部初始化Variable(global_variables_initializer),注意sess均使用的是run函數

7)定義循環次數,在for循環內部使用sess.run執行train_op,完成模型訓練

8)對於模型評估,直接使用真實結果placeholder與定義的模型輸出結果進行比較即可


推薦閱讀:

NLP(2) Tensorflow 文本- 價格建模 Part2
五分鐘喝不完一杯咖啡,但五分鐘可以帶你入門TensorFlow
TensorFlow 訓練好模型參數的保存和恢復代碼
TF使用例子-情感分類
TensorFlow 教程 #04 - 保存 & 恢復

TAG:TensorFlow |