Tensorflow實戰google深度學習框架代碼學習八(模型的保存和使用)
來自專欄 tensorflow代碼學習+pythonAI
#保存模型一共會有3個文件,ckpt保存變數的值,meta保存的是圖的結構,checkpoint保存此文件中所有模型的列表import tensorflow as tf a1 = tf.Variable(tf.truncated_normal(shape=[2],seed=2),name=a1)a2 = tf.Variable(tf.truncated_normal(shape=[2],seed=2),name=a2)result = a1 +a2#保存模型saver = tf.train.Saver()with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver.save(sess,saveckpt/model.ckpt)
import tensorflow as tfa1 = tf.Variable(tf.truncated_normal(shape=[2]),name=a1)a2 = tf.Variable(tf.truncated_normal(shape=[2]),name=a2)result = a1 +a2#載入模型saver = tf.train.Saver()with tf.Session() as sess: saver.restore(sess,saveckpt/model.ckpt) print(sess.run(result))
結果:
INFO:tensorflow:Restoring parameters from saveckpt/model.ckpt[-1.71622169 -0.39324597]
#直接載入持久化的圖import tensorflow as tfsaver = tf.train.import_meta_graph(saveckpt/model.ckpt.meta)with tf.Session() as sess: saver.restore(sess,saveckpt/model.ckpt) print(sess.run(a1))
結果:
INFO:tensorflow:Restoring parameters from saveckpt/model.ckpt[-0.85811085 -0.19662298]
#變數重命名import tensorflow as tfv1 = tf.Variable([1.0,2.1],name=v1)v2 = tf.Variable([2.0,3.0],name=v2)saver = tf.train.Saver(var_list={a1:v1,a2:v2})with tf.Session() as sess: saver.restore(sess,saveckpt/model.ckpt) print(sess.run(v1))
結果:
INFO:tensorflow:Restoring parameters from saveckpt/model.ckpt[-0.85811085 -0.19662298]
推薦閱讀:
※機器學習篇-指標:AUC
※如何訓練模型?(1)——最小二乘法
※圖解機器學習:參數起點設計的重要性,symmetry breaking
※機器學習項目流程清單
TAG:機器學習 |