36計多1計!教你應對神經網路的"神經"問題~

該網路已經訓練了近12個小時。 一切看起來都很好:梯度流動,損失正在減少。 但預測時全部為零,全部為背景,什麼都沒有檢測到。 「我做錯了什麼?"——我問我的電腦,但它沒有回答。

當你的模型輸出垃圾時,你從哪裡開始檢查?(例如預測所有輸出的平均值,還是精度差?)網路可能由於多種原因而無法進行訓練。 在過程中有很多調試,而我經常會發現自己做同樣的檢查。在這個名單中我彙編了我的經驗和最好的想法,我希望這也對你有用。

很多事情可能會出錯。 但其中有些更有可能更常出錯。我通常從這簡短的列表開始,作為第一緊急反應:

  • 從一個已知可以用於這類數據的簡單模型開始(例如,圖像用的VGG)。 如果可能,使用標準的損失函數。
  • 省略所有不必要的東西,例如正則化和數據增加。
  • 如果在優化模型,請仔細檢查預處理,因為它應該是與原始模型的訓練相同。
  • 驗證輸入數據是否正確。
  • 從一個非常小的數據集(2-20個樣本)開始。 過度擬合它並逐漸添加更多數據。
  • 開始逐漸添加所有省略的部分:

數據增加/正規化,定製損失函數,嘗試更複雜的模型。

如果上述步驟不執行,請從下面的大表中一項一項驗證。

數據集問題

1.檢查數據輸入

檢查你輸入網路的數據是否正確。例如,我已經不止一次混合了圖像的寬度和高度。有時,我會錯誤地給所有的零。或者我會一次又一次地使用同一批次。因此,顯示幾批輸入和目標輸出,並確保它們正常。

2.嘗試隨機輸入

嘗試使用隨機數而不是實際的數據,看看錯誤的行為是否相同。如果相同,你的網路確實在某些時候將數據轉換成垃圾。嘗試逐層調試,查看出錯的地方。

3.檢查數據載入器

你的數據可能好,但輸入數據到網路的代碼可能壞了。在任何操作之前輸出第一層的輸入並進行檢查。

4.確保輸入是連接到輸出

檢查幾個輸入樣本是否有正確的標籤。並確保用相同方式混洗輸入樣本與輸出標籤。

5.輸入和輸出之間的關係是否太隨機?

也許輸入和輸出之間,非隨機部分與隨機部分相比太小(可以認為股票價格是這樣的),即輸入與輸出無足夠的關係。沒有一種一至的方式來檢測這樣的問題,因為它取決於數據的性質。

6.數據集中是否有太多噪音?

當我從食品網站上爬下一幅圖像數據時,這發生在我身上。網路無法學習的標籤太多了。手動檢查一堆輸入樣本,看看標籤是否有問題。截止點是可爭議的,如這論文使用50%損壞標籤,在MNIST上準確率達到50%以上。

7.重新排列數據集

如果你的數據集沒有被打亂並且具有特定的順序(按照標籤排序),這可能會對學習產生負面影響。隨機打亂混洗你的數據集來避免這種情況。確保將輸入和標籤混在一起。

8.減少類別不平衡

每個B類圖像都有1000個A類圖像嗎?那麼你可能需要平衡損失函數或嘗試其他類別不平衡的方法。

9.你有足夠的訓練案例嗎?

如果你從頭開始訓練一個網路,那麼你可能需要大量的數據。對於圖像分類,人們說每一類需要1000張圖片。

10.確保你的批次不包含單個標籤

這可能發生在排序數據集中(即前10k樣本包含相同的類別)。通過混洗數據集可以輕鬆修復。

11.減少批量大小

這篇論文指出,擁有非常大的批次可以降低模型的泛化能力。

附註1.使用標準數據集(例如mnist,cifar10)

當測試新的網路架構或編寫新的代碼時,先使用標準數據集,而不是自己的數據。 這是因為這些數據集有很多參考結果,它們被證明是「可解決的」。不存在標籤噪音,訓練/測試分配差異,數據集難度過大等問題。

數據歸一化/擴增問題

12.標準化功能

你的輸入是否標準化,具有零均值和零單位差異?

13.你有太多的數據擴容

數據擴容具有正規化效果。過多的數據擴容,加上其他形式的正規化(weight L2,dropout等),可能會導致網路欠擬合。

14.檢查預處理模型的預處理

如果你使用預先訓練的模型,請確保你正在使用與訓練時相同的歸一化和預處理。 例如,圖像像素應該在[0,1],[-1,1]或[0,255]的範圍內?

15.檢查訓練/驗證/測試集的預處理

CS231n指出了一個常見的陷阱:

「...任何預處理統計(例如數據均值)只能在訓練數據上計算,然後應用於驗證/測試數據。 例如計算平均值,並從整個數據集中的每個圖像中減去它,然後將數據分割成訓練/驗證/測試將是一個錯誤。「

另外,檢查每個樣本或批次中是否有不同的預處理。

執行問題

16.嘗試解決一個更簡單的版本

這將有助於找到問題的所在。 例如,如果目標輸出是對象類和坐標,請嘗試將預測限制為對象類。

17. 尋找正確的損失「偶然」

再次從優秀的CS231n得到下面的提示:從小參數開始,不需正則化。 例如我們有10個類別,」偶然」就意味著我們將在10%的機會內獲得正確的類別,並且Softmax損失是正確類的負對數概率,所以:-ln(0.1)= 2.302。 此後,嘗試提高正規化實力,增加損失。

18.檢查你的損失函數

如果你使用了自己的損失函數,請檢查它是否存在錯誤並添加單元測試。通常我的損失會稍微不正確,並以微妙的方式傷害網路的性能。

19.驗證損失輸入

如果你正在使用由你的框架提供的損失函數,請確保你傳給它正確的內容。例如,在PyTorch中,我會混合NLLLoss和CrossEntropyLoss,因為前者需要一個softmax輸入,而後者不用。

20.調整損失權重

如果你的損失由幾個較小的損失函數組成,請確保其相對於每個函數的幅度是相對正確的。這可能涉及測試不同的損失權重組合。

21.監控其他指標

有時損失不是你的網路是否正確被訓練的最佳預測因素。如果可以,請使用其他指標,如準確性。

22.測試任何自定義層

你自己在網路中實現了哪些層?檢查並仔細檢查以確保它們按預期運作。

23.檢查「凍結」層或變數

檢查是否無意中禁用應該可以學習的某些層/變數的梯度更新。

24.增加網路規模

也許網路的表現力不足以捕獲目標函數。嘗試在完全連接的層中添加更多層或更多的隱藏單元。

25.檢查隱藏的維度錯誤

如果你的輸入看起來像(k,H,W)=(64,64,64),很容易錯過與錯誤維度相關的錯誤。對輸入維度使用奇數(例如,每個維度使用不同素數),並檢查它們如何通過網路傳播。

26. 探索梯度檢查

如果你手動實施梯度下降法,則梯度檢查確保你的反向傳播能夠正常運作。 更多信息:1 2 3。

訓練問題

27.解決一個真正的小數據集

過度擬合一小部分數據,並確保它可以正常運作。 例如,訓練1或2個樣本,看看你的網路是否可以學習區分這些。每一類別用更多的樣本來進行。

28.檢查初始化權重

如果不確定,使用Xavier或He初始化。 此外,你的初始化可能會導致你的本地最小值不足,因此請嘗試其他初始化,看是否有幫助。

29.改變你的超級參數

也許你使用一套特別糟糕的超級參數。 如果可行,請嘗試網格搜索。

30.減少正規化

太多的正規化可能會導致網路的過度欠擬合。減少正規化,如dropout,批量規範,權重/偏倚L2正規化等。厲害的「編程人員實踐深層次學習」課程中,傑里米·霍華德(Jeremy Howard)建議首先排除欠擬合。 這意味著你充分訓練數據達到過擬合,屆時再解決這個問題。

31. 給它時間

也許你的網路需要更多的時間訓練才能開始有意義的預測。如果你的損失在穩定下降,那就讓它進行更多的訓練。

32. 從訓練切換到測試模式

某些框架的層,如「批次規範」,「dropout」和其他層,在訓練和測試期間的反應不一樣。切換到適當的模式可能會幫助你的網路正常進行預測。

33.可視化訓練

監控每個層的激活,權重和更新。確保他們的大小匹配。例如,參數更新的大小(權重和偏差)應為1-e3。

考慮像Tensorboard和Crayon這樣的可視化庫。必要時,你也可以輸出權重/偏差/激活。注意層次激活中那些均值大於,嘗試批量標準或ELU。

Deeplearning4j指出了權重和偏差直方圖應該要有的樣子:「對於權重,這些直方圖在一段時間後應該具有近似高斯(正態)的分布。對於偏差,這些直方圖通常從0開始,最終會達到近似高斯(LSTM是個例外)。注意分歧為+/-無限遠的參數。留意變得非常大的偏差。這有時會在分類的輸出層發生,如果類別的分布非常不平衡。檢查層更新,它們應該具有高斯分布。

34.嘗試一個不同的優化器

你選擇的優化器不應該阻止你的網路進行訓練,除非你選擇了特別糟糕的超級參數。然而,適當的優化器可以幫助你在最短的時間內獲得最多的訓練。你正在使用的演算法的論文應該會指定優化器。如果沒有,我傾向於使用Adam或plain SGD包含momentum。

35.爆發/消失的梯度

檢查層更新,因為非常大的值可能代表梯度的爆發。梯度剪裁可能有幫助。檢查層激活。Deeplearning4j提出了一個很好的指導方針:「激活的良好標準差為0.5到2.0。超出此範圍可能代表激活的消失或爆炸。」

36.增加/減少學習率

低學習率會導致模型收斂速度非常慢。高學習率一開始就會很快減少損失,但可能難以找到一個很好的解決方案。 試試將當前學習速率乘以0.1或10。

37. 克服NaNs

在訓練RNNs的時候(我聽說),發生NaN(Non-a-Number)是一個更大的問題。 一些解決方法:降低學習率,特別是如果在前100次迭代中獲得NaNs。 NaN可以由零或自然對數零或負數產生。

羅素·斯圖爾特(Russell Stewart)對如何處理NaNs有很大的指導。 嘗試逐層評估網路,並查看NaN出現的位置。

翻譯:Young Lin

編輯:小咪


推薦閱讀:

word embedding之GLOVE代碼
機器學習與數據挖掘網上資源
[貝葉斯六]之樸素貝葉斯分類器設計
機器學習基礎與實踐(二)----數據轉換
ML4-Brief Introduction of Deep Learning(李宏毅筆記)

TAG:學習方法 | 深度學習DeepLearning | 機器學習 |