基於RNN的seq2seq與基於CNN的seq2seq區別,為什麼後者效果更好?


沒有人評論,還是我自己嘗試來回答一下吧,同時也歡迎聯繫轉載。

提出這個問題原因,在於FaceBook AI實驗室最近發布了一個基於CNN Seq2Seq的機器翻譯模型,取得了比Google翻譯更好的效果。通過最近的學習發現,這個問題不是絕對的。

FaceBook的Convolutional SeqSeq取得了超越Google翻譯的成果,重要原因在於採用了很多的trick,很多工作值得借鑒:

1、Position Embedding,在輸入信息中加入位置向量P=(p1,p2,....),把位置向量與詞向量W=(w1,w2,.....)求和構成向量E=(w1+p1,w2+p2),做為網路輸入,使由CNN構成的Encoder和Decoder也具備了RNN捕捉輸入Sequence中詞的位置信息的功能。

2、層疊CNN構成了hierarchical representation表示。層疊的CNN擁有3個優點:

(1)捕獲long-distance依賴關係。底層的CNN捕捉相聚較近的詞之間的依賴關係,高層CNN捕捉較遠詞之間的依賴關係。通過層次化的結構,實現了類似RNN(LSTM)捕捉長度在20個詞以上的Sequence的依賴關係的功能。

(2)效率高。假設一個sequence序列長度為n,採用RNN(LSTM)對其進行建模 需要進行n次操作,時間複雜度O(n)。相比,採用層疊CNN只需要進行n/k次操作,時間複雜度O(n/k),k為卷積窗口大小。

(3)可以並行化實現。RNN對sequence的建模依賴於序列的歷史信息,因此不能並行實現。相比,層疊CNN正個sequence進行卷積,不依賴序列歷史信息,可以並行實現,模型訓練更快,特別是在工業生產,面臨處理大數據量和實時要求比較高的情況下。

3、融合了Residual connection、liner mapping的多層attention。通過attention決定輸入的哪些信息是重要的,並逐步往下傳遞。把encoder的輸出和decoder的輸出做點乘(dot products),再歸一化,再乘以encoder的輸入X之後做為權重化後的結果加入到decoder中預測目標語言序列。

4、採用GLU做為gate mechanism。GLU單元激活方式如下公式所示:

每一層的輸出都是一個線性映射X*W + b,被一個門gate:o(X*V+c)控制,通過做乘法來控制信息向下層流動的力度,o採用雙曲正切S型激活函數。這個機制類似LSTM中的gate mechanism,對於語言建模非常有效,使模型可以選擇那些詞或特徵對於預測下一個詞是真的有效的。

5、進行了梯度裁剪和精細的權重初始化,加速模型訓練和收斂。

完整的結構分析:

Figure1是論文中給出的的Convolutional Seq2Seq的結構,看起來有點複雜,其實挺簡單的。下面簡要分析下是如何與上述5個trick結合起來的:

上左encoder部分:通過層疊的卷積抽取輸入源語言(英語)sequence的特徵,圖中直進行了一層卷積。卷積之後經過GLU激活做為encoder輸出。

下左decoder部分:採用層疊卷積抽取輸出目標語言(德語)sequence的特徵,經過GLU激活做為decoder輸出。

中左attention部分:把decoder和encoder的輸出做點乘,做為輸入源語言(英語)sequence中每個詞權重。

中右Residualconnection:把attention計算的權重與輸入序列相乘,加入到decoder的輸出中輸出輸出序列。

最後實驗結論:在多個公開數據集上獲得了新的state-of-the-art的成績。在WMT-16、英語-羅馬尼亞語翻譯,高出以前方法1.8 BLEU;在WMT-14、英語-法語翻譯,比以前LSTM模型所取得的成績高出1.5 BLEU;在WMT-14、英語-德語翻譯,比以前方法高出0.5 BLEU。

總結:個人感覺本文採用了很多簡單且非常有效的trick,達到了基於LSTM的NMT方法更好的效果,正因為如此,並不能說,基於CNN seq2seq模型就一定比基於LSTM的Seq2Seq一定好。採用CNN的Seq2Seq最大的優點在於速度快,效率高,缺點就是需要調整的參數太多。上升到CNN和RNN用於NLP問題時,CNN也是可行的,且網路結構搭建更加靈活,效率高,特別是在大數據集上,往往能取得比RNN更好的結果。

參考論文:
《Convolutional Sequence to Sequence Learning》,下載地址:https://arxiv.org/abs/1705.03122
《Language modeling with gated linear units》,下載地址:https://arxiv.org/abs/1612.08083
《A Convolutional Encoder Model for Neural Machine Translation》,下載地址:https://arxiv.org/abs/1611.02344


我倒是不覺得CNN的seq2seq比RNN的效果更好了,至少現在看起來的情況是相差不大,只是說基於Gated CNN做的seq2seq翻譯模型的運行速度顯著的快於RNN,這也主要是由於CNN中的卷積的計算是比較有利於GPU做並行化的,而RNN則在時間維度上需要使用循環來實現。

其實Gated CNN算是Facebook好幾個月以前的工作了,現在突然火了一把,不得不說標題黨還是很重要的……


LSTM估計不久就要成為歷史過客了。他的核心是引入gate,但是操作實現上太複雜了,一個LSTM cell就挺複雜,人為設計痕迹過重,以至於遮住了它的洞察能力,在表達長程相關性上,也沒有達到期望的效果。可能是,他這種技術實現方式--搞4個分支(遺忘門,輸入門,輸出門,選擇門)加一個cell流傳遞的方式--就做不到想要的效果。

CNN估計就不會是歷史過客,它應該是抓住了某些局部和全局處理上的本質抽象能力,技術實現上也是簡潔優美,即便將來網路技術改朝換代,它也總會以某種面貌升級換代保存下來繼續發揚光大。


君不見ICLR2017裡面谷歌自家也刷新了BLEU嗎?我沒有任何否認cnn seq2seq的意思,我就是想說bleu只是一個評分標準而已。


我覺得就目前的情況來看還說不上CNN的seq2seq就比RNN要好,只是引入Gate以及位置信息等trick的CNN在facebook手下達到了和rnn旗鼓相當的效果,亮點是速度更快,不過對於cnn來說速度快是意料之中的。長遠來看,技術的發展不是東風與西風誰壓倒誰的問題,而且互相學習借鑒,螺旋攀升的過程,CNN中gate的引入就是一例。


效果更好,也許是模型更複雜,過擬合更嚴重。


推薦閱讀:

對比JIT和AOT,各自有什麼優點與缺點?
C++ 為什麼沒有 function 關鍵字?
如何用通俗易懂的語言解釋虛擬存儲器?
計算機語言算不算語言?
國人對於國外CS教材是否存在盲目崇拜心理?

TAG:自然語言處理 | 計算機科學 | 深度學習DeepLearning | LSTM | 卷積神經網路CNN |