TF Boys (TensorFlow Boys ) 養成記(三):TensorFlow 變數共享

上次說到了 TensorFlow 從文件讀取數據,這次我們來談一談變數共享的問題。

為什麼要共享變數?我舉個簡單的例子:例如,當我們研究生成對抗網路GAN的時候,判別器的任務是,如果接收到的是生成器生成的圖像,判別器就嘗試優化自己的網路結構來使自己輸出0,如果接收到的是來自真實數據的圖像,那麼就嘗試優化自己的網路結構來使自己輸出1。也就是說,生成圖像和真實圖像經過判別器的時候,要共享同一套變數,所以TensorFlow引入了變數共享機制。

變數共享主要涉及到兩個函數: tf.get_variable(<name>, <shape>, <initializer>) 和 tf.variable_scope(<scope_name>)。

先來看第一個函數: tf.get_variable。

tf.get_variable 和tf.Variable不同的一點是,前者擁有一個變數檢查機制,會檢測已經存在的變數是否設置為共享變數,如果已經存在的變數沒有設置為共享變數,TensorFlow 運行到第二個擁有相同名字的變數的時候,就會報錯。

例如如下代碼:

def my_image_filter(input_images):n conv1_weights = tf.Variable(tf.random_normal([5, 5, 32, 32]),n name="conv1_weights")n conv1_biases = tf.Variable(tf.zeros([32]), name="conv1_biases")n conv1 = tf.nn.conv2d(input_images, conv1_weights,n strides=[1, 1, 1, 1], padding=SAME)n return tf.nn.relu(conv1 + conv1_biases)n

有兩個變數(Variables)conv1_weighs, conv1_biases和一個操作(Op)conv1,如果你直接調用兩次,不會出什麼問題,但是會生成兩套變數;

# First call creates one set of 2 variables.nresult1 = my_image_filter(image1)n# Another set of 2 variables is created in the second call.nresult2 = my_image_filter(image2)n

如果把 tf.Variable 改成 tf.get_variable,直接調用兩次,就會出問題了:

result1 = my_image_filter(image1)nresult2 = my_image_filter(image2)n# Raises ValueError(... conv1/weights already exists ...)n

為了解決這個問題,TensorFlow 又提出了 tf.variable_scope 函數:它的主要作用是,在一個作用域 scope 內共享一些變數,可以有如下幾種用法:

1)

with tf.variable_scope("image_filters") as scope:n result1 = my_image_filter(image1)n scope.reuse_variables() # or n #tf.get_variable_scope().reuse_variables()n result2 = my_image_filter(image2)n

需要注意的是:最好不要設置 reuse 標識為 False,只在需要的時候設置 reuse 標識為 True。

2)

with tf.variable_scope("image_filters1") as scope1:n result1 = my_image_filter(image1)nwith tf.variable_scope(scope1, reuse = True)n result2 = my_image_filter(image2)n

通常情況下,tf.variable_scope 和 tf.name_scope 配合,能畫出非常漂亮的流程圖,但是他們兩個之間又有著細微的差別,那就是 name_scope 只能管住操作 Ops 的名字,而管不住變數 Variables 的名字,看下例:

with tf.variable_scope("foo"):n with tf.name_scope("bar"):n v = tf.get_variable("v", [1])n x = 1.0 + vnassert v.name == "foo/v:0"nassert x.op.name == "foo/bar/addn

參考資料:

1. tensorflow.org/how_tos/

推薦閱讀:

譯文 | 與TensorFlow的第一次接觸(一)
Windows 10安裝Tensorflow手記
學習筆記TF039:TensorBoard

TAG:TensorFlow | 深度学习DeepLearning |