Tensorflow模型保存和載入

Tensorflow模型保存和載入

1 人贊了文章

  1. Tensorflow模型文件

我們在checkpoint_dir目錄下保存的文件結構如下:

1.1 checkpoint文件:

該文件是文本文件,記錄了保存的最新的checkpoint文件以及其它checkpoint文件列表。

1.2 meta文件:

MyModel.meta文件保存的是圖結構信息,meta文件是pb(protocol buffer)格式文件,包含變數、op、集合等。

1.3 ckpt文件:

ckpt文件是二進位文件,保存了所有的weights、biases、gradients等變數。在tensorflow 0.11之前,保存在.ckpt文件中。0.11後,通過兩個文件保存,如:

2. 保存Tensorflow模型

tensorflow 提供了tf.train.Saver類來保存模型,值得注意的是,在tensorflow中,變數是存在於Session環境中,也就是說,只有在Session環境下才會存有變數值,因此,保存模型時需要傳入session:

看一個簡單例子:

執行後,在checkpoint_dir目錄下創建模型文件如下:

另外,如果想要在1000次迭代後,再保存模型,只需設置global_step參數即可:

保存的模型文件名稱會在後面加-1000,如下:

另一種比較實用的是,如果你希望每2小時保存一次模型,並且只保存最近的5個模型文件:

如果我們不對tf.train.Saver指定任何參數,默認會保存所有變數。如果你不想保存所有變數,而只保存一部分變數,可以通過指定variables/collections。在創建tf.train.Saver實例時,通過將需要保存的變數構造list或者dictionary,傳入到Saver中:

3. 導入訓練好的模型

此處借上次研究論文attention-based lstm for aspect-level sentiment classification的機會,自己寫了一下導入模型進行預測的過程。代碼會跟論文實現代碼相關。

checkpoint_file = tf.train.latest_checkpoint() #此處為checkpoint路徑graph = tf.Graph()with graph.as_default(): session_conf = tf.ConfigProto( allow_soft_placement=True, log_device_placement=False) sess = tf.Session(config=session_conf) with sess.as_default(): # Load the saved meta graph and restore variables saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file)) saver.restore(sess, checkpoint_file) # Get the placeholders from the graph by name input_x = graph.get_operation_by_name("inputs/x").outputs[0] print (input_x) aspect = graph.get_operation_by_name("inputs/aspect_id").outputs[0] print (aspect) sen_len = graph.get_operation_by_name("inputs/sen_len").outputs[0] keep_prob1_sxl = graph.get_operation_by_name("dropout_keep_prob1").outputs[0] keep_prob2_sxl = graph.get_operation_by_name("dropout_keep_prob2").outputs[0] # Tensors we want to evaluate losss = graph.get_operation_by_name("predict/predictions").outputs[0] result = sess.run(losss ,{input_x: a, aspect: b,sen_len:[11], keep_prob1_sxl:1.0, keep_prob2_sxl : 1.0}) #這裡的result是我自己的測試print (result)

完整代碼已上傳到github:

sxlprince/attention-based-lstm-for-aspect-level-sentiment-classification

但這段代碼似乎有點問題,通過這段代碼預測的輸出總是跟正確輸出錯位。比如正確輸出的one-hot形式是[0,1, 0],代表情感極性是1,但是預測的輸出是0,而如果正確輸出是[1,0, 0],則預測輸出就是1。重點是在對測試集進行測試時列印出的預測輸出卻是對的,一直沒找到原因,以後回過頭來再看QAQ。


推薦閱讀:

線性代數與TensorFlow程序用例
win10下GPU版本的Tensorflow最新安裝教程
TensorFlow的Summary
怎樣使用tensorflow導入已經下載好的mnist數據集?
MobileNetV2模型的簡單理解及其代碼實現

TAG:深度學習DeepLearning | TensorFlow |