巧妙使用 TensorFlow 之 TensorLayer

TensorFlow 1.0 版本剛剛正式發布,各路媒體大量刷屏,可見其影響力。再加上 TensorFlow-Fold 的出現,讓其建立 dynamic computation graphs 非常方便,具備了 dynet 和 pytorch 類似的功能。同學們可能發現 1.0 版本改動非常大,很多 API 被重新命名、使用方法也變了很多,導致舊版本的 TF 代碼無法在 1.0 版本上直接運行。另外,TF1.0 版本中終於加入了tf.layers,雖然功能有限,但比沒有要好多了。

今天介紹一個工具 TensorLayer(以下簡稱 TL),雖然它實現了各種各樣的層,也提供類似 Keras 的 fit(), test(), predict() 等方法,但我稱其為工具而不是庫,因為它大量的功能以函數形式提供,通過巧妙使用提供的函數,可以非常簡潔高效地實現複雜的應用。比如 TL 提供了大量的數據增強和預處理函數,自己可以根據應用而個性化地組合這些函數(比如 image segmentation 時 X 和 Y 要對應處理)。TL 的 API 設計要求儘可能地輸入 TF 本身的 API,這樣的好處是可以和 TF 方便地交互使用,其 DynamicRNNLayer 和 PoolLayer 是很好的例子(當然它也提供 簡化版本API )。因此在我看來,使用 TL 的設計是為了巧妙地使用 TF,它提供的代碼例子都跟隨一種編寫風格,方便社區統一風格以分享閱讀代碼。

這裡根據最近使用 TL 的經驗,我總結了一些使用小技巧,若寫得不客觀請見諒,當作是自己的筆記吧。

第一次在知乎寫文章,寫得不好看看就好 ??

我現在在美帝讀博,非常喜歡深度學習,歡迎交流 :D

*** 更多小技巧將陸續在(這裡)補充。

1. 安裝

* 為了方便閱讀和拓展 TL 代碼,建議把整個項目下載下來(在terminal中輸入 git clone zsdonghao/tensorlayer),然後把 tensorlayer 的文件夾放到你的項目中。

* 由於近期 TL 發展很快,若想用 pip 安裝,建議安裝 master 版本。

* 對於研究 NLP 的同學,可能需要安裝 NLTK 和 NLTK data 以使用文本分析API,這些功能封裝在 tl.nlp 中。(若不使用則不需要安裝)

2. TF與TL相互轉換

* TF 轉為 TL : 通過 InputLayer 把 tensor 輸入到 tl.layers

* TL 轉為 TF : 通過 network.outputs 獲取 tensor

* 其它途徑 [issues7], 多輸入 [issues31]

3. Training/Testing 切換

* 通過 network.all_drop 來 disable/enable DropoutLayer (這隻當使用 DropoutLayer 時才可以,可參考 tutorial_mnist.py 和 Understand Basic layer

* 更好的方法是把 noise 層的 "is_fix" 設為 "True",然後對 Training 和 Testing 分別建立不同的graph,這需要用到 parameter sharing。除了控制training/testing,這個方法可以讓建立 graph 時使用不一樣的參數,如 batch_size,GaussianNoiseLayer, BatchNormLayer 等。 例子如下:

def mlp(x, is_train=True, reuse=False): with tf.variable_scope("MLP", reuse=reuse): tl.layers.set_name_reuse(reuse) net = InputLayer(x, name=in) net = DropoutLayer(net, 0.8, True, is_train, drop1) net = DenseLayer(net, 800, tf.nn.relu, dense1) net = DropoutLayer(net, 0.8, True, is_train, drop2) net = DenseLayer(net, 800, tf.nn.relu, dense2) net = DropoutLayer(net, 0.8, True, is_train, drop3) net = DenseLayer(net, 10, tf.identity, out) logits = net.outputs net.outputs = tf.nn.sigmoid(net.outputs) return net, logitsx = tf.placeholder(tf.float32, shape=[None, 784], name=x)y_ = tf.placeholder(tf.int64, shape=[None, ], name=y_)net_train, logits = mlp(x, is_train=True, reuse=False)net_test, _ = mlp(x, is_train=False, reuse=True)cost = tl.cost.cross_entropy(logits, y_, name=cost)

4. 獲取variables

TL非常特殊的一點是:需要給每層輸入一個唯一的名字,除非 reuse 該層。我剛開始用時,完全不明白這樣設計的道理,後來發現這樣的好處是杜絕了錯誤重用和方便參數管理。

* 使用 tl.layers.get_variables_with_name 獲取參數列表,盡量少用 net.all_params,下面的例子獲取了上面例子中全部函數,因為上面例子中的函數都置於 「MLP」 之下:

train_vars = tl.layers.get_variables_with_name(MLP, True, True)train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(cost, var_list=train_vars)

* 這個方法常常用於選擇哪些參數需要被更新,比如訓練 GAN 時,可以分別獲取 G 和 D 的參數列表,放到對應的 optimizer 中。

* 其它方法 [issues17], [issues26], [FQA]

5. 使用預訓練的 CNN 及建立 Resnet

很多應用中需要用到預訓練好的CNN模型,比如 image captioning, VQA, 以及在小數據集中 fine-tune 做分類器等等。

* 預訓練的 CNN

。TL 網站上提供了 VGG16, VGG19, Inception 等例子,請見 TL/example

。此外通過 tl.layers.SlimNetsLayer 可以使用 Tf-Slim pre-trained models 中全部預訓練好的模型!

* Resnet

。Implement by "for" loop [issues85]

。 Other methods [by @ritchieng]

6. 數據增強

* 使用 TF 提供的 TFRecord,參考 cifar10 and tfrecord examples; 這裡介紹一個很好的工具: imageflow

* TL提供了 tl.prepro.threading_data 來使用 python-threading,並提供了大量圖像增強的函數: the functions for images augmentation,請參考 tutorial_image_preprocess.py

7. 句子ID化(Sentences tokenization)

NLP中,詞語需要轉換為ID來處理,TL 的 tl.nlp 提供了大量的方法,但我覺得下面的幾個han s基本夠用了。

* 使用 tl.nlp.process_sentence 把句子分隔,對於中文推薦使用 jieba分詞

* 然後使用 tl.nlp.create_vocab 來建立辭彙表並保存成為 txt 文件,該函數還會返回一個 tl.nlp.SimpleVocabulary 實例

* 最後建議從 tl.nlp.create_vocab 保存的 txt 文件中實例化一個 tl.nlp.Vocabulary,以方便詞語和數字ID之間的轉換

* 更多文本處理函數請見 tl.prepro 和 tl.nlp

8. Dynamic RNN 與 sequence length

* 使用 tl.layers.retrieve_seq_length_op2 來幫助 DynamicRNNLayer 自動計算每個句子的 sequence length

* 對一個 batch 的數據做 zero padding:

b_sentence_ids = tl.prepro.pad_sequences(b_sentence_ids, padding=post)

* 其它方法 [issues18]

9. 常見bug

* Matplotlib issue arise when importing TensorLayer [issues] [FQA] (這個問題往往在遠程連接 ubuntu 時出現)

10. 其它小技巧

* TL默認模式下,在執行每一個 layer 時會把相關信息顯示到terminal中。但當你在建立非常深的網路時,這些信息沒有太大幫助。因此可以通過 with tl.ops.suppress_stdout(): 來禁止print輸出:

print("You can see me")with tl.ops.suppress_stdout(): print("You cant see me") # 在這裡建立模型print("You can see me")

Useful links

* TL official sites: [Docs], [中文文檔], [Github]

* Learning Deep Learning with TF and TL

討論

微信交流群
推薦閱讀:

基於深度學習的目標跟蹤演算法是否可能做到實時?
Facenet即triplet network模型訓練,loss不收斂的問題?
如何在torch7上增加一個新的層?
如何評價人們對電腦在圍棋上戰勝人類的時間預測?
AlphaGo 演算法的通用性到底有多廣?

TAG:深度学习DeepLearning | Python | 机器视觉 |