數學 · RNN(二)· BPTT 演算法

(這裡是本章會用到的 GitHub 地址)

(感謝評論區 @陌燭 指出本文的諸多錯誤!!真的非常感謝!!【拜】)

RNN 的「前向傳導演算法」

在說明如何進行訓練之前,我們先來回顧一下 RNN 的「前向傳導演算法。在上一章中曾經給過一個沒有激活函數和變換函數的公式:

begin{align} o_{1} &= Vs_{1} = Vleft( Ux_{1} right) = x_{1}  o_{2} &= Vs_{2} = Vleft( Ws_{1} + Ux_{2} right) = 2s_{1} + x_{2}  ldots  o_{t} &= Vs_{t} = Vleft( Ws_{t - 1} + Ux_{t} right) = 2s_{t - 1} + x_{t} end{align}

在實現層面來說,這就是一個循環的事兒,所以代碼寫起來會比較簡單:

import numpy as npnnclass RNN1:n def __init__(self, u, v, w):n self._u, self._v, self._w = np.asarray(u), np.asarray(v), np.asarray(w)n self._states = Nonenn # 激活函數n def activate(self, x):n return xnn # 變換函數n def transform(self, x):n return xnn def run(self, x):n output = []n x = np.atleast_2d(x)n # 初始化 States 矩陣為零矩陣n # 之所以把所有 States 記下來、是因為訓練時(BPTT 演算法)要用到n self._states = np.zeros([len(x)+1, self._u.shape[0]])n for t, xt in enumerate(x):n # 對著公式敲代碼即可 ( σω)σn self._states[t] = self.activate(n self._u.dot(xt) + self._w.dot(self._states[t-1])n )n output.append(self.transform(n self._v.dot(self._states[t]))n )n return np.array(output)n

可以用上一章說過的那個小栗子來測試一下:

  • 假設現在U,V是單位陣,W是單位陣的兩倍

  • 假設輸入序列為:left( 1,0,0,ldots,0 right)^{T} rightarrow left( 0,1,0,ldots,0 right)^{T} rightarrow left( 0,0,1,ldots,0 right)^{T} rightarrow ldots rightarrow left( 0,0,0,ldots,1 right)^{T}

對應的測試代碼如下:

n_sample = 5nrnn = RNN1(np.eye(n_sample), np.eye(n_sample), np.eye(n_sample) * 2)nprint(rnn.run(np.eye(n_sample)))n

程序輸出為:

這和我們上一章推出的理論值left( 1,0,0,ldots,0 right)^{T} rightarrow left( 2,1,0,ldots,0 right)^{T} rightarrow left( 4,2,1,ldots,0 right)^{T} rightarrow ldots rightarrow left( 2^{n - 1},2^{n - 2},2^{n - 3},ldots,1 right)^{T}是一致的(n=5

RNN 的「反向傳播演算法」

簡潔起見,我們採用上一章第一張圖所示的那個樸素網路結構:

然後做出如下符號約定:

  • phi作為隱藏層的激活函數
  • varphi作為輸出層的變換函數
  • L_{t} = L_{t}left( o_{t},y_{t} right)作為模型的損失函數,其中標籤y_{t}是一個 one-hot 向量;由於 RNN 處理的通常是序列數據、所以在接受完序列中所有樣本後再統一計算損失是合理的,此時模型的總損失可以表示為(假設輸入序列長度為n):L = sum_{t = 1}^{n}L_{t}

為了更清晰地表明各個配置,我們可以整理出如下圖所示的結構:

易知o_{t} = varphileft( text{Vs}_{t} right) = varphileft( text{V?}left( Ux_{t} + Ws_{t - 1} right) right),其中s_{0} = mathbf{0 =}left( 0,0,ldots,0 right)^{T}。令:

o_{t}^{*} = text{Vs}_{t},  s_{t}^{*} = Ux_{t} + Ws_{t - 1}

則有:

o_{t} = varphileft( o_{t}^{*} right),  s_{t} = phi(s_{t}^{*})

從而(註:統一使用「*」表示 element wise 乘法,使用「times」表示矩陣乘法):

frac{partial L_{t}}{partial o_{t}^{*}} = frac{partial L_{t}}{partial o_{t}}*frac{partial o_{t}}{partial o_{t}^{*}} = frac{partial L_{t}}{partial o_{t}}*varphi^{}left( o_{t}^{*} right)

frac{partial L_{t}}{partial V} = frac{partial L_{t}}{partial Vs_{t}} times frac{partial Vs_{t}}{partial V} = left( frac{partial L_{t}}{partial o_{t}}*varphi^{}left( o_{t}^{*} right) right) times s_{t}^{T}

可見對矩陣V的分析過程即為普通的反向傳播演算法,相對而言比較平凡。由L = sum_{t = 1}^{n}L_{t}可知,它的總梯度可以表示為:

frac{partial L}{partial V} = sum_{t = 1}^{n}{left( frac{partial L_{t}}{partial o_{t}}*varphi^{}left( o_{t}^{*} right) right) times s_{t}^{T}}

而事實上,RNN 的 BP 演算法的主要難點在於它 State 之間的通信,亦即梯度除了按照空間結構傳播(o_{t} rightarrow s_{t} rightarrow x_{t})以外,還得沿著時間通道傳播(s_{t} rightarrow s_{t - 1} rightarrow ldots rightarrow s_{1}),這導致我們比較難將相應 RNN 的 BP 演算法寫成一個統一的形式(回想之前的「前向傳導演算法」)。為此,我們可以採用「循環」的方法來計算各個梯度

由於是反向傳播演算法,所以t應從n開始降序循環至 1,在此期間(若需要初始化、則初始化為 0 向量或 0 矩陣):

  • 計算時間通道上的「局部梯度」 :

    begin{align} frac{partial L_{t}}{partial s_{t}^{*}} &= frac{partial s_{t}}{partial s_{t}^{*}}*left( frac{partial s_{t}^{T}V^{T}}{partial s_{t}} times frac{partial L_{t}}{partial Vs_{t}} right) = phi(s_t^*)*left[V^{T} times left( frac{partial L_{t}}{partial o_{t}}*varphi^{}left( o_{t}^{*} right) right)right]  frac{partial L_{t}}{partial s_{k - 1}^{*}} &= frac{partial s_{k}^{*}}{partial s_{k - 1}^{*}} times frac{partial L_{t}}{partial s_{k}^{*}} = phi^{}left( s_{k - 1}^{*} right) * left( W^{T} times frac{partial L_{t}}{partial s_{k}^{*}} right),  (k = 1,ldots,t) end{align}
  • 利用時間通道上的「局部梯度」計算UW的梯度:

    begin{align} frac{partial L_{t}}{partial U} &= sum_{k = 1}^{t}{frac{partial L_{t}}{partial s_{k}^{*}} times frac{partial s_{k}^{*}}{partial U}} = sum_{k = 1}^{t}{frac{partial L_{t}}{partial s_{k}^{*}} times x_{k}^{T}}  frac{partial L_{t}}{partial W} &= sum_{k = 1}^{t}{frac{partial L_{t}}{partial s_{k}^{*}} times frac{partial s_{k}^{*}}{partial W}} = sum_{k = 1}^{t}{frac{partial L_{t}}{partial s_{k}^{*}} times s_{k - 1}^{T}} end{align}

以上即為 RNN 反向傳播演算法的所有推導,它比 NN 的 BP 演算法要繁複不少。事實上,像這種需要把梯度沿時間通道傳播的 BP 演算法是有一個專門的名詞來描述的——Back Propagation Through Time(常簡稱為 BPTT,可譯為「時序反向傳播演算法」)

不妨舉一個具體的栗子來加深理解,假設:

  • 激活函數phi為 Sigmoid 函數
  • 變換函數varphi為 Softmax 函數
  • 損失函數L_{t}為 Cross Entropy(感謝評論區 @格子非 指出這裡的錯誤):L_{t}left( o_{t},y_{t} right) = -left[y_{t}log o_{t}+(1-y_t)log(1-o_t)right]

由 NN 處的討論可知這是一個非常經典、有效的配置,其中:

frac{partial L_{t}}{partial o_{t}}*varphi^{}left( o_{t}^{*} right) = o_{t} - y_{t}

phi^{}left( s_{t}^{*} right) = phileft( s_{t}^{*} right)*left( 1 - phileft( s_{t}^{*} right) right) = s_{t}*(1 - s_{t})

從而

frac{partial L}{partial V} = sum_{t = 1}^{n}{left( o_{t} - y_{t} right) times s_{t}^{T}}

tn開始降序循環至 1 的期間中,各個「局部梯度」為:

begin{align} frac{partial L_{t}}{partial s_{t}^{*}} &= V^{T} times left( frac{partial L_{t}}{partial o_{t}}*varphi^{}left( o_{t}^{*} right) right) = left[s_t*(1-s_t)right]*left[ V^{T} times (o_{t} - y_{t})right]  frac{partial L_{t}}{partial s_{k - 1}^{*}} &= W^{T} times left( frac{partial L_{t}}{partial s_{k}^{*}}*phi^{}left( s_{k - 1}^{*} right) right) = [s_{k - 1}*left( 1 - s_{k - 1} right)] * left(W^{T} times frac{partial L_{t}}{partial s_{k}^{*}} right),  (k = 1,ldots,t) end{align}

由此可算出如下相應梯度:

begin{align} frac{partial L_{t}}{partial U} &= sum_{k = 1}^{t}{frac{partial L_{t}}{partial s_{k}^{*}} times x_{k}^{T}}  frac{partial L_{t}}{partial W} &= sum_{k = 1}^{t}{frac{partial L_{t}}{partial s_{k}^{*}} times s_{k - 1}^{T}} end{align}

可以看到形式相當簡潔,所以我們完全可以比較輕易地寫出相應實現:

class RNN2(RNN1):n # 定義 Sigmoid 激活函數n def activate(self, x):n return 1 / (1 + np.exp(-x))nn # 定義 Softmax 變換函數n def transform(self, x):n safe_exp = np.exp(x - np.max(x))n return safe_exp / np.sum(safe_exp)nn def bptt(self, x, y):n x, y, n = np.asarray(x), np.asarray(y), len(y)n # 獲得各個輸出,同時計算好各個 Staten o = self.run(x)n # 照著公式敲即可 ( σω)σn dis = o - yn dv = dis.T.dot(self._states[:-1])n du = np.zeros_like(self._u)n dw = np.zeros_like(self._w)n for t in range(n-1, -1, -1):n st = self._states[t]n ds = self._v.T.dot(dis[t]) * st * (1 - st)n # 這裡額外設定了最多往回看 10 步n for bptt_step in range(t, max(-1, t-10), -1):n du += np.outer(ds, x[bptt_step])n dw += np.outer(ds, self._states[bptt_step-1])n st = self._states[bptt_step-1]n ds = self._w.T.dot(ds) * st * (1 - st)n return du, dv, dwnn def loss(self, x, y):n o = self.run(x)n return np.sum(n -y * np.log(np.maximum(o, 1e-12)) -n (1 - y) * np.log(np.maximum(1 - o, 1e-12))n )n

注意我們設定了在每次沿時間通道反向傳播時、最多往回看 10 步,這是因為我們實現的這種樸素 RNN 的梯度存在著一些不良性質,我們在下一節中馬上就會進行相關的說明

指數級梯度所帶來的問題

注意到由於 RNN 需要沿時間通道進行反向傳播,其相應的「局部梯度」為:

frac{partial L_{t}}{partial s_{k - 1}^{*}} = [s_{k - 1}*left( 1 - s_{k - 1} right)] * left(W^{T} times frac{partial L_{t}}{partial s_{k}^{*}} right)

注意到式中的每個局部梯度frac{partial L_{t}}{partial s_{k}^{*}}都會「攜帶」一個W矩陣和一個s_{k}的 Sigmoid 系激活函數所對應的梯度s_{k}*left( 1 - s_{k} right),這意味著局部梯度受W和各個激活函數的梯度的影響是指數級的。姑且不考慮W而單看激活函數的梯度,回憶我們之前在 NN 處講過的梯度問題,這裡的這種指數級梯度的表現和彼時深層網路梯度的表現是幾乎同理的(事實上 RNN 的時間通道長得確實很像一個深層網路)——當輸入趨近於兩端時,激活函數的梯度會隨著傳播而迅速彌散,這就是 RNN 中所謂的「梯度消失(The Vanishing Gradient)」問題。是故我們在上一小節實現 RNN 時規定在沿時間通道反向傳播時最多只往回看 10 步,這是因為再往下看也沒有太大意義了(可以大概地類比於多於 10 層的、以 Sigmoid 系函數作為激活函數的神經網路)(以下純屬開腦洞:這麼說的話是不是能在時間通道裡面傳遞殘差然後弄一個 Residual RNN 呢……)

這當然是非常令人沮喪的結果,要知道 RNN 的一大好處就在於它能利用上歷史的信息,然而梯度消失卻告訴我們 RNN 能夠利用的歷史信息十分有限。所以針對該問題作出優化是非常有必要的,解決方案大體上分兩種:

  • 選用更好的激活函數
  • 改進 State 的傳遞方式

第二點是 LSTMs 等特殊 RNN 的做法,這裡就主要說說第一點——如何選用更好的激活函數。由 NN、CNN 處的討論不難想到,用 ReLU 作為激活函數很有可能是個不錯的選擇;不過由於梯度是指數級的這一點不會改變,此時我們可能就會面臨另一個問題:「梯度爆炸(The Exploding Gradient)」(註:不是說 Sigmoid 系函數就不會引發梯度爆炸、因為當矩陣W的元素很大時同樣會爆炸,只是相對而言更容易引發梯度消失而已)。不過相比起梯度消失問題來講,梯度爆炸相對而言要顯得更「友好」一些,這是因為:

  • 梯度爆炸一旦發生,是會迅速反映到結果上來的(比如一堆數變成了 NaN)
  • 梯度爆炸可以通過簡單的設定閾值來得到改善

而梯度消失相比之下,既難以直接從結果看出、又沒有特別平凡的解決方案。現有的較常用的方法為調整參數的初值、進行適當的正則化、使用 ReLU(需要小心梯度爆炸)等等

關於為何 LSTMs 能夠解決梯度消失,直觀上來說就是上方時間通道是簡單的線性組合、從而使得梯度不再是指數級的。詳細的推導可以參見各種論文(比如說這篇),我就不在這裡獻醜了 ( σω)σ

以上就大致地說了說 RNN 的 BPTT 演算法,主要要注意的其實就是時間通道上的 BP 演算法。如果把時間通道看成一個神經網路的話,運用局部梯度來反向傳播其實相當自然

希望觀眾老爺們能夠喜歡~


推薦閱讀:

天賦在以下層級的數學學習中所佔比重各是多少?
勒貝格測度初涉(二)
歐拉到底有多厲害?
素數的倒數和是否發散?
一家人坐飛機,是不是應該分乘不同的航班,以防萬一出事,對家庭造成不可彌補的損失?

TAG:数学 | 机器学习 | 神经网络 |