從 IndRNN 來回顧 RNN 模型

從 IndRNN 來回顧 RNN 模型

IndRNN 是最近新提出的一種 RNN 架構,作者認為其能夠更好地解決 gradient vanishing / exploding 以及具有更好的解釋性。

從 RNN 到 IndRNN

首先來複習一下 RNN t 時刻的 hidden state 的計算公式:

這個公式相信都是瞭然於心,如果畫成圖的話,就是下面這個樣子:

圖片引用自 colahblog

但是,請問一下自己,這個 A 裡面是怎麼樣的一個結構?如果不能一下子回答出來,那麼你可能和我一樣,從來都不曾真正的理解 RNN。

在這個 RNN cell 的內部單元,其連接的狀態是這樣的:

Vanilla RNN 內部是兩個全連接

注意: h_{t-1}h_t 之間是一個全連接的關係

而 IndRNN 的計算公式呢,做了一個小小的修改:

從矩陣乘變成了 haramard dot,也就是 element wise 的乘法,能給 RNN 帶來什麼樣的變化呢?

全連接變成了單個神經元自身之間的傳遞,這帶來兩個變化:

  1. 更好地解決梯度爆炸 / 消失
  2. 更好的解釋性

接下來就從這兩個方面入手來深入了解一下 IndRNN

梯度爆炸/消失

首先來複習一下 RNN 的梯度怎麼計算:

假設目標函數為 ,我們對第 個 hidden state 求導,結果為:

frac{partial{J}}{partial{h_T}} Pi_{k=t}^{T-1}diag(sigma(h_{k+1}))mathbf{U^T}

其中 diag 是指對角矩陣(其中對角的元素是激活函數對 h_{k+1} 的導數),其實就是後一 hidden state 通過鏈式法則展開到最前的一個 hiddent state,求導的結果連乘。

矩陣的連乘可以通過對角化來簡便計算,即化為 , 是一個對角陣,主對角線上的元素就是其特徵值。而如果其特徵值小於 1,那麼在連續的乘積之後其值就會接近 0,梯度消失;如果大於 1,那麼就會接近變成 NaN,梯度爆炸。解決的手段分別就是梯度裁剪(gradient clipping)和 合適的初始化 + 更換為 ReLU acitivation(目的是為了選擇合適的特徵值?)

LSTM 解決這一問題的思路是增加 gates,來控制信息的流動,從而較好的解決梯度消失問題(因為梯度爆炸用 clipping 能夠比較粗暴地解決,而 vanishing 並不行)。LSTM 的求導比較繁瑣,可以參考一下 LSTM Forward and Backward 。從繁瑣的公式中比較難看出 LSTM 解決梯度消失和爆炸,我們可以通過下面這張圖來直觀的感受一下 LSTM 的作用:

梯度的流動過程

左邊是 RNN,右邊是 LSTM;顏色的深淺表示了梯度影響的程度。可以看到,RNN 第一個時刻的受最後一個時刻的影響微乎其微,這種情況下就可以認為是出現了梯度消失,無法較好地更新我們的參數;右邊的 LSTM,為了簡化起見,我們將將 input gate 設為 0(圖中的 - 符號),forget gate 始終記憶前一狀態的信息(圖中的 o 符號),我們可以將第一個時刻的信息一直傳遞至我們想要的 4, 6,並且其梯度的也能夠通過這一條路徑成功的回傳。所以,通過控制輸入以及先前狀態的流動方式,LSTM 能夠較好地解決梯度消失的問題。

GRU 則在 LSTM 基礎之上做了簡化:

GRU,圖片引用自 colahblog

  1. 將 Forget Gate 和 Input Gate 合併成一個 Update Gate ,不像 LSTM 是由兩個獨立的門來控制
  2. 使用一個 Reset Gate 來直接控制 h_{t-1}h_t 的貢獻,而 LSTM 在計算 hat c_t 的時候是沒有一個 r_t 來控制 h_{t-1}

這樣做的好處很直觀地一點就是減少參數的數量,能夠加快訓練速度;另外 對於 的控制能夠讓 cell 更好的理清過去時刻狀態對現在狀態的影響程度,但 Update 門的不獨立性又使得它的效用有所下降。我猜測如果直接在 LSTM 的基礎之上加一個 Reset Gate,可能效果會更好,但參數的數量就上去了,所以這裡可能存在一個性能和速度的 tradeoff。

回到 IndRNN,其對於 h_{n,t} 的梯度計算如下(t 時刻 hidden state 的第 n 個單元):

IndRNN 的梯度計算公式

最大的差別就是矩陣連乘變成了一個數的冪次,這樣我們就可以通過控制這個矩陣中元素的大小來避免梯度消失和爆炸。當然,我們也可以選擇一個合適的數值範圍來讓讓梯度更好地流動,加快訓練的速度。

實驗的結果也證明,IndRNN 的長期記憶(也就是梯度能夠傳遞的時間步數)效果遠好於 RNN 和 LSTM(5000 vs 500~1000)

解釋性

因為 h_{t-1}h_t 各個單元之間是相互獨立的,那麼就可以提供了看待 IndRNN 的兩個權重矩陣的新的視角:

  1. W :負責提取輸入的空間特徵
  2. U : 提取時間上的特徵

而且,在這之後加一層全連接層,IndRNN 在一些條件的限制下( two constrains : 1. linear activation; diagonalizable weight matrix)能夠變成一個普通的 RNN 模型。

不過關於這一點,我覺得還需要更好地可視化的手段來幫助解釋,否則這樣的解釋依舊停留在 Intuition 上,不足以說明問題。

更深的 RNN

寫這篇文章的時候突然想到一個問題,為什麼 RNN 裡面都使用 tanh /sigmoid 而不是 ReLU,Leaky ReLU 這種現在的爆款標配 activation 呢?事實上是有的,Hinton 在 IRNN 這篇文章里就嘗試使用一些初始化的 Trick 和 ReLU activation 來解決梯度消失的問題,並且能夠取得近似 LSTM 的效果。

對於 LSTM 來說,其中有三個門,因為門的值需要在 [0, 1] 之間,所以選擇 sigmoid 函數,沒有問題;那麼 Cell State 的計算和最後的 Hidden State 為什麼不嘗試使用 ReLU 呢,應該也沒問題,求導還快,計算也方便,但是同時也就有兩個缺點,一個是不是以 0 為中心,這個在 CS231n 中對於各種激活函數中有講到,ReLU 雖然是 Non-saturated activation,但他的值域是大於零的 ;另外一點,就是如果使用 ReLU,那麼 LSTM 的輸出可能會很大。所以,這裡同樣存在著一個 trade off,考慮到 LSTM 提出的時間以及在各個任務上性能都還不錯,替換的需求不大,所以可能也就這麼用下來了。所以,使用 ReLU 或者 Leaky ReLU 是可行的。

既然如此,那麼 IndRNN 用上 ReLU 更加沒什麼問題,所以作者也提出能夠借鑒 CNN 中的一些方法,使用類似 ResNet 進行堆疊,得到更深的 RNN:

對於 ResNet 的了解不多,所以這裡先暫時擱置,有待後面補上。但看到 Batch Normalization 以及 ResNet,也是在提醒我們 CNN 和 RNN 相互借鑒非常重要。

References

IndRNN

RNN Tutorial-Toronto University

Understanding-LSTM

推薦閱讀:

機器學習中的數學基礎(簡介)
周明:如果用一個詞形容NLP圈的2017,我選「想像」| 人物對話
Model, model告訴我,她到底在想什麼?
從Kaggle賽題: Quora Question Pairs 看文本相似性/相關性
《Attention is all you need》

TAG:RNN | 自然語言處理 |