tensorflow保存和恢復模型的兩種方法介紹

tensorflow保存和恢復模型的兩種方法介紹

一、前言

本文將會介紹tensorflow保存和恢復模型的兩種方法,一種是傳統的Saver類save保存和restore恢復方法,還有一種是比較新穎的SavedModelBuilder類的builder保存和loader文件里的load恢復方法。通過了解這兩種方法,我們可以解決如何保存和恢復一個已經訓練好的神經網路模型用於推理預測的現實需求,也可以輔助查看分析一個長時間訓練的模型性能,最重要的是我們可以預防因長時間訓練中途出現斷電、宕機、出錯退出等問題導致的訓練功虧一簣問題!可見,掌握tensorflow保存和恢復模型的方法,對我們工程應用有多麼大的幫助,同時,這也是我們必須要掌握的基礎技能,下面我將分別介紹它們!

二、模型保存恢復之save/restore方法

save和restore方法主要在Saver類里實現,源代碼位於tensorflow/python/training/saver.py

2-1)不管是save還是restore,我們首先都是要新建一個Saver,使用方法如下:

saver = tf.train.Saver(...)

注意一點:位於 tf.train.Saver()之後的變數將不會被存儲!

Saver的構造函數如下:

__init__( var_list=None, reshape=False, sharded=False, max_to_keep=5, keep_checkpoint_every_n_hours=10000.0, name=None, restore_sequentially=False, saver_def=None, builder=None, defer_build=False, allow_empty=False, write_version=tf.train.SaverDef.V2, pad_step_number=False, save_relative_paths=False, filename=None)

對我們來說比較關注的有以下幾個配置參數:

保存模型時:

var_list:特殊需要保存和恢復的變數和可保存對象列表或字典,默認為空,將會保存所有的可保存對象;

max_to_keep:保存多少個最新的checkpoint文件,默認為5,即保存最近五個checkpoint文件;

keep_checkpoint_every_n_hours:多久保存checkpoint文件,默認為10000小時,相當于禁用了這個功能;

save_relative_paths:為True時,checkpoint文件將不會記錄完整的模型路徑,而只會僅僅記錄模型名字,這方便於將保存下來的模型複製到其他目錄並使用的情況;

恢復模型時:

reshape:為True時,允許從已保存checkpoint文件里恢復並重新設定形狀不一樣的張量,默認為false;

sharded:碎片化checkpoint文件到每一個設備,默認false;

restore_sequentially:為True時,會在每個設備中順序地恢復不同的變數,同時可以在恢複比較大的模型時節省內存;

2-2)使用Saver類的save介面保存模型

saver.save(...)

save介面如下:

save( sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix=meta, write_meta_graph=True, write_state=True)

該方法運行為保存變數的構造函數所添加的ops,它需要一個已經建好圖的會話,同時要求所有變數均已經被初始化,該函數返回保存模型的絕對路徑,可用於restore時使用。

其參數說明如下:

sess:一個建好圖的會話,用以運行保存操作;

save_path:包含模型名字的絕對路徑,最終會自動在模型名字添加相應後綴

global_step:該參數會自動添加到save_path名字用以區別不同步驟保存的模型;

latest_filename:生成檢查點文件的名字,默認是「checkpoint」;

meta_graph_suffix:MetaGraphDef元圖後綴,默認為「meta」;

write_meta_graph:指明是否要保存元圖數據,默認為True;

write_state:指明是否要寫CheckpointStateProto,默認為True;

2-3)獲取最近保存的所有模型

last_ckpt = saver.last_checkpoints

或者使用如下方法:

# get_checkpoint_state(checkpoint_dir, latest_filename=None)ckpt = tf.train.get_checkpoint_state("/home/xsr-ai/study/mnist/mnist-model")

這將會得到一個包含有最近保存模型的列表,但是不包括checkpoint檢查點文件,如下;

我們要恢復哪一個模型,可以使用如下任一種類似方法:

saver.restore(last_ckpt[-1])saver.restore(last_ckpt[0])saver.restore(ckpt.model_checkpoint_path)saver.restore(ckpt.all_model_checkpoint_paths[-1])

2-4)使用restore恢復已保存模型

saver.restore(sess, save_path)

該函數恢復一個已保存的模型,它需要一個已建好圖結構的會話,恢復模型得到的變數無需初始化,在恢復過程中已有對保存變數做了初始化操作。

sess:用以恢復參數模型的會話;

save_path:已保存模型的路徑,通常包含模型名字;

2-5)圖存儲和載入write_graph/import_graph_def方法

有時候我們建立好一個會話圖後,需要保存,以供將來使用,那麼以下方法是很有效的!

圖存儲方法:

def write_graph(graph_or_graph_def, logdir, name, as_text=True):

該函數存儲一個tensorflow圖原型到文件里,其參數含義如下:

graph_or_graph_def:tensorflow Graph或GraphDef;

logdir:保存圖或圖原型的目錄;

as_text:默認為True,即以ASCII方式寫到文件里

return:返回圖或圖原型保存的路徑

使用例子如下:

v = tf.Variable(0, name=my_variable)sess = tf.Session()# tf.train.write_graph(sess.graph, /tmp/my-model, train.pbtxt) --> that is oktf.train.write_graph(sess.graph_def, /tmp/my-model, train.pbtxt)

圖載入方法:

def import_graph_def(graph_def, input_map=None, return_elements=None, name=None, op_dict=None, producer_op_list=None):

該函數可載入已存儲的"graph_def"到當前默認圖裡,並從系列化的tensorflow [`GraphDef`]協議緩衝里提取所有的tf.Tensor和tf.Operation到當前圖裡,其參數如下:

graph_def:一個包含圖操作OP且要導入GraphDef的默認圖;

input_map:字典關鍵字映射,用以從已保存圖裡恢復出對應的張量值;

return_elements:從已保存模型恢復的Ops或Tensor對象;

return:從已保存模型恢復後的Ops和Tensorflow列表,其名字位於return_elements;

使用例子如下:

with tf.Session() as _sess: with gfile.FastGFile("/tmp/tfmodel/train.pbtxt",rb) as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) _sess.graph.as_default() tf.import_graph_def(graph_def, name=tfgraph)

2-6)MetaGraph導出和導入export_meta_graph/ import_meta_graph方法

先了解一下什麼是MetaGraph:

一個MetaGraph既包含了tensorflow GraphDef,也包含了在跨越進程邊界時在圖形中運行計算所需的相關元數據,它也可以用來長期存儲tensorflow圖結構。MetaGraph包含繼續訓練、執行評估或在先前訓練的圖形上運行推理所需的信息。

MetaGraph包含的信息被表示為一個MetaGraphDef協議緩衝,它包含如下幾方面:

MetaInfoDef:元信息,比如版本信息和用戶信息;

GraphDef:用於描述一個圖結構;

SaverDef:用於Saver;

CollectionDef :映射進一步描述模型的其他組件,比如變數或tensorflow隊列;

MetaGraph導出方法:

def export_meta_graph(filename=None, collection_list=None, as_text=False, export_scope=None, clear_devices=False, clear_extraneous_savers=False):

該函數可以導出tensorflow元圖及其所需的數據,其參數如下:

filename:保存路徑及其文件名;

collection_list:要收集的字元串鍵的列表;

as_text:為True時導出的文本格式為ASCII編碼;

export_scope:導出的名字空間,用以刪除;

clear_devices:導出時將與設備相關的信息去掉,即導出文件不與特定設備環境關聯;

clear_extraneous_savers:從圖中刪除與此導出操作無關的任何saver相關信息(保存/恢復操作和SaverDefs)。

return:MetaGraphDef proto;

官方提供的使用常式:

# Build the model...with tf.Session() as sess: # Use the model ...# Export the default running graph and only a subset of the collections.meta_graph_def = tf.train.export_meta_graph( filename=/tmp/my-model.meta, collection_list=["input_tensor", "output_tensor"])

MetaGraph導入方法:

def import_meta_graph(meta_graph_or_file, clear_devices=False, import_scope=None, **kwargs):

該函數以「MetaGraphDef」協議緩衝區作為輸入,如果其參數是一個包含「MetaGraphDef」協議緩衝區的文件,它將以文件內容構造一個協議緩衝區,然後將「graph_def」欄位中的所有節點添加到當前圖形,並重新創建所有由collection_list收集的列表內容,最後返回由「saver_def」欄位構造的saver以供使用,其參數如下:

meta_graph_or_file:`MetaGraphDef`協議緩衝區或者包含MetaGraphDef且帶有路徑的文件名;

clear_devices:導入時將與設備相關的信息去掉,即不與導出時的圖設備環境關聯,可兼容當前設備環境;

import_scope:導入名字空間,用以刪除;

**kwargs:可選的參數;

return:在「MetaGraphDef」中由「saver_def」構造的存儲模型,如果MetaGraphDef沒有保存的變數則會直接返回None;

官方提供的使用常式:

...# Create a saver.saver = tf.train.Saver(...variables...)# Remember the training_op we want to run by adding it to a collection.tf.add_to_collection(train_op, train_op)sess = tf.Session()for step in xrange(1000000): sess.run(train_op) if step % 1000 == 0: # Saves checkpoint, which by default also exports a meta_graph # named my-model-global_step.meta. saver.save(sess, my-model, global_step=step)with tf.Session() as sess: new_saver = tf.train.import_meta_graph(my-save-dir/my-model-10000.meta) new_saver.restore(sess, my-save-dir/my-model-10000) # tf.get_collection() returns a list. In this example we only want the # first one. train_op = tf.get_collection(train_op)[0] for step in xrange(1000000): sess.run(train_op)

三、舉例說明save/restore方法

下面我們將基於mnist寫一個例子來說明如何使用save/restore方法保存和恢復模型,這是一個基於softmax的mnist常式,為了執行這個程序,我們需要事先下載mnist數據,可到網站

MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges下載

執行命令時需要指定命令行參數「--data_dir」到你存放mnist數據的目錄,例如:

python mnist_softmax.py --data_dir /home/xsr-ai/study/mnist/

"""A very simple MNIST classifier.See extensive documentation athttps://www.tensorflow.org/get_started/mnist/beginners"""from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport argparseimport sysfrom tensorflow.examples.tutorials.mnist import input_dataimport tensorflow as tfFLAGS = Nonedef main(_): # Import data mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) # Create the model x = tf.placeholder(tf.float32, [None, 784]) W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) y = tf.matmul(x, W) + b # Define loss and optimizer y_ = tf.placeholder(tf.float32, [None, 10]) # The raw formulation of cross-entropy, # # tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)), # reduction_indices=[1])) # # can be numerically unstable. # # So here we use tf.nn.softmax_cross_entropy_with_logits on the raw # outputs of y, and then average across the batch. cross_entropy = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)) train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) sess = tf.InteractiveSession() tf.global_variables_initializer().run() # Train saver = tf.train.Saver() for index in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) if index % 100 == 0: print("index: %d" % index) path = saver.save(sess, "/home/xsr-ai/study/mnist/mnist-model/model.ckpt", global_step=index) # , latest_filename="hello" # Test trained model correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})) ckpt = tf.train.get_checkpoint_state("/home/xsr-ai/study/mnist/mnist-model") saver.restore(sess, ckpt.all_model_checkpoint_paths[0]) print(ckpt) print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))if __name__ == __main__: parser = argparse.ArgumentParser() parser.add_argument(--data_dir, type=str, default=/tmp/tensorflow/mnist/input_data, help=Directory for storing input data) FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

執行如上程序會得到如下終端列印信息:

我們看到準確率並不是很高,那是我們沒有使用包含空間信息的卷積神經網路結構來訓練,同時在模型保存的目錄下面出現幾個保存的模型:

包含一個checkpoint文件,它記錄了max_to_keep個最新的保存模型信息,如下:

同時,按照默認max_to_keep等於5則包含五個模型信息,其中有 5 個 model.ckpt-{global_step}.data-00000-of-00001 文件,是訓練過程中保存的模型,5 個 model.ckpt-{global_step}.meta 文件,是訓練過程中保存的元數據(TensorFlow 默認只保存最近 5 個模型和元數據,刪除前面沒用的模型和元數據),5 個 model.ckpt-{global_step}.index 文件,{global_step}代表迭代次數。

實際上,我有在程序後面使用saver.restore方法恢復了保存的模型,然後進行了預測:

ckpt = tf.train.get_checkpoint_state("/home/xsr-ai/study/mnist/mnist-model") saver.restore(sess, ckpt.all_model_checkpoint_paths[0]) print(ckpt) print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

因為使用的模型是五個中比較早並非最近的一個,所以預測準確率只有0.9125,這比最後一次預測的準確率0.918要差一點點的。

四、模型保存恢復之builder/loader方法

builder/loader方法也是可以保存和恢復tensorflow模型的,只是他們源代碼是在不同文件里,builder其源代碼在tensorflow/python/saved_model/builder_impl.py,而loader的源代碼則位於tensorflow/python/saved_model/loader_impl.py。相較於save和restore方法會生成比較多的模型文件,builder和loader方法則會更簡單一些,同時也是saver提供的更高級別的系列化,它也更適合於商業化,按照創作者的說法「它顯然是未來!」

使用builder方法保存模型:

我們主要使用SavedModelBuilder類來新建一個builder,SavedModelBuilder的參數很簡單,就一個export_dir參數即要保存模型的路徑,但要確保所保存的目錄是未有建立的,否則會導致出錯!

獲取builder方法如下:

builder = tf.saved_model.builder.SavedModelBuilder("/home/xsr-ai/study/mnist/saved-model")

在訓練完後,我們調用如下命令保存模型:

builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.TRAINING], signature_def_map=None, assets_collection=None)builder.save()

add_meta_graph_and_variables的介紹如下:

def add_meta_graph_and_variables(sess,tags,signature_def_map=None,assets_collection=None,legacy_init_op=None,clear_devices=False,main_op=None):

該函數可以將當前元圖添加到SavedModel並保存變數,其參數如下:

sess:用於執行添加元圖和變數功能的會話;

tags:用於保存元圖的標籤;

signature_def_map:用於保存元圖的簽名;

assets_collection:使用SavedModel保存的資源集合;

legacy_init_op:在恢復模型操作後,對Op和Ops組的遺留支持;

clear_devices:如果默認圖形上的設備信息應該被清除,則應該設置為true;

main_op:在載入圖時執行Op或Ops組的操作。請注意,當main_op被指定時,它將在載入恢復op後運行;

return:無返回

save()的介紹:

def save(as_text=False):

該函數將「SavedModel」協議緩衝區的數據寫入到硬碟里,其參數只有一個as_text,主要用於指明是否按照ASCII編碼格式寫入到文件里,其返回的是保存模型的路徑。

使用loader方法恢復模型:

我們主要使用load(...)來恢復模型:

def load(sess, tags, export_dir, **saver_kwargs):

該函數可以從標籤指定的SavedModel載入模型,其參數如下:

sess:恢復模型的會話;

tags:用於恢復元圖的標籤,需與保存時的一致,用於區別不同的模型;

export_dir:存儲SavedModel協議緩衝區和要載入的變數的目錄;

**saver_kwargs:可選的關鍵字參數傳遞給saver;

return:在提供的會話中載入的「MetaGraphDef」協議緩衝區,這可以用於進一步提取signature-defs, collection-defs等;

load通常使用方法如下:

with tf.Session() as sess: tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING], "/home/xsr-ai/study/mnist/saved-model")

一定要注意標籤和模型路徑都要與保存模型時一致,然後使用相應的變數時,需要保存時的名字空間!

五、舉例說明builder/loader方法

與save/restore方法一樣,我們也用mnist來舉例說明如何使用builder/loader方法來保存恢復模型,但這次我們用卷積神經網路的方法,順便看看準確率是不是有很大的提高!

"""A simple MNIST classifier which displays summaries in TensorBoard.This is an unimpressive MNIST model, but it is a good example of usingtf.name_scope to make a graph legible in the TensorBoard graph explorer, and ofnaming summary tags so that they are grouped meaningfully in TensorBoard.It demonstrates the functionality of every TensorBoard dashboard."""from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport argparseimport osimport sysimport tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_dataFLAGS = Nonedef train(): # Import data mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True, fake_data=FLAGS.fake_data) sess = tf.InteractiveSession() # Create a multilayer model. # Input placeholders with tf.name_scope(input): x = tf.placeholder(tf.float32, [None, 784], name=x-input) y_ = tf.placeholder(tf.float32, [None, 10], name=y-input) with tf.name_scope(input_reshape): image_shaped_input = tf.reshape(x, [-1, 28, 28, 1]) tf.summary.image(input, image_shaped_input, 10) # We cant initialize these variables to 0 - the network will get stuck. def weight_variable(shape): """Create a weight variable with appropriate initialization.""" initial = tf.truncated_normal(shape, stddev=0.1) return tf.Variable(initial) def bias_variable(shape): """Create a bias variable with appropriate initialization.""" initial = tf.constant(0.1, shape=shape) return tf.Variable(initial) def variable_summaries(var): """Attach a lot of summaries to a Tensor (for TensorBoard visualization).""" with tf.name_scope(summaries): mean = tf.reduce_mean(var) tf.summary.scalar(mean, mean) with tf.name_scope(stddev): stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) tf.summary.scalar(stddev, stddev) tf.summary.scalar(max, tf.reduce_max(var)) tf.summary.scalar(min, tf.reduce_min(var)) tf.summary.histogram(histogram, var) def nn_layer(input_tensor, input_dim, output_dim, layer_name, act=tf.nn.relu): """Reusable code for making a simple neural net layer. It does a matrix multiply, bias add, and then uses ReLU to nonlinearize. It also sets up name scoping so that the resultant graph is easy to read, and adds a number of summary ops. """ # Adding a name scope ensures logical grouping of the layers in the graph. with tf.name_scope(layer_name): # This Variable will hold the state of the weights for the layer with tf.name_scope(weights): weights = weight_variable([input_dim, output_dim]) variable_summaries(weights) with tf.name_scope(biases): biases = bias_variable([output_dim]) variable_summaries(biases) with tf.name_scope(Wx_plus_b): preactivate = tf.matmul(input_tensor, weights) + biases tf.summary.histogram(pre_activations, preactivate) activations = act(preactivate, name=activation) tf.summary.histogram(activations, activations) return activations hidden1 = nn_layer(x, 784, 500, layer1) with tf.name_scope(dropout): keep_prob = tf.placeholder(tf.float32) tf.summary.scalar(dropout_keep_probability, keep_prob) dropped = tf.nn.dropout(hidden1, keep_prob) # Do not apply softmax activation yet, see below. y = nn_layer(dropped, 500, 10, layer2, act=tf.identity) with tf.name_scope(cross_entropy): # The raw formulation of cross-entropy, # # tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.softmax(y)), # reduction_indices=[1])) # # can be numerically unstable. # # So here we use tf.nn.softmax_cross_entropy_with_logits on the # raw outputs of the nn_layer above, and then average across # the batch. diff = tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y) with tf.name_scope(total): cross_entropy = tf.reduce_mean(diff) tf.summary.scalar(cross_entropy, cross_entropy) with tf.name_scope(train): train_step = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize( cross_entropy) with tf.name_scope(accuracy): with tf.name_scope(correct_prediction): correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) with tf.name_scope(accuracy): accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) tf.summary.scalar(accuracy, accuracy) # Merge all the summaries and write them out to # /tmp/tensorflow/mnist/logs/mnist_with_summaries (by default) merged = tf.summary.merge_all() train_writer = tf.summary.FileWriter(FLAGS.log_dir + /train, sess.graph) test_writer = tf.summary.FileWriter(FLAGS.log_dir + /test) tf.global_variables_initializer().run() # Train the model, and also write summaries. # Every 10th step, measure test-set accuracy, and write test summaries # All other steps, run train_step on training data, & add training summaries def feed_dict(train): """Make a TensorFlow feed_dict: maps data onto Tensor placeholders.""" if train or FLAGS.fake_data: xs, ys = mnist.train.next_batch(100, fake_data=FLAGS.fake_data) k = FLAGS.dropout else: xs, ys = mnist.test.images, mnist.test.labels k = 1.0 return {x: xs, y_: ys, keep_prob: k} for i in range(FLAGS.max_steps): if i % 100 == 0: # Record summaries and test-set accuracy summary, acc = sess.run([merged, accuracy], feed_dict=feed_dict(False)) test_writer.add_summary(summary, i) print(Accuracy at step %s: %s % (i, acc)) else: # Record train set summaries, and train if i % 100 == 99: # Record execution stats run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() summary, _ = sess.run([merged, train_step], feed_dict=feed_dict(True), options=run_options, run_metadata=run_metadata) train_writer.add_run_metadata(run_metadata, step%03d % i) train_writer.add_summary(summary, i) print(Adding run metadata for, i) else: # Record a summary summary, _ = sess.run([merged, train_step], feed_dict=feed_dict(True)) train_writer.add_summary(summary, i) builder = tf.saved_model.builder.SavedModelBuilder("/home/xsr-ai/study/mnist/saved-model") builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.TRAINING]) builder.save() train_writer.close() test_writer.close()def main(_): if tf.gfile.Exists(FLAGS.log_dir): tf.gfile.DeleteRecursively(FLAGS.log_dir) tf.gfile.MakeDirs(FLAGS.log_dir) train()if __name__ == __main__: parser = argparse.ArgumentParser() parser.add_argument(--fake_data, nargs=?, const=True, type=bool, default=False, help=If true, uses fake data for unit testing.) parser.add_argument(--max_steps, type=int, default=1000, help=Number of steps to run trainer.) parser.add_argument(--learning_rate, type=float, default=0.001, help=Initial learning rate) parser.add_argument(--dropout, type=float, default=0.9, help=Keep probability for training dropout.) parser.add_argument( --data_dir, type=str, default=os.path.join(os.getenv(TEST_TMPDIR, /tmp), tensorflow/mnist/input_data), help=Directory for storing input data) parser.add_argument( --log_dir, type=str, default="/home/xsr-ai/study/mnist/logdir", help=Summaries log directory) FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

執行完該程序後,終端列印信息如下:

很明顯,使用卷積神經網路,準確率大大提高到了0.9684,那麼會保存哪些東西呢?

進入到SavedModelBuilder指定的路徑「/home/xsr-ai/study/mnist/saved-model」,發現生成了如下東西:

一個pb文件,以及一個variables文件夾,裡面存放的是variables.data-00000-of-00001和

variables.index,與save/restore方法比,沒有checkpoint檢查點文件以及以「.meta」為後綴的元數據文件,但是多了一個pb文件,這是這兩種tensorflow保存和恢復模型方法的區別!

那麼又如何恢復由builder保存的模型呢?我使用如下例子來說明如何使用loader來恢復模型,代碼比較簡潔,主要是測試恢復模型後,可否正常獲取到特定的變數權值:

import tensorflow as tfwith tf.Session() as sess: tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING], "/home/xsr-ai/study/mnist/saved-model") var = sess.run(layer2/biases/Variable:0) print(var)

在jupyter notebook里執行該程序,可得到如下輸出:

列印出來的layer2/biases/Variable即是模型訓練時的最終值,可見,我們保存一個模型後,也是可以恢復然後再進行分析的!


推薦閱讀:

機器學習:彈性伸縮的雲端託管服務
TensorFlow基本使用
Windows下TensorBoard的使用
TensorFlow安裝、開發環境設定、使用入門
Github|如何用TensorFlow實現DenseNet和DenseNet-BC(附源代碼)

TAG:TensorFlow | 深度學習DeepLearning | 機器學習 |