RNN Part 3-Back Propagation Through Time and Vanishing Gradients(BPTT演算法和梯度消失)

習翔宇:RNN Part1-RNN介紹?

zhuanlan.zhihu.com圖標

中我們介紹了循環神經網路的基本結構,在

習翔宇:back propagation algorithm推導?

zhuanlan.zhihu.com圖標

中我們介紹了前饋神經網路的反向傳播演算法的推導,在本節我們將介紹

  1. RNN中的反向傳播演算法,並解釋為什麼跟傳統的反向傳播不同
  2. 梯度消失現象,導致了LSTM和GRU的發展。梯度消失問題最早由Sepp Hochreiter在1911年提出,並且由於深度學習的流行而重新受到了關注

1. BPTT演算法

在循環神經網路Part1-RNN介紹中,我們給出了RNN的基本方程

s_{t}=f(Ux_{t}+Ws_{t-1})

hat{y_{t}}=softmax(Vs_{t})

我們定義損失為cross entrophy loss(交叉熵損失),定義如下:

E_{t}(y_{t},hat{y_{t}})=-y_{t}loghat{y_{t}}

E(y,hat{y})=sum_{t}^{}{E_{t}(y_{t},hat{y_{t}})}=-sum_{t}^{}{y_{t}loghat{y_{t}}}

其中 hat{y_{t}} 是時刻t的正確word,而 y_{t}是我們的預測結果,將整個句子作為訓練樣本,因此total error是每個時刻點t的error的累加。

我們的目標是先計算出error的梯度,然後採用梯度下降演算法學習出好的參數。就跟我們將誤差累加一樣,我們也累加每個時刻點的梯度,例如 frac{partial E}{partial W}=sum_{t}^{}{frac{partial E_{t}}{partial W}}

為了計算梯度我們採用chain rule of differentiation ,我們採用E3作為例子

frac{partial E_{3}}{partial V}=frac{partial E_{3}}{partial hat{y_{3}}}frac{partial hat{y_{3}}}{partial V}

=frac{partial E_{3}}{partial hat{y_{3}}} frac{partial hat{y_{3}}}{partial z_{3}} frac{partial z_{3}}{partial V}=(hat{y_{3}}-y_{3})otimes s_{3}

其中 otimes 表示兩個vector的outer product,這裡可以看出 frac{partial E_{3}}{partial V} 只取決於在時刻t=3的值包括 hat{y_{3}}、y_{3}、s_{3}

而對於W和U是不同的,我們根據鏈式法則 frac{partial E_{3}}{partial W}=frac{partial E_{3}}{partial hat{y_{3}}} frac{partial hat{y_{3}}}{partial s_{3}} frac{partial s_{3}}{partial W}

其中 s_{3}=tanh(Ux_{3}+Ws_{2}) ,因此 s_{3} 依賴於  s_{2} ,而 s_{2} 依賴於Ws_{1} ,依此類推,因此我們不能將 s_{2} 當作一個常量,需要依次地使用鏈式法則,如下所示

frac{partial E_{3}}{partial W}=frac{partial E_{3}}{partial hat{y_{3}}} frac{partial hat{y_{3}}}{partial s_{3}} frac{partial s_{3}}{partial W}=sum_{k=0}^{3}{frac{partial E_{3}}{partial hat{y_{3}}} frac{partial hat{y_{3}}}{partial s_{3}} frac{partial s_{3}}{partial s_{k}}} frac{partial s_{k}}{partial W}

最後 們將每個時刻的梯度貢獻都累加了起來,換句話說,因為W在每一步都影響到了輸出,因此我們需要將梯度從時刻t=3通過網路反向傳播到時刻t=0

需要注意的是,這與我們在前饋神經網路中使用的標準反向傳播演算法完全相同,主要的差異在於我們將每時刻W的梯度相加,在傳統的神經網路中,我們在層之間並不共享參數,因此不需要相加。BPTT只是標準反向傳播在展開的循環神經網路上應用的一個名字而已。


2. The Vanishing Gradient Problem

在之前的Recurrent Neural Networks Part 2-Tensorflow實現RNN中我們提到了RNN難於學習到long-range dependencies between words that are several steps apart. 然而英文句子的意思通常是由距離不那麼近的單詞決定的。例如

「The man who wore a wig on his head went inside」.

這個句子是關於一個男人進去了,而不是關於wig的,但是一個plain RNN是很難獲取到這樣的信息的,為什麼呢?我們來看一下我們之前計算出來的梯度:

frac{partial E_{3}}{partial W}=frac{partial E_{3}}{partial hat{y_{3}}} frac{partial hat{y_{3}}}{partial s_{3}} frac{partial s_{3}}{partial W}=sum_{k=0}^{3}{frac{partial E_{3}}{partial hat{y_{3}}} frac{partial hat{y_{3}}}{partial s_{3}} frac{partial s_{3}}{partial s_{k}}} frac{partial s_{k}}{partial W}

這裡要注意 frac{partial s_{3}}{partial s_{k}} 本身也是一個chain rule,例如 frac{partial s_{3}}{partial s_{1}}=frac{partial s_{3}}{partial s_{2}} frac{partial s_{2}}{partial s_{1}} ,我們是對一個vector求另一個vector的導數,因此結果是Jacobian Matrix,我們將梯度進行重寫得到如下的式子

frac{partial E_{3}}{partial W}=sum_{k=0}^{3}{frac{partial E_{3}}{partial hat{y_{3}}} frac{partial hat{y_{3}}}{partial s_{3}} frac{partial s_{3}}{partial s_{k}}} frac{partial s_{k}}{partial W}=sum_{k=0}^{3}{frac{partial E_{3}}{partial hat{y_{3}}} frac{partial hat{y_{3}}}{partial s_{3}} (prod_{j=k+1}^{3}frac{partial s_{j}}{partial s_{j-1}}) frac{partial s_{k}}{partial W}}

On the difficulty of training recurrent neural networks這篇論文證明了上述Jacobian矩陣的2-norm上界為1,你可以認為是一個絕值,這是非常直觀的因為tanh函數將所有的值映射到(-1, 1),其導數上界為1;sigmoid函數將所有的值映射到(0, 1), 其導數上界為 frac{1}{4} .

tanh激活函數和導數示意圖

可以看到tanh和sigmoid函數在兩端的導數接近0,它們接近一條平坦的直線,我們稱神經元是飽和的,它們梯度為0並且回將前一層中的其他梯度驅動到0。因此在矩陣和多個矩陣乘法(t-k)的值比較小的情況下,梯度值快速地指數收縮,最終在幾個時間步後完全消失。來自遙遠步驟的梯度貢獻會變為0,這些步驟的狀態不會影響正在學習的內容:你最終無法學習到long-range dependencies, 梯度消失不是RNN獨有的,在deep Feedforward Neural Network. 只不過RNN往往比較深,這使得梯度消失問題在RNN中比較顯著。

因此,雖然簡單循環網路從理論上可以建立長時間間隔的狀態之間的依賴關係(Long-Term Dependencies),但是由於梯度爆炸或消失問題,實際上只能學習到短周期的依賴關係。這就是所謂的長期依賴問題。

很容易想像,基於我們的激活函數和網路參數,可能會產生梯度爆炸如果Jacobian matrix的值比較大的話,這稱為exploding gradient problem.

Vanishing gradients比exploding gradient受到更多關注有兩個原因:

  1. exploding gradients是比較明顯的,梯度會變成NaN並且代碼會崩潰;
  2. 在預定義的閾值前對梯度進行修剪是一個簡單而有效的解決梯度爆炸問題的方案

Vanishing gradients問題更複雜因為它的出現並不明顯,而且不太清楚如何解決。

幸運的是有一些方法可以解決vanishing gradient problem。對於W矩陣適當的初始化可以減少vanishing gradient的效果,正則化也可以。

一個更佳的方法是採用ReLU替代tanh或者sigmoid函數,ReLU導數要麼是0,要麼是1,不會出現梯度消失問題。

另一個更佳流行的方法是採用Long Short-Term Memory(LSTM) 或者Gated Recurrent Unit(GRU) 架構。LSTM是1997年提出來的,並且是NLP中應用最廣泛的模型。

GRUs是2014年提出來的,是LSTM的簡化版本,兩種RNN架構都用來處理梯度消失問題並且俄能夠有效的學習long-range dependencies,我們將在後邊部分介紹。


推薦閱讀:

NMT 如何保證句子通順?
NAACL2018 | 傑出論文:RNN作為識別器,判定加權語言一致性
ENAS的原理和代碼解析
追根溯源:深度學習架構譜系
簡單的Char RNN生成文本

TAG:神經網路 | RNN | 演算法 |