TensorFlow從0到1 | 第十七章 Step By Step上手TensorBoard
上一篇 16 L2正則化對抗「過擬合 」提到,為了檢測訓練過程中發生的過擬合,需要記錄每次迭代(甚至每次step)模型在訓練集和驗證集上的識別精度。其實,為了能更好的調試和優化TensorFlow程序,日誌(logging)和監測(monitoring)需求遠不止「識別精度」。在訓練過程中不斷變化的「損失」、「更新速率」(step/sec)甚至「概率分布」等信息,都是幫助我們深入理解模型與訓練的關鍵信息。
對此,TensorBoard提供了盡善盡美的支持。它能將所記錄的動/靜態信息,方便的可視化成直觀的圖形,幫助人們更容易的分析並獲得洞察,讓神經網路「黑盒」變得透明。
本篇介紹TensorBoard的基本用法:繪製MNIST案例中計算圖、識別精度和損失。
訓練集和驗證集識別精度
TensorBoard生成圖形的流程框架,簡單概括起來就兩點:
- TensorFlow運行並將log信息記錄到文件;
- TensorBoard讀取文件並繪製圖形。
在代碼實現和組織層面,通常只要在「正常」代碼後集中添加負責logging的代碼即可,兩者能夠很好的區隔,不會發生嚴重的耦合。以下示例代碼基於16 L2正則化對抗「過擬合」,修改了logging的部分。
step 1:構造summary node
TensorBoard以protocol buffer 的方式記錄信息,它是Google開發的一種序列化結構數據的方法。
我們的目標是記錄accuracy和loss,更準確的說是記錄accuracy node和loss node的輸出值,那麼首先需要將數據轉換成protocol buffer object,而負責轉換動作的就是TensorFlow提供的summary節點(summary有匯總和概括的含義,暫不做翻譯)。
accuracy_scalar = tf.summary.scalar(accuracy, accuracy)loss_scalar = tf.summary.scalar(loss, loss)n
上面的tf.summary.scalar方法稱為summary operation。它接受一個要跟蹤的節點,並返回一個scalar summary節點,該節點以protocol buffer的方式表示一個標量值。
summary節點與其他節點一樣,依靠Session運行才會有輸出。如果跟蹤的節點非常多,還可以進行節點合併,Session在運行時會自動遍歷運行所有的summary節點:
merged = tf.summary.merge_all()
step 2:構造summary file writer
構造好summary node後,就要構造summary文件寫入器了,所有跟蹤的信息都依靠它來寫入文件,而TensorBoard繪製的圖形正是基於這些文件的。
train_writer = tf.summary.FileWriter(MNIST/logs/tf17/train)validation_writer = tf.summary.FileWriter(MNIST/logs/tf17/validation)n
tf.summary.FileWriter構造summary文件寫入器,接受一個log的目錄作為保存文件的路徑。log目錄如果不存在,會被程序自動創建。通常訓練集日誌和驗證集日誌分開存放,分別構造各自的summary文件寫入器即可。
step 3:運行summary節點
在運行summary節點時,出於性能考慮(儘可能少的運行計算圖),會與使用相同輸入數據的「正常」節點一起執行,下面代碼基於訓練數據,使用了合併的summary節點:
summary, accuracy_currut_train = sess.run(n [merged, accuracy], feed_dict={x: mnist.train.images, y_: mnist.train.labels})n
在summary節點不多時,當然也可以分別運行節點,下面代碼基於驗證數據,使用了單獨的summary節點:
(sum_accuracy_validation,n sum_loss_validation,n accuracy_currut_validation) = sess.run(n [accuracy_scalar, loss_scalar, accuracy], feed_dict={x: mnist.validation.images, y_: mnist.validation.labels})n
step 4:向記錄器添加
運行summary節點的輸出,即可通過文件寫入器的add_summary方法進行添加,
該方法除了接受summary節點的運行輸出值,還接受一個global_step參數來作為序列號:train_writer.add_summary(summary, epoch)nvalidation_writer.add_summary(sum_accuracy_validation, epoch)validation_writer.add_summary(sum_loss_validation, epoch)n
step 5:啟動TensorBoard Server
啟動TensorBoard Server可以與前面的記錄寫入並行,TensorBoard會自動的掃描日誌文件的更新。
重新生成並繪製,只需手工刪除現有數據或者目錄即可。
新啟動一個命令行窗口,激活虛擬環境,鍵入命令tensorboard,其參數logdir指出log文件的存放目錄,可以只給出其上級目錄,TensorBoard會自動遞歸掃描目錄:
tensorboard --logdir=TF1_1/MNIST
TensorBoard Server
當TensorBoard伺服器順利啟動後,即可打開瀏覽器輸入地址:http://127.0.0.1:6006/查看。注意在Windows環境下輸入http://0.0.0.0:6006/無效。下圖就是TensorBoard繪製出的accuracy和loss的圖形:
TensorBoard
圖形「同框」技巧
上圖中的accuracy和loss圖形中,訓練集曲線和驗證集曲線以不同顏色「同框」出現,特別便於對比分析。同框需要滿足以下兩點:
- 要同框的曲線跟蹤的必須是同一個節點,比如跟蹤accuracy節點或loss節點;
- 各曲線的數據記錄在不同的目錄下,可以通過構造兩個不同的文件寫入器來達到;
繪製計算圖
TensorBoard除了繪製動態數據,繪製靜態的graph(計算圖)更是easy,在構造「文件寫入器」多添加一個參數sess.graph即可:
train_writer = tf.summary.FileWriter(MNIST/logs/tf17/train, sess.graph)
重新運行TensorFlow程序後,啟動TensorBoard Server,在瀏覽器打開頁面,點選GRAPHS菜單,即可看到:
Graph
附完整代碼
import argparsenimport sysnfrom tensorflow.examples.tutorials.mnist import input_datanimport tensorflow as tfnnFLAGS = Nonennndef main(_):n # Import datan mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True,n validation_size=10000)nn # Create the modeln x = tf.placeholder(tf.float32, [None, 784])n W_2 = tf.Variable(tf.random_normal([784, 100]) / tf.sqrt(784.0))n W_2 = tf.get_variable(n name="W_2",n regularizer=regularizer,n initializer=tf.random_normal([784, 30], stddev=1 / tf.sqrt(784.0)))n b_2 = tf.Variable(tf.random_normal([100]))n z_2 = tf.matmul(x, W_2) + b_2n a_2 = tf.sigmoid(z_2)nn W_3 = tf.Variable(tf.random_normal([100, 10]) / tf.sqrt(100.0))n W_3 = tf.get_variable(n name="W_3",n regularizer=regularizer,n initializer=tf.random_normal([30, 10], stddev=1 / tf.sqrt(30.0)))n b_3 = tf.Variable(tf.random_normal([10]))n z_3 = tf.matmul(a_2, W_3) + b_3n a_3 = tf.sigmoid(z_3)nn # Define loss and optimizern y_ = tf.placeholder(tf.float32, [None, 10])nn tf.add_to_collection(tf.GraphKeys.WEIGHTS, W_2)n tf.add_to_collection(tf.GraphKeys.WEIGHTS, W_3)n regularizer = tf.contrib.layers.l2_regularizer(scale=5.0 / 50000)n reg_term = tf.contrib.layers.apply_regularization(regularizer)nn loss = (tf.reduce_mean(n tf.nn.sigmoid_cross_entropy_with_logits(labels=y_, logits=z_3)) +n reg_term)nn train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)nn sess = tf.InteractiveSession()n tf.global_variables_initializer().run()nn correct_prediction = tf.equal(tf.argmax(a_3, 1), tf.argmax(y_, 1))n accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))nn accuracy_scalar = tf.summary.scalar(accuracy, accuracy)n loss_scalar = tf.summary.scalar(loss, loss)n merged = tf.summary.merge_all()n train_writer = tf.summary.FileWriter(n MNIST/logs/tf17/train, sess.graph)n validation_writer = tf.summary.FileWriter(n MNIST/logs/tf17/validation)nn # Trainn best = 0n for epoch in range(30):n for _ in range(5000):n batch_xs, batch_ys = mnist.train.next_batch(10)n sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})n sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})n # Test trained modeln summary, accuracy_currut_train = sess.run(n [merged, accuracy],n feed_dict={x: mnist.train.images,n y_: mnist.train.labels})nn (sum_accuracy_validation,n sum_loss_validation,n accuracy_currut_validation) = sess.run(n [accuracy_scalar, loss_scalar, accuracy],n feed_dict={x: mnist.validation.images,n y_: mnist.validation.labels})nn train_writer.add_summary(summary, epoch)n validation_writer.add_summary(sum_accuracy_validation, epoch)n validation_writer.add_summary(sum_loss_validation, epoch)nn print("Epoch %s: train: %s validation: %s"n % (epoch, accuracy_currut_train, accuracy_currut_validation))n best = (best, accuracy_currut_validation)[n best <= accuracy_currut_validation]nn # Test trained modeln print("best: %s" % best)n train_writer.close()n validation_writer.close()nnnif __name__ == __main__:n parser = argparse.ArgumentParser()n parser.add_argument(--data_dir, type=str, default=../MNIST/,n help=Directory for storing input data)n FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)n
by 黑猿大叔丨首發簡書
推薦閱讀:
※用人工智慧來續寫《權力的遊戲》,靠譜嗎?
※神經網路從被人忽悠到忽悠人(一)
※經驗之談:如何為你的機器學習問題選擇合適的演算法
※國內圍棋 AI 挑戰世界冠軍,並非嘩眾取寵,而是深度學習的勝利
TAG:深度学习DeepLearning | 人工智能算法 | TensorFlow |