Inside TF-Slim(9) evaluation
來自專欄 Bob學步4 人贊了文章
0. 前言
- Slim - evaluation 源碼。
- Slim 調用的 tensorflow/tensorflow/contrib/training/python/training/evaluation.py 源碼。
1. 基本概念
1.1. 作用
給定模型結構與預測數據,對模型的性能指標(metrics)進行預測/評估,並記錄中間結果(summary)。
1.2. 官方實例
- 調用checkpoint file
images, labels = LoadData(...)predictions = MyModel(images)names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({ "accuracy": slim.metrics.accuracy(predictions, labels), "mse": slim.metrics.mean_squared_error(predictions, labels),})# 定義summary操作for metric_name, metric_value in metrics_to_values.iteritems(): tf.summary.scalar(metric_name, metric_value)# 定義model和log的位置checkpoint_dir = /tmp/my_model_dir/log_dir = /tmp/my_model_eval/# 每十分鐘進行一次 evaluation操作num_evals = 1000slim.evaluation_loop( , checkpoint_dir, logdir, num_evals=num_evals, eval_op=names_to_updates.values(), summary_op=tf.contrib.deprecated.merge_summary(summary_ops), eval_interval_secs=600)
- 不帶有metrics的的evaluation操作(僅進行summary)
images, labels = LoadData(...)predictions = MyModel(images)tf.summary.scalar(...)tf.summary.histogram(...)checkpoint_dir = /tmp/my_model_dir/log_dir = /tmp/my_model_eval/# 每10分鐘evaluation一次slim.evaluation_loop( master=, checkpoint_dir, logdir, num_evals=1, summary_op=tf.contrib.deprecated.merge_summary(summary_ops), eval_interval_secs=600)
2. API與源碼
2.1. API介紹
wait_for_new_checkpoint
:設置等待時間,等待新的checkpoint,若有新的checkpoint則返回其path(string類型)。checkpoints_iterator
:持續獲得最新checkpoint file。如果時間間隔過大,可能會跳過中間的checkpoint file。evaluate_once
:進行一輪evaluate操作。evaluation_loop
:與上個函數的不同之處在於,該函數可以運行多輪(循環運行,可設置運行時間間隔)。
2.2. 源碼
# 進行一輪 evaluate 操作# 返回final_op的值或Nonedef evaluate_once(master, # 分散式相關 checkpoint_path, # 要導入的checkpoint file logdir, # summary地址 num_evals=1, # 一共要進行多少次eval_op操作 initial_op=None, # 模型初始化 initial_op_feed_dict=None, eval_op=None, # update_op 操作 eval_op_feed_dict=None, final_op=None, # evaluate 結束後進行的操作 final_op_feed_dict=None, summary_op=_USE_DEFAULT, # summary操作 summary_op_feed_dict=None, variables_to_restore=None, # 在evaluation期間需要保存的參數 session_config=None, hooks=None): if summary_op == _USE_DEFAULT: summary_op = summary.merge_all() # 設置停止條件 all_hooks = [evaluation.StopAfterNEvalsHook(num_evals),] # 設置summary條件,每一次eval_op後都進行summary_op操作 if summary_op is not None: all_hooks.append(evaluation.SummaryAtEndHook( log_dir=logdir, summary_op=summary_op, feed_dict=summary_op_feed_dict)) if hooks is not None: all_hooks.extend(hooks) # 要保存的參數列表 saver = None if variables_to_restore is not None: saver = tf_saver.Saver(variables_to_restore) # 下面函數代碼地址 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/evaluation.py # 使用 monitored_session 實現 return evaluation.evaluate_once( checkpoint_path, master=master, scaffold=monitored_session.Scaffold( init_op=initial_op, init_feed_dict=initial_op_feed_dict, saver=saver), eval_ops=eval_op, feed_dict=eval_op_feed_dict, final_ops=final_op, final_ops_feed_dict=final_op_feed_dict, hooks=all_hooks, config=session_config)def evaluation_loop(master, checkpoint_dir, logdir, num_evals=1, # 每輪evaluate的操作次數 initial_op=None, initial_op_feed_dict=None, init_fn=None, eval_op=None, eval_op_feed_dict=None, final_op=None, final_op_feed_dict=None, summary_op=_USE_DEFAULT, summary_op_feed_dict=None, variables_to_restore=None, eval_interval_secs=60, # 兩次evaluate操作的最短時間 max_number_of_evaluations=None, # 設置evaluate操作的最多次數 session_config=None, timeout=None, # 等待checkpoints 的最大時間 hooks=None): if summary_op == _USE_DEFAULT: summary_op = summary.merge_all() all_hooks = [evaluation.StopAfterNEvalsHook(num_evals),] if summary_op is not None: all_hooks.append(evaluation.SummaryAtEndHook( log_dir=logdir, summary_op=summary_op, feed_dict=summary_op_feed_dict)) if hooks is not None: # Add custom hooks if provided. all_hooks.extend(hooks) saver = None if variables_to_restore is not None: saver = tf_saver.Saver(variables_to_restore) # 除了調用的這個函數,其他與 evaluate_once 沒有區別 # 下面函數源碼在 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/training/python/training/evaluation.py # Repeatedly searches for a checkpoint in `checkpoint_dir` and evaluates it. return evaluation.evaluate_repeatedly( checkpoint_dir, master=master, scaffold=monitored_session.Scaffold( init_op=initial_op, init_feed_dict=initial_op_feed_dict, init_fn=init_fn, saver=saver), eval_ops=eval_op, feed_dict=eval_op_feed_dict, final_ops=final_op, final_ops_feed_dict=final_op_feed_dict, eval_interval_secs=eval_interval_secs, hooks=all_hooks, config=session_config, max_number_of_evaluations=max_number_of_evaluations, timeout=timeout) def wait_for_new_checkpoint(checkpoint_dir, # checkpoint file所在路徑 last_checkpoint=None, # 最後獲取的checkpoint file的名稱 seconds_to_sleep=1, # 兩次查詢的間隔時間 timeout=None): # 最多等待時間 logging.info(Waiting for new checkpoint at %s, checkpoint_dir) stop_time = time.time() + timeout if timeout is not None else None while True: # 獲取最新 checkpoint file checkpoint_path = tf_saver.latest_checkpoint(checkpoint_dir) if checkpoint_path is None or checkpoint_path == last_checkpoint: # 如果還沒有獲取過checkpoint file 或 最新checkpoint file沒有變 if stop_time is not None and time.time() + seconds_to_sleep > stop_time: # 超過最多等待時間,則退出 return None # 繼續等待 time.sleep(seconds_to_sleep) else: # 有新的checkpoint file logging.info(Found new checkpoint at %s, checkpoint_path) return checkpoint_path# 持續獲取最新的checkpoint filedef checkpoints_iterator(checkpoint_dir, min_interval_secs=0, timeout=None, timeout_fn=None): checkpoint_path = None while True: new_checkpoint_path = wait_for_new_checkpoint( checkpoint_dir, checkpoint_path, timeout=timeout) if new_checkpoint_path is None: if not timeout_fn: # 時間到,且沒有timeount_fn,則退出 logging.info(Timed-out waiting for a checkpoint.) return if timeout_fn(): # 如果有timeout_fn函數,且函數返回True,則退出 return else: # 繼續等待 continue start = time.time() checkpoint_path = new_checkpoint_path yield checkpoint_path time_to_next_eval = start + min_interval_secs - time.time() if time_to_next_eval > 0: time.sleep(time_to_next_eval)
3. 其他
- metrics主要功能就是獲取value和update_op……
- evaluation主要用於檢測階段,不用於訓練階段。使用evaluation時應注意,模型應該已經訓練完成,保存為一個checkpoint file。
- 在evaluation的具體實現中,使用了 monitored_session,也就是新版
training
使用的底層實現,下一步學習這個。 - 還沒想清楚該怎麼配合
tf.data.Dataset
與evaluation
使用。
推薦閱讀:
※手把手教你在windows7上安裝tensorflow-gpu開發環境
※TF官網學習(6)--創建Datasets
TAG:深度學習DeepLearning | TensorFlow |