Seq2seq模型及注意力機制
在上一篇文章中,我們介紹了傳統的RNN及其變種GRU及LSTM。它們只能處理輸出序列為定長的情況。如果要處理輸出序列為不定長情況的問題呢?例如機器翻譯,例如英文到法語的句子翻譯,輸入和輸出均為不定長。前人提出了seq2seq模型,basic idea是設計一個encoder與decoder,其中encoder將輸入序列編碼為一個包含輸入序列所有信息的context vector,decoder通過對的解碼獲得輸入序列的信息,從而得到輸出序列。encoder及decoder都通常為RNN循環神經網路。
seq2seq模型
編碼器 Encoder
編碼器的作用是把一個不定長的輸入序列轉化成一個定長的context vector. 該context vector編碼了輸入序列的序列。回憶一下循環神經網路,假設該循環神經網路單元為(可以為vanilla RNN, LSTM, GRU),那麼hidden state為
編碼器的context vector是所有時刻hidden state的函數,即:
簡單地,我們可以把最終時刻的hidden state作為context vecter。當然我們也可以取各個時刻hidden states的平均,以及其他方法。
解碼器 Decoder
編碼器最終輸出一個context vector,該context vector編碼了輸入序列的信息。
假設訓練數據中的輸出序列為,我們希望每個時刻的輸出即取決於之前的輸出也取決於context vector,即估計,從而得到輸出序列的聯合概率分布:
並定義該序列的損失函數loss function
通過最小化損失函數來訓練seq2seq模型。
那麼如何估計?
我們使用另一個循環神經網路作為解碼器。解碼器使用函數來表示時刻輸出的概率
為了區分編碼器中的hidden state,其中為時刻解碼器的hidden state。區別於編碼器,解碼器中的循環神經網路的輸入除了前一個時刻的輸出序列,和前一個時刻的hidden state以外,還包含了context vector。即:
其中函數g為解碼器的循環神經網路單元。
什麼叫注意力機制? Attention-based mechanism
我們注意到,在以上的解碼器設計中,各個時刻使用了相同的context vector。
以英語-法語翻譯為例,給定一對輸入序列「they are watching」和輸出序列「Ils regardent」,解碼器在時刻1可以使用更多編碼了「they are」信息的背景向量來生成「Ils」,而在時刻2可以使用更多編碼了「watching」信息的背景向量來生成「regardent」。這看上去就像是在解碼器的每一時刻對輸入序列中不同時刻分配不同的注意力。這也是注意力機制的由來。
我們自然地想到了可以在不同時刻採用不同的context vector。對上述編碼器稍加修改,將改為,代表時刻的context vector。那麼時刻解碼器的hidden state為
如何對不同時刻設計不同context vector?原先我們直接取編碼器hidden state的平均或者最後一個值,那麼一個自然的想法是取的加權平均,即:
如何計算權重?先設計一個,再用計算softmax以輸出概率。
如何設計?可以認為與當前時刻編碼器的hidden state,以及上一個時刻解碼器的hidden state有關,即:
對於函數的選取,研究者提出了不同的方案。在Bahanau的論文中,
。即對與分別乘以一個矩陣轉為向量,再用tanh激活,得到一個向量。因為輸出是一個scalar,再乘以向量轉為標量。其中、、都是需要學習的模型參數。
針對不同的函數的選取,可以得到不同的注意力機制。
推薦閱讀:
※斯坦福CS231n項目實戰(四):淺層神經網路
※CS231n Assignment2
※李飛飛最新論文:構建好奇心驅動的神經網路,複製嬰兒學習能力
※Inception-v2/v3結構解析(原創)
※使用py-faster-rcnn進行目標檢測(object detect)
TAG:深度學習DeepLearning | 機器學習 | 神經網路 |