pytorch中如何處理RNN輸入變長序列padding

一、為什麼RNN需要處理變長輸入

假設我們有情感分析的例子,對每句話進行一個感情級別的分類,主體流程大概是下圖所示:

思路比較簡單,但是當我們進行batch個訓練數據一起計算的時候,我們會遇到多個訓練樣例長度不同的情況,這樣我們就會很自然的進行padding,將短句子padding為跟最長的句子一樣。

比如向下圖這樣:

但是這會有一個問題,什麼問題呢?比如上圖,句子「Yes」只有一個單詞,但是padding了5的pad符號,這樣會導致LSTM對它的表示通過了非常多無用的字元,這樣得到的句子表示就會有誤差,更直觀的如下圖:

那麼我們正確的做法應該是怎麼樣呢?

這就引出pytorch中RNN需要處理變長輸入的需求了。在上面這個例子,我們想要得到的表示僅僅是LSTM過完單詞"Yes"之後的表示,而不是通過了多個無用的「Pad」得到的表示:如下圖:

二、pytorch中RNN如何處理變長padding

主要是用函數torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()來進行的,分別來看看這兩個函數的用法。

這裡的pack,理解成壓緊比較好。 將一個 填充過的變長序列 壓緊。(填充時候,會有冗餘,所以壓緊一下)

輸入的形狀可以是(T×B×* )。T是最長序列長度,B是batch size,*代表任意維度(可以是0)。如果batch_first=True的話,那麼相應的 input size 就是 (B×T×*)。

Variable中保存的序列,應該按序列長度的長短排序,長的在前,短的在後(特別注意需要進行排序)。即input[:,0]代表的是最長的序列,input[:, B-1]保存的是最短的序列。

參數說明:

input (Variable) – 變長序列 被填充後的 batch

lengths (list[int]) – Variable 中 每個序列的長度。(知道了每個序列的長度,才能知道每個序列處理到多長停止

batch_first (bool, optional) – 如果是True,input的形狀應該是B*T*size。

返回值:

一個PackedSequence 對象。一個PackedSequence表示如下所示:

具體代碼如下:

embed_input_x_packed = pack_padded_sequence(embed_input_x, sentence_lens, batch_first=True)encoder_outputs_packed, (h_last, c_last) = self.lstm(embed_input_x_packed)

此時,返回的h_last和c_last就是剔除padding字元後的hidden state和cell state,都是Variable類型的。代表的意思如下(各個句子的表示,lstm只會作用到它實際長度的句子,而不是通過無用的padding字元,下圖用紅色的打鉤來表示):

但是返回的output是PackedSequence類型的,可以使用:

encoder_outputs, _ = pad_packed_sequence(encoder_outputs_packed, batch_first=True)

將encoderoutputs在轉換為Variable類型,得到的_代表各個句子的長度。

三、總結

這樣綜上所述,RNN在處理類似變長的句子序列的時候,我們就可以配套使用torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()來避免padding對句子表示的影響

參考:

pytorch對可變長度序列的處理 - 深度學習1 - 博客園?

www.cnblogs.com圖標趙普:pytorch RNN 變長輸入 padding?

zhuanlan.zhihu.com圖標
推薦閱讀:

《機器學習實戰》學習總結(十一)——隱馬爾可夫模型(HMM)
決策樹與隨機森林
CS231N 課程筆記合集
推薦系統:經典方法
CS259D:數據挖掘與網路安全講義筆記

TAG:機器學習 | 深度學習DeepLearning | PyTorch |