訓練的神經網路不工作?一文帶你跨過這37個坑

近日,Slav Ivanov 在 Medium 上發表了一篇題為《37 Reasons why your Neural Network is not working》的文章,從四個方面(數據集、數據歸一化/增強、實現、訓練),對自己長久以來的神經網路調試經驗做了 37 條總結,並穿插了不少出色的個人想法和思考,希望能幫助你跨過神經網路訓練中的 37 個大坑。機器之心對該文進行了編譯,原文鏈接請見文末。

神經網路已經持續訓練了 12 個小時。它看起來很好:梯度在變化,損失也在下降。但是預測結果出來了:全部都是零值,全部都是背景,什麼也檢測不到。我質問我的計算機:「我做錯了什麼?」,它卻無法回答。

如果你的模型正在輸出垃圾(比如預測所有輸出的平均值,或者它的精確度真的很低),那麼你從哪裡開始檢查呢?

無法訓練神經網路的原因有很多,因此通過總結諸多調試,作者發現有一些檢查是經常做的。這張列表匯總了作者的經驗以及最好的想法,希望也對讀者有所幫助。

I. 數據集問題

1. 檢查你的輸入數據

檢查饋送到網路的輸入數據是否正確。例如,我不止一次混淆了圖像的寬度和高度。有時,我錯誤地令輸入數據全部為零,或者一遍遍地使用同一批數據執行梯度下降。因此列印/顯示若干批量的輸入和目標輸出,並確保它們正確。

2. 嘗試隨機輸入

嘗試傳遞隨機數而不是真實數據,看看錯誤的產生方式是否相同。如果是,說明在某些時候你的網路把數據轉化為了垃圾。試著逐層調試,並查看出錯的地方。

3. 檢查數據載入器

你的數據也許很好,但是讀取輸入數據到網路的代碼可能有問題,所以我們應該在所有操作之前列印第一層的輸入並進行檢查。

4. 確保輸入與輸出相關聯

檢查少許輸入樣本是否有正確的標籤,同樣也確保 shuffling 輸入樣本同樣對輸出標籤有效。

5. 輸入與輸出之間的關係是否太隨機?

相較於隨機的部分(可以認為股票價格也是這種情況),輸入與輸出之間的非隨機部分也許太小,即輸入與輸出的關聯度太低。沒有一個統一的方法來檢測它,因為這要看數據的性質。

6. 數據集中是否有太多的噪音?

我曾經遇到過這種情況,當我從一個食品網站抓取一個圖像數據集時,錯誤標籤太多以至於網路無法學習。手動檢查一些輸入樣本並查看標籤是否大致正確。

7. Shuffle 數據集

如果你的數據集沒有被 shuffle,並且有特定的序列(按標籤排序),這可能給學習帶來不利影響。你可以 shuffle 數據集來避免它,並確保輸入和標籤都被重新排列。

8. 減少類別失衡

一張類別 B 圖像和 1000 張類別 A 圖像?如果是這種情況,那麼你也許需要平衡你的損失函數或者嘗試其他解決類別失衡的方法。

9. 你有足夠的訓練實例嗎?

如果你在從頭開始訓練一個網路(即不是調試),你很可能需要大量數據。對於圖像分類,每個類別你需要 1000 張圖像甚至更多。

10. 確保你採用的批量數據不是單一標籤

這可能發生在排序數據集中(即前 10000 個樣本屬於同一個分類)。可通過 shuffle 數據集輕鬆修復。

11. 縮減批量大小

巨大的批量大小會降低模型的泛化能力(參閱:https://arxiv.org/abs/1609.04836)

II. 數據歸一化/增強

12. 歸一化特徵

你的輸入已經歸一化到零均值和單位方差了嗎?

13. 你是否應用了過量的數據增強?

數據增強有正則化效果(regularizing effect)。過量的數據增強,加上其它形式的正則化(權重 L2,中途退出效應等)可能會導致網路欠擬合(underfit)。

14. 檢查你的預訓練模型的預處理過程

如果你正在使用一個已經預訓練過的模型,確保你現在正在使用的歸一化和預處理與之前訓練模型時的情況相同。例如,一個圖像像素應該在 [0, 1],[-1, 1] 或 [0, 255] 的範圍內嗎?

15. 檢查訓練、驗證、測試集的預處理

CS231n 指出了一個常見的陷阱:「任何預處理數據(例如數據均值)必須只在訓練數據上進行計算,然後再應用到驗證、測試數據中。例如計算均值,然後在整個數據集的每個圖像中都減去它,再把數據分發進訓練、驗證、測試集中,這是一個典型的錯誤。」此外,要在每一個樣本或批量(batch)中檢查不同的預處理。

III. 實現的問題

16. 試著解決某一問題的更簡易的版本。

這將會有助於找到問題的根源究竟在哪裡。例如,如果目標輸出是一個物體類別和坐標,那就試著把預測結果僅限制在物體類別當中(嘗試去掉坐標)。

17.「碰巧」尋找正確的損失

還是來源於 CS231n 的技巧:用小參數進行初始化,不使用正則化。例如,如果我們有 10 個類別,「碰巧」就意味著我們將會在 10% 的時間裡得到正確類別,Softmax 損失是正確類別的負 log 概率: -ln(0.1) = 2.302。然後,試著增加正則化的強度,這樣應該會增加損失。

18. 檢查你的損失函數

如果你執行的是你自己的損失函數,那麼就要檢查錯誤,並且添加單元測試。通常情況下,損失可能會有些不正確,並且損害網路的性能表現。

19. 核實損失輸入

如果你正在使用的是框架提供的損失函數,那麼要確保你傳遞給它的東西是它所期望的。例如,在 PyTorch 中,我會混淆 NLLLoss 和 CrossEntropyLoss,因為一個需要 softmax 輸入,而另一個不需要。

20. 調整損失權重

如果你的損失由幾個更小的損失函數組成,那麼確保它們每一個的相應幅值都是正確的。這可能會涉及到測試損失權重的不同組合。

21. 監控其它指標

有時損失並不是衡量你的網路是否被正確訓練的最佳預測器。如果可以的話,使用其它指標來幫助你,比如精度。

22. 測試任意的自定義層

你自己在網路中實現過任意層嗎?檢查並且複核以確保它們的運行符合預期。

23. 檢查「冷凍」層或變數

檢查你是否無意中阻止了一些層或變數的梯度更新,這些層或變數本來應該是可學的。

24. 擴大網路規模

可能你的網路的表現力不足以採集目標函數。試著加入更多的層,或在全連層中增加更多的隱藏單元。

25. 檢查隱維度誤差

如果你的輸入看上去像(k,H,W)= (64, 64, 64),那麼很容易錯過與錯誤維度相關的誤差。給輸入維度使用一些「奇怪」的數值(例如,每一個維度使用不同的質數),並且檢查它們是如何通過網路傳播的。

26. 探索梯度檢查(Gradient checking)

如果你手動實現梯度下降,梯度檢查會確保你的反向傳播(backpropagation)能像預期中一樣工作。

IV. 訓練問題

27. 一個真正小的數據集

過擬合數據的一個小子集,並確保其工作。例如,僅使用 1 或 2 個實例訓練,並查看你的網路是否學習了區分它們。然後再訓練每個分類的更多實例。

28. 檢查權重初始化

如果不確定,請使用 Xavier 或 He 初始化。同樣,初始化也許會給你帶來壞的局部最小值,因此嘗試不同的初始化,看看是否有效。

29. 改變你的超參數

或許你正在使用一個很糟糕的超參數集。如果可行,嘗試一下網格搜索。

30. 減少正則化

太多的正則化可致使網路嚴重地欠擬合。減少正則化,比如 dropout、批規範、權重/偏差 L2 正則化等。在優秀課程《編程人員的深度學習實戰》(Practical Deep Learning For Coders-18 hours of lessons for free)中,Jeremy Howard 建議首先解決欠擬合。這意味著你充分地過擬合數據,並且只有在那時處理過擬合。

31. 給它一些時間

也許你的網路需要更多的時間來訓練,在它能做出有意義的預測之前。如果你的損失在穩步下降,那就再多訓練一會兒。

32. 從訓練模式轉換為測試模式

一些框架的層很像批規範、Dropout,而其他的層在訓練和測試時表現並不同。轉換到適當的模式有助於網路更好地預測。

33. 可視化訓練

監督每一層的激活值、權重和更新。確保它們的大小匹配。例如,參數更新的大小(權重和偏差)應該是 1-e3。

考慮可視化庫,比如 Tensorboard 和 Crayon。緊要時你也可以列印權重/偏差/激活值。

尋找平均值遠大於 0 的層激活。嘗試批規範或者 ELUs。

Deeplearning4j 指出了權重和偏差柱狀圖中的期望值:對於權重,一些時間之後這些柱狀圖應該有一個近似高斯的(正常)分布。對於偏差,這些柱狀圖通常會從 0 開始,並經常以近似高斯(這種情況的一個例外是 LSTM)結束。留意那些向 +/- 無限發散的參數。留意那些變得很大的偏差。這有時可能發生在分類的輸出層,如果類別的分布不均勻。

檢查層更新,它們應該有一個高斯分布。

34. 嘗試不同的優化器

優化器的選擇不應當妨礙網路的訓練,除非你選擇了一個特別糟糕的參數。但是,為任務選擇一個合適的優化器非常有助於在最短的時間內獲得最多的訓練。描述你正在使用的演算法的論文應當指定優化器;如果沒有,我傾向於選擇 Adam 或者帶有動量的樸素 SGD。

35. 梯度爆炸、梯度消失

檢查隱蔽層的最新情況,過大的值可能代表梯度爆炸。這時,梯度截斷(Gradient clipping)可能會有所幫助。

檢查隱蔽層的激活值。Deeplearning4j 中有一個很好的指導方針:「一個好的激活值標準差大約在 0.5 到 2.0 之間。明顯超過這一範圍可能就代表著激活值消失或爆炸。」

36. 增加、減少學習速率

低學習速率將會導致你的模型收斂很慢;

高學習速率將會在開始階段減少你的損失,但是可能會導致你很難找到一個好的解決方案。

試著把你當前的學習速率乘以 0.1 或 10。

37. 克服 NaNs

據我所知,在訓練 RNNs 時得到 NaN(Non-a-Number)是一個很大的問題。一些解決它的方法:

減小學習速率,尤其是如果你在前 100 次迭代中就得到了 NaNs。

NaNs 的出現可能是由於用零作了除數,或用零或負數作了自然對數。

Russell Stewart 對如何處理 NaNs 很有心得(http://russellsstewart.com/notes/0.html)。

嘗試逐層評估你的網路,這樣就會看見 NaNs 到底出現在了哪裡。

資源:

CS231n Convolutional Neural Networks for Visual Recognition

russellsstewart.com/not

Neural network always predicts the same class

How to Visualize, Monitor and Debug Neural Network Learning

What does "debugging" a deep net look like? ? r/MachineLearning

Why the prediction or the output of neural network does not change during the test phase?

Neural Network predictions converging to one value

gab41.lab41.org/some-ti

How to debug an artificial neural network algorithm

選自Medium 機器之心編譯


推薦閱讀:

神經網路中,bias有什麼用,為什麼要設置bias,當加權和大於某值時,激活才有意義?
有關神經網路和遺傳演算法?
深度學習在自然語言處理中到底發揮了多大作用?有哪些不足或局限?

TAG:神经网络 | 机器学习 |