Inside TF-Slim(2) arg_scope

0. 前言

  • arg_scope 源碼地址
  • 參考博客:slim.arg_scope原理分析。
  • 源碼使用到了Python特性中的 裝飾器(即類似Java注釋) 和 上下文管理(即with語句)

1. 基本功能與使用

1.1. 基本功能

  • arg_scope的主要功能是為一些操作提供默認參數。
  • arg_scope可以疊加使用。
  • API定義:

def arg_scope(list_ops_or_scope, **kwargs):# list_ops_or_scope: 需要添加默認參數的ops和scope# **kwargs: 默認參數列表,如 param1=abcd

1.2. 官方實例

  • 官方實例1,如何使用arg_scope

from third_party.tensorflow.contrib.layers.python import layers arg_scope = tf.contrib.framework.arg_scope with arg_scope([layers.conv2d], padding=SAME, initializer=layers.variance_scaling_initializer(), regularizer=layers.l2_regularizer(0.05)): net = layers.conv2d(inputs, 64, [11, 11], 4, padding=VALID, scope=conv1) net = layers.conv2d(net, 256, [5, 5], scope=conv2) # 其中,第一個conv2d相當於:layers.conv2d(inputs, 64, [11, 11], 4, padding=VALID, initializer=layers.variance_scaling_initializer(), regularizer=layers.l2_regularizer(0.05), scope=conv1) # 第二個conv2d相當於:layers.conv2d(inputs, 256, [5, 5], padding=SAME, initializer=layers.variance_scaling_initializer(), regularizer=layers.l2_regularizer(0.05), scope=conv2)

  • 官方實例2,如何復用arg_scope

with arg_scope([layers.conv2d], padding=SAME, initializer=layers.variance_scaling_initializer(), regularizer=layers.l2_regularizer(0.05)) as sc: net = layers.conv2d(net, 256, [5, 5], scope=conv1) ....# 可以直接通過arg_scope對象with arg_scope(sc): net = layers.conv2d(net, 256, [5, 5], scope=conv2)

  • 官方實例3,如何令自己創建的函數也可用於arg_scope

@tf.contrib.framework.add_arg_scopedef conv2d(*args, **kwargs)

1.3. 使用總結

  • 使用步驟:
  1. 使用@slim.add_arg_scope修飾目標操作(即函數)。
  2. 使用with slim.arg_scope(...):設置默認參數。
  • 自己的一些測試:

import tensorflow.contrib.slim as slim# 定義一些自己的函數,並使用add_arg_scope裝飾@slim.add_arg_scopedef my_function1(param1=param1, param2=param2): print(my_function1, param1, param2)@slim.add_arg_scopedef my_function2(param2=param2, param3=param3): print(my_function2, param2, param3) # 操作一 # 普通的調用arg_scope的方式with slim.arg_scope([my_function1], param2=param2_modify): my_function1() # 輸出 my_function1 param1 param2_modify# 操作二# 如果設置的一些參數,不存在於ops列表中,則會報錯# 以下實例中,my_function1函數中沒有參數param3,所以會報錯with slim.arg_scope([my_function1], param3=param2_modify): my_function1()# 操作三# 優先順序# 1. 優先順序最高的是函數本身的實參# 2. 有多層arg_scope時,優先順序最高的時最內側的arg_scope# 3. 優先順序最低的是函數定義中的默認參數with slim.arg_scope([my_function1], param2=param2_modify1): with slim.arg_scope([my_function1], param2=param2_modify2): my_function1(param2=2) # 輸出為 my_function1 param1 2# 操作四# 當命名參數與非命名參數同時使用時要注意# 如函數 my_function1中的參數列表為param1, param2# 則以下函數會報錯with slim.arg_scope([my_function1], param1=1): my_function1(1, 2) # 錯誤 my_function1() got multiple values for argument param1 # 錯誤分析:從源碼來看,非命名參數與命名參數是分開處理的,所以以上代碼等價於 # my_function1(1, 2, param1=1)# 操作五# 對於以arg_scope作為參數傳遞時with slim.arg_scope([my_function1], param2=22) as s: passwith slim.arg_scope([my_function1], param1=111): with slim.arg_scope(s): my_function1() # 輸出 myfunction1 param1 22 # 由此可見,最外層的arg_scope完全不起作用with slim.arg_scope([my_function2], param3=123): with slim.arg_scope(s): my_function2() # 輸出 my_function2() param2 param3 # 由此可見,最外層arg_scope完全沒有效果 # 即使s對象完全沒有對my_function2進行操作,最外層arg_scope也沒有起作用

2.源碼理解

2.1. arg_scope數據存儲介紹

# 列表,實現"棧"結構,用於多層arg_scope# 列表中每個元素為字典,代表一層arg_scope,且融合了之前所有arg_scope的內容# (有特例,就是使用arg_scope對象、即字典對象,作為參數傳遞到arg_scope中,具體參考arg_scope函數源碼)# 元素key為str(func)# 元素value為字典:即 參數名(字元串) -> 默認參數值_ARGSTACK = [{}]# 字典,用於存儲所有被@add_arg_scope修飾的函數# key為函數名稱,即str(func)# value為函數命名參數列表_DECORATED_OPS = {}

2.2. 函數功能介紹

  • 私有函數

# 獲取 _ARGSTACK 對象def _get_arg_stack(): if _ARGSTACK: return _ARGSTACK else: _ARGSTACK.append({}) return _ARGSTACK# 獲取函數的屬性_key_op的值,該屬性不存在,則返回str(op)def _key_op(op): return getattr(op, _key_op, str(op))# 獲取當前函數的模塊名稱與函數名稱def _name_op(op): return (op.__module__, op.__name__)# 獲取當前函數所有命名參數(有默認數值屬性)的列表def _kwarg_names(func): kwargs_length = len(func.__defaults__) if func.__defaults__ else 0 return func.__code__.co_varnames[-kwargs_length:func.__code__.co_argcount]# 將當前函數添加到 _DECORATED_OPS 中# 從這兒可以看出,_DECORATED_OPS的key為函數名,value為有默認屬性值命名參數的列表def _add_op(op): key_op = _key_op(op) if key_op not in _DECORATED_OPS: _DECORATED_OPS[key_op] = _kwarg_names(op)

  • 共有函數

# 獲取當前arg_scope# arg_scope可以疊加使用,該方法獲取最內側的arg_scope對象# 換句話說,獲取棧頂部的arg_scope字典對象def current_arg_scope(): stack = _get_arg_stack() return stack[-1]# 最重要的函數,作用為:為選定的操作,添加默認參數# 根據定義可以看出,該函數使用了上下文管理器,簡單說就是為了使用 with 而必須的操作。@tf_contextlib.contextmanagerdef arg_scope(list_ops_or_scope, **kwargs): if isinstance(list_ops_or_scope, dict): # 當輸入的list_ops_or_scope為其他arg_scope對象時,就會輸入字典對象 if kwargs: # 當list_ops_or_scope為字典對象時,kwargs必須為空,否則報錯 raise ValueError(When attempting to re-use a scope by suppling a dictionary, kwargs must be empty.) current_scope = list_ops_or_scope.copy() try: # 從以下代碼可以看出,當使用arg_scope對象作為參數時,之前所有arg_scope對象都不起作用了 # 即1.5.節中的操作五 _get_arg_stack().append(current_scope) yield current_scope finally: _get_arg_stack().pop() else: # 當輸入 list_ops_or_scope 為list或tuple時 if not isinstance(list_ops_or_scope, (list, tuple)): raise TypeError(list_ops_or_scope must either be a list/tuple or reused scope (i.e. dict)) try: # 第一步:複製前一層arg_scope中所有操作的默認參數 current_scope = current_arg_scope().copy() # 第二步:遍歷所有op(即函數),判斷該函數是否被@add_arg_scope修飾, # 最後將當前arg_scope中的參數列表與之前arg_scope中的參數列表合併 for op in list_ops_or_scope: key_op = _key_op(op) # 判斷是否被@add_arg_scope修飾 if not has_arg_scope(op): raise ValueError(%s is not decorated with @add_arg_scope, _name_op(op)) # 合併參數列表 if key_op in current_scope: current_kwargs = current_scope[key_op].copy() current_kwargs.update(kwargs) current_scope[key_op] = current_kwargs else: current_scope[key_op] = kwargs.copy() # 第三步:將當前arg_scope對象放置到棧中 _get_arg_stack().append(current_scope) # 上下文管理器標誌位,之前都是__enter__操作,之後都是__exit__操作 yield current_scope finally: # 離開with塊時,將當前arg_scope刪除 _get_arg_stack().pop()# 該函數設計到Python中的 修飾器# 為了使得arg_socpe有效,相關操作(函數)必須使用@add_arg_scope修飾# 主要作用:將函數添加到內部數據結構中,使得可以通過arg_scope操作def add_arg_scope(func): # 用以下函數修飾輸入函數func # 從該函數的輸入列表中可以看出,命名參數與非命名參數是分開處理的 # 且合併的參數僅僅為命名參數(不對非命名參數進行處理) # 即1.5.節中的操作四 def func_with_args(*args, **kwargs): current_scope = current_arg_scope() current_args = kwargs key_func = _key_op(func) if key_func in current_scope: current_args = current_scope[key_func].copy() current_args.update(kwargs) return func(*args, **current_args) # 被修飾的函數,都要被添加到_DECORATED_OPS中 _add_op(func) setattr(func_with_args, _key_op, _key_op(func)) return tf_decorator.make_decorator(func, func_with_args)# 查看一個函數是否被add_arg_scope修飾def has_arg_scope(func): return _key_op(func) in _DECORATED_OPS# 通過arc_scope可為該函數設置哪些變數的默認參數# 起始該方法不準確,因為也可以通過arg_scope設置一些非命名參數的默認值,但容易出錯……def arg_scoped_arguments(func): assert has_arg_scope(func) return _DECORATED_OPS[_key_op(func)]

推薦閱讀:

My solutions for `Google TensorFlow Speech Recognition Challenge`
TensorFlow的checkpoint中變數的重命名
卷積神經網路模型(2)-AlexNet解讀
28 款 GitHub 最流行的開源機器學習項目(附地址)

TAG:TensorFlow | 深度學習DeepLearning |