Seq2seq模型及注意力機制

在上一篇文章中,我們介紹了傳統的RNN及其變種GRU及LSTM。它們只能處理輸出序列為定長的情況。如果要處理輸出序列為不定長情況的問題呢?例如機器翻譯,例如英文到法語的句子翻譯,輸入和輸出均為不定長。前人提出了seq2seq模型,basic idea是設計一個encoder與decoder,其中encoder將輸入序列編碼為一個包含輸入序列所有信息的context vectorc,decoder通過對c的解碼獲得輸入序列的信息,從而得到輸出序列。encoder及decoder都通常為RNN循環神經網路。

seq2seq模型

seq2seq

編碼器 Encoder

編碼器的作用是把一個不定長的輸入序列x_1, x_2,ldots,x_T轉化成一個定長的context vectorc. 該context vector編碼了輸入序列x_1, x_2,ldots,x_T的序列。回憶一下循環神經網路,假設該循環神經網路單元為f(可以為vanilla RNN, LSTM, GRU),那麼hidden state為

h_t=f(x_t, h_{t-1})

編碼器的context vector是所有時刻hidden state的函數,即:

c=q(h_1,ldots,h_T)

簡單地,我們可以把最終時刻的hidden stateh_T作為context vecter。當然我們也可以取各個時刻hidden states的平均,以及其他方法。

解碼器 Decoder

編碼器最終輸出一個context vectorc,該context vector編碼了輸入序列x_1, x_2,ldots,x_T的信息。

假設訓練數據中的輸出序列為y_1, y_2,ldots,y_T,我們希望每個t時刻的輸出即取決於之前的輸出也取決於context vector,即估計mathbb{P}(y_{t}|y_1,ldots,y_{t-1}, c),從而得到輸出序列的聯合概率分布:

mathbb{P}(y_1,ldots,y_{T})=prod_{t=1}^{T}mathbb{P}(y_{t}|y_1,ldots,y_{t-1},c)

並定義該序列的損失函數loss function

-logmathbb{P}(y_1,ldots,y_{T})

通過最小化損失函數來訓練seq2seq模型。

那麼如何估計mathbb{P}(y_{t}|y_1,ldots,y_{t-1}, c)

我們使用另一個循環神經網路作為解碼器。解碼器使用函數p來表示t時刻輸出y_{t}的概率

mathbb{P}(y_{t}|y_1,ldots,y_{t-1}, c)=p(y_{t-1},s_{t},c)

為了區分編碼器中的hidden stateh_t,其中s_{t}t時刻解碼器的hidden state。區別於編碼器,解碼器中的循環神經網路的輸入除了前一個時刻的輸出序列y_{t-1},和前一個時刻的hidden states_{t-1}以外,還包含了context vectorc。即:

s_{t}=g(y_{t-1},s_{t-1},c)

其中函數g為解碼器的循環神經網路單元。

什麼叫注意力機制? Attention-based mechanism

我們注意到,在以上的解碼器設計中,各個時刻使用了相同的context vector。

以英語-法語翻譯為例,給定一對輸入序列「they are watching」和輸出序列「Ils regardent」,解碼器在時刻1可以使用更多編碼了「they are」信息的背景向量來生成「Ils」,而在時刻2可以使用更多編碼了「watching」信息的背景向量來生成「regardent」。這看上去就像是在解碼器的每一時刻對輸入序列中不同時刻分配不同的注意力。這也是注意力機制的由來。

我們自然地想到了可以在不同時刻採用不同的context vector。對上述編碼器稍加修改,將c改為c_{t}c_{t}代表t時刻的context vector。那麼t時刻解碼器的hidden state為

s_{t}=g(y_{t-1},s_{t-1},c_{t})

如何對不同時刻設計不同context vectorc_{t}?原先我們直接取編碼器hidden stateh_t的平均或者最後一個值,那麼一個自然的想法是取h_t的加權平均,即:

c_{t}=sum_{t=1}^T alpha_{tt}h_t

如何計算權重alpha_{tt}?先設計一個e_{tt},再用e_{tt}計算softmax以輸出概率。

alpha_{tt} = frac{exp(e_{tt})}{ sum_{k=1}^T exp(e_{tk}) }

如何設計e_{tt}?可以認為e_{tt}與當前時刻編碼器的hidden stateh_t,以及上一個時刻解碼器的hidden states_{t-1}有關,即:

e_{tt}=a(s_{t-1}, h_t)

對於函數a的選取,研究者提出了不同的方案。在Bahanau的論文中,

e_{tt}=v^Ttanh(W_ss_{t-1}+W_hh_t)。即對s_{t-1}h_t分別乘以一個矩陣轉為向量,再用tanh激活,得到一個向量。因為輸出是一個scalar,再乘以向量v^T轉為標量。其中vW_sW_h都是需要學習的模型參數。

針對不同的函數a的選取,可以得到不同的注意力機制。


推薦閱讀:

斯坦福CS231n項目實戰(四):淺層神經網路
CS231n Assignment2
李飛飛最新論文:構建好奇心驅動的神經網路,複製嬰兒學習能力
Inception-v2/v3結構解析(原創)
使用py-faster-rcnn進行目標檢測(object detect)

TAG:深度學習DeepLearning | 機器學習 | 神經網路 |