通往無限層神經網路 (2):一個富迭代性的微分方程,與幾個小實驗

在上一篇我們看了殘差結構在深層網路中的必要性:對於殘差網路(Residual Network)的一種理解方法,與深層網路的訓練 - 知乎專欄。

現在網路越來深了,那麼最深就是無限深。我們在此運用殘差結構,構造一個無限層的神經網路,或者說,一個"層的深度可以連續變化"的神經網路(不知道之前有沒有人做過這個,如果您有什麼想法,歡迎討論)。

還是拿上一次的例子,令網路只有 1 個輸入,1 個輸出。令網路的輸出為 f,層的深度是從 0 到 1 連續標記。

於是這裡核心的方程是:

f(0,x)=x,;;;frac{partial f(t,x)}{partial t} = N(t,,{w_t},, f(t,x))

這裡用 N(t,, {w_t},, f(t,x)) 代表 "第 t 層" 的操作,其中 {w_t} 是權重,f(t,x) 是之前的層給出的輸出。這個微分方程比較有趣,如果你把它展開會發現它具有很強的迭代性,因為神經網路具有迭代性,所以它的解會出現 e^t 這樣的項。

例子1:

由於 f(t,x) 的情況比較複雜,我們不妨先看更簡單的情況:

f(0,x)=0,;;;frac{partial f(t,x)}{partial t} = N(t,, {w_t},, x)

這裡的殘差結構不是用之前的輸出 f(t,x),而是直接用原始數據 x 作為輸入。這已經足以給出相當複雜的最終網路輸出。不妨看個最簡單的例子。令:

N(t,, {w_t},, x) = Max[0,, x + (w_1 t + w_2)]

即,它擁有 1 個權重固定為 1,而偏置隨著層的深度 t 的加深會線性變化的 ReLU 單元。於是我們積分一下就可以得到網路的輸出:

f(x) = f(1,x) = int_0^1 N(t,, {w_t},, x) ,dt = ?

答案是:

這個輸出就比較複雜,而且還出現了 x^2 這樣的高次項。

例子2:

如果我們把 w_tt 的關係取得更複雜,答案更是會非常複雜。舉例,如果加一個看似不起眼的 t^2 因子:

N(t,, {w_t},, x) = Max[0,, x + (t^2 + w_1 t + w_2)]

最終網路的輸出將是:

可以說它會很容易進入混沌狀態。畢竟它是個無限深的神經網路。

例子3:

再看殘差結構用之前的輸出 f(t,x) 的情況。例如,最簡單的例子是這樣:

f(0,x)=x,;;;frac{partial f(t,x)}{partial t} = Max(0, f(t,x))

這個解出來是:

f(t,x) = IF;;x < 0;;THEN;;x;;ELSE;;x cdot e^t

但是如果仔細思考,會發現這種殘差結構有點問題,容易丟失信息。舉例,如果在中間出現了x_1x_2 滿足 f(x_1,t)=f(x_2,t),那麼在後續它們就會被"鎖定",也就是對於任何 T>t 都有 f(x_1, T) = f(x_2, T)

這是否意味著我們也應該試試在中間層引入原始數據?其實我訓練圍棋策略網路時試過這個,但好像作用不大。對於圖像的問題,也許可以試試在中間層引入縮小到相應尺寸的原始圖像,遲點試試。

總結:

下一步的問題是如何訓練這種網路,自然的想法是 BP 是否也存在一個連續的形式。我們在後續看這個問題。

如果本文對你有啟發,請點個贊,謝謝!~ 如果您有什麼想法,也很歡迎討論。


推薦閱讀:

在機器學習時代,程序如何利用機器學習的原理反機器學習呢?
生成式對抗網路 NIPS 2016 課程 第 0~1 節
為什麼梯度下降能找到最小值?
什麼是強化學習?

TAG:深度学习DeepLearning | 人工智能 | 机器学习 |