標籤:

調優哪家強——tensorflow命令行參數

歡迎關注我們的微信公眾號「人工智慧LeadAI」(ID:atleadai)

深度學習神經網路往往有過多的Hyperparameter需要調優,優化演算法、學習率、卷積核尺寸等很多參數都需要不斷調整,使用命令行參數是非常方便的。有兩種實現方式,一是利用python的argparse包,二是調用tensorflow自帶的app.flags實現。

利用python的argparse包

argparse介紹及基本使用:

jianshu.com/p/b8b09084b

下面代碼用argparse實現了命令行參數的輸入。

import argparsenimport sysnparser = argparse.ArgumentParser()nparser.add_argument(--fake_data, nargs=?, const=True, type=bool, ndefault=False, nhelp=If true, uses fake data for unit testing.)nparser.add_argument(--max_steps, type=int, default=1000, nhelp=Number of steps to run trainer.)nparser.add_argument(--learning_rate, type=float, default=0.001, nhelp=Initial learning rate)nparser.add_argument(--dropout, type=float, default=0.9, nhelp=Keep probability for training dropout.)nparser.add_argument(--data_dir, type=str, default=/tmp/tensorflow/mnist/input_data, help=Directory for storing input data) parser.add_argument(--log_dir, type=str, default=/tmp/tensorflow/mnist/logs/mnist_with_summaries, nhelp=Summaries log directory) FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)n

通過調用python的argparse包,調用函數parser.parse_known_args()解析命令行參數。代碼運行後得到的FLAGS是一個結構體,內部參數分別為:

FLAGS.data_dirnOut[5]: /tmp/tensorflow/mnist/input_datan FLAGS.fake_data Out[6]: False FLAGS.max_stepsnOut[7]: 1000n FLAGS.learning_ratenOut[8]: 0.001n FLAGS.dropoutnOut[9]: 0.9n FLAGS.data_dirnOut[10]: /tmp/tensorflow/mnist/input_datan FLAGS.log_dirnOut[11]: /tmp/tensorflow/mnist/logs/mnist_with_summariesn

利用tf.app.flags組件

首先需要定義一個tf.app.flags對象,調用自帶的DEFINE_string, DEFINE_boolean, DEFINE_integer, DEFINE_float設置不同類型的命令行參數及其默認值。當然,也可以在終端用命令行參數修改這些默認值。

# Define hyperparametersnflags = tf.app.flagsnFLAGS = flags.FLAGSnflags.DEFINE_boolean("enable_colored_log", False, "Enable colored log") n"The glob pattern of train TFRecords files")nflags.DEFINE_string("validate_tfrecords_file", n"./data/a8a/a8a_test.libsvm.tfrecords", n"The glob pattern of validate TFRecords files")nflags.DEFINE_integer("label_size", 2, "Number of label size")nflags.DEFINE_float("learning_rate", 0.01, "The learning rate")n def main(): n # Get hyperparameters nif FLAGS.enable_colored_log: nimport coloredlogs ncoloredlogs.install() nlogging.basicConfig(level=logging.INFO) nFEATURE_SIZE = FLAGS.feature_size nLABEL_SIZE = FLAGS.label_size n... nreturn 0nif __name__ == 『__main__』: main()n

這段代碼採用的是tensorflow庫中自帶的tf.app.flags模塊實現命令行參數的解析。如果用終端運行tf程序,用上述兩種方式都可以,如果用spyder之類的工具,那麼只有第一種方式有用,第二種方式會報錯。

其中有個tf.app.flags組件,還有個tf.app.run()函數。官網幫助文件是這麼說的:

flags module: Implementation of the flags interface. run(...): Runs the program with an optional main function and argv list.

tf.app.run的源代碼:

1."""Generic entry point script.""" n2.from __future__ import absolute_import n3.from __future__ import division n4.from __future__ import print_function n5. n6.import sys n7. n8.from tensorflow.python.platform import flags n9. n10. n11.def run(main=None): n12. f = flags.FLAGS n13. f._parse_flags() n14. main = main or sys.modules[__main__].main n15. sys.exit(main(sys.argv))n

也就是處理flag解析,然後執行main函數。

用shell腳本實現訓練代碼的執行

在終端執行python代碼,首先需要在代碼文件開頭寫入shebang,告訴系統環境變數如何設置,用python2還是用python3來編譯這段代碼。然後修改代碼許可權為可執行,用

./python_code.py

就可以執行。同理,這段代碼也可以用shell腳本來實現。創建.sh文件,運行python_code.py並設置參數max_steps=100

python python_code.py --max_steps 100n

推薦閱讀:

TAG:TensorFlow |