pytorch中如何處理RNN輸入變長序列padding
一、為什麼RNN需要處理變長輸入
假設我們有情感分析的例子,對每句話進行一個感情級別的分類,主體流程大概是下圖所示:
思路比較簡單,但是當我們進行batch個訓練數據一起計算的時候,我們會遇到多個訓練樣例長度不同的情況,這樣我們就會很自然的進行padding,將短句子padding為跟最長的句子一樣。
比如向下圖這樣:
但是這會有一個問題,什麼問題呢?比如上圖,句子「Yes」只有一個單詞,但是padding了5的pad符號,這樣會導致LSTM對它的表示通過了非常多無用的字元,這樣得到的句子表示就會有誤差,更直觀的如下圖:
那麼我們正確的做法應該是怎麼樣呢?
這就引出pytorch中RNN需要處理變長輸入的需求了。在上面這個例子,我們想要得到的表示僅僅是LSTM過完單詞"Yes"之後的表示,而不是通過了多個無用的「Pad」得到的表示:如下圖:
二、pytorch中RNN如何處理變長padding
主要是用函數torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()來進行的,分別來看看這兩個函數的用法。
這裡的pack,理解成壓緊比較好。 將一個 填充過的變長序列 壓緊。(填充時候,會有冗餘,所以壓緊一下)
輸入的形狀可以是(T×B×* )。T是最長序列長度,B是batch size,*代表任意維度(可以是0)。如果batch_first=True的話,那麼相應的 input size 就是 (B×T×*)。
Variable中保存的序列,應該按序列長度的長短排序,長的在前,短的在後(特別注意需要進行排序)。即input[:,0]代表的是最長的序列,input[:, B-1]保存的是最短的序列。
參數說明:
input (Variable) – 變長序列 被填充後的 batch
lengths (list[int]) – Variable 中 每個序列的長度。(知道了每個序列的長度,才能知道每個序列處理到多長停止)
batch_first (bool, optional) – 如果是True,input的形狀應該是B*T*size。
返回值:
一個PackedSequence 對象。一個PackedSequence表示如下所示:
具體代碼如下:
embed_input_x_packed = pack_padded_sequence(embed_input_x, sentence_lens, batch_first=True)encoder_outputs_packed, (h_last, c_last) = self.lstm(embed_input_x_packed)
此時,返回的h_last和c_last就是剔除padding字元後的hidden state和cell state,都是Variable類型的。代表的意思如下(各個句子的表示,lstm只會作用到它實際長度的句子,而不是通過無用的padding字元,下圖用紅色的打鉤來表示):
但是返回的output是PackedSequence類型的,可以使用:
encoder_outputs, _ = pad_packed_sequence(encoder_outputs_packed, batch_first=True)
將encoderoutputs在轉換為Variable類型,得到的_代表各個句子的長度。
三、總結
這樣綜上所述,RNN在處理類似變長的句子序列的時候,我們就可以配套使用torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()來避免padding對句子表示的影響
參考:
pytorch對可變長度序列的處理 - 深度學習1 - 博客園趙普:pytorch RNN 變長輸入 padding推薦閱讀:
※《機器學習實戰》學習總結(十一)——隱馬爾可夫模型(HMM)
※決策樹與隨機森林
※CS231N 課程筆記合集
※推薦系統:經典方法
※CS259D:數據挖掘與網路安全講義筆記
TAG:機器學習 | 深度學習DeepLearning | PyTorch |