標籤:

TensorFlow的checkpoint中變數的重命名

TensorFlow裡面,一般用tf.variable_scope()來規定變數的名字。

現在,想把一個訓練好的checkpoint裡面的variable_scope替換掉,而且不想通過建立Graph來實現。

有大神寫代碼如下:

import sys, getoptnnimport tensorflow as tfnnusage_str = python tensorflow_rename_variables.py --checkpoint_dir=path/to/dir/ n --replace_from=substr --replace_to=substr --add_prefix=abc --dry_runnnndef rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run):n checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)n with tf.Session() as sess:n for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):n # Load the variablen var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)nn # Set the new namen new_name = var_namen if None not in [replace_from, replace_to]:n new_name = new_name.replace(replace_from, replace_to)n if add_prefix:n new_name = add_prefix + new_namenn if dry_run:n print(%s would be renamed to %s. % (var_name, new_name))n else:n print(Renaming %s to %s. % (var_name, new_name))n # Rename the variablen var = tf.Variable(var, name=new_name)nn if not dry_run:n # Save the variablesn saver = tf.train.Saver()n sess.run(tf.global_variables_initializer())n saver.save(sess, checkpoint.model_checkpoint_path)nnndef main(argv):n checkpoint_dir = Nonen replace_from = Nonen replace_to = Nonen add_prefix = Nonen dry_run = Falsenn try:n opts, args = getopt.getopt(argv, h, [help=, checkpoint_dir=, replace_from=,n replace_to=, add_prefix=, dry_run])n except getopt.GetoptError:n print(usage_str)n sys.exit(2)n for opt, arg in opts:n if opt in (-h, --help):n print(usage_str)n sys.exit()n elif opt == --checkpoint_dir:n checkpoint_dir = argn elif opt == --replace_from:n replace_from = argn elif opt == --replace_to:n replace_to = argn elif opt == --add_prefix:n add_prefix = argn elif opt == --dry_run:n dry_run = Truenn if not checkpoint_dir:n print(Please specify a checkpoint_dir. Usage:)n print(usage_str)n sys.exit(2)nn rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run)nnnif __name__ == __main__:n main(sys.argv[1:])n

具體用法:

python tensorflow_rename_variables.py --checkpoint_dir=path/to/dir --replace_from=scope1 --replace_to=scope1/model --add_prefix=abc/n

代碼來源:

https://gist.github.com/batzner/7c24802dd9c5e15870b4b56e22135c96gist.github.com

stackoverflow上的問題:

Rename variable scope of saved model in TensorFlowstackoverflow.com圖標
推薦閱讀:

如何評價陳天奇的模塊化深度學習系統NNVM?
Tensorflow 的reduce_sum()函數到底是什麼意思,誰能解釋下?
[乾貨|實踐] TensorBoard可視化
如何看待Face++出品的小型化網路ShuffleNet?

TAG:TensorFlow |