標籤:

TensorFlow RNN 教程和代碼

分析:

看 TensorFlow 也有一段時間了,準備按照 GitHub 上的教程,敲出來,順便整理一下思路。

博客:TensorFlow 安裝,TensorFlow 教程,TensorFlowNews 原創人工智慧,機器學習,深度學習,神經網路,計算機視覺,自然語言處理項目分享。

RNN部分

定義參數,包括數據相關,訓練相關。

定義模型,損失函數,優化函數。

訓練,準備數據,輸入數據,輸出結果。

代碼:

#!/usr/bin/env pythonn# -*- coding: utf-8 -*-n# Shared by http://www.tensorflownews.com/n# Github:https://github.com/TensorFlowNewsnnimport tensorflow as tfnfrom tensorflow.examples.tutorials.mnist import input_datanfrom tensorflow.contrib import rnnnnmnist=input_data.read_data_sets("./data",one_hot=True)nntraining_rate=0.001ntraining_iters=100000nbatch_size=128ndisplay_step=10nnn_input=28nn_steps=28nn_hidden=128nn_classes=10nnx=tf.placeholder("float",[None,n_steps,n_input])ny=tf.placeholder("float",[None,n_classes])nnweights={out:tf.Variable(tf.random_normal([n_hidden,n_classes]))}nbiases={out:tf.Variable(tf.random_normal([n_classes]))}nndef RNN(x,weights,biases):ntx=tf.unstack(x,n_steps,1)ntlstm_cell=rnn.BasicLSTMCell(n_hidden,forget_bias=1.0)ntoutputs,states=rnn.static_rnn(lstm_cell,x,dtype=tf.float32)ntreturn tf.matmul(outputs[-1],weights[out])+biases[out]nnpred=RNN(x,weights,biases)ncost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))noptimizer=tf.train.AdamOptimizer(learning_rate=training_rate).minimize(cost)nncorrect_pred=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))naccuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32))nninit=tf.global_variables_initializer()nnwith tf.Session() as sess:ntsess.run(init)ntstep=1ntwhile step*batch_size<training_iters:nttbatch_x,batch_y=mnist.train.next_batch(batch_size)nttbatch_x=batch_x.reshape(batch_size,n_steps,n_input)nttsess.run(optimizer,feed_dict={x:batch_x,y:batch_y})nttif step%display_step==0:ntttacc=sess.run(accuracy,feed_dict={x:batch_x,y:batch_y})ntttloss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})ntttprint("Iter " + str(step * batch_size) + ", Minibatch Loss= " + nttt "{:.6f}".format(loss) + ", Training Accuracy= " + nttt "{:.5f}".format(acc))nttstep+=1n

推薦閱讀:

TensorFlow Tutorial-1

TAG:TensorFlow |