36計多1計!教你應對神經網路的"神經"問題~
該網路已經訓練了近12個小時。 一切看起來都很好:梯度流動,損失正在減少。 但預測時全部為零,全部為背景,什麼都沒有檢測到。 「我做錯了什麼?"——我問我的電腦,但它沒有回答。
當你的模型輸出垃圾時,你從哪裡開始檢查?(例如預測所有輸出的平均值,還是精度差?)網路可能由於多種原因而無法進行訓練。 在過程中有很多調試,而我經常會發現自己做同樣的檢查。在這個名單中我彙編了我的經驗和最好的想法,我希望這也對你有用。
很多事情可能會出錯。 但其中有些更有可能更常出錯。我通常從這簡短的列表開始,作為第一緊急反應:
- 從一個已知可以用於這類數據的簡單模型開始(例如,圖像用的VGG)。 如果可能,使用標準的損失函數。
- 省略所有不必要的東西,例如正則化和數據增加。
- 如果在優化模型,請仔細檢查預處理,因為它應該是與原始模型的訓練相同。
- 驗證輸入數據是否正確。
- 從一個非常小的數據集(2-20個樣本)開始。 過度擬合它並逐漸添加更多數據。
- 開始逐漸添加所有省略的部分:
數據增加/正規化,定製損失函數,嘗試更複雜的模型。
如果上述步驟不執行,請從下面的大表中一項一項驗證。
數據集問題
1.檢查數據輸入
檢查你輸入網路的數據是否正確。例如,我已經不止一次混合了圖像的寬度和高度。有時,我會錯誤地給所有的零。或者我會一次又一次地使用同一批次。因此,顯示幾批輸入和目標輸出,並確保它們正常。
2.嘗試隨機輸入
嘗試使用隨機數而不是實際的數據,看看錯誤的行為是否相同。如果相同,你的網路確實在某些時候將數據轉換成垃圾。嘗試逐層調試,查看出錯的地方。
3.檢查數據載入器
你的數據可能好,但輸入數據到網路的代碼可能壞了。在任何操作之前輸出第一層的輸入並進行檢查。
4.確保輸入是連接到輸出
檢查幾個輸入樣本是否有正確的標籤。並確保用相同方式混洗輸入樣本與輸出標籤。
5.輸入和輸出之間的關係是否太隨機?
也許輸入和輸出之間,非隨機部分與隨機部分相比太小(可以認為股票價格是這樣的),即輸入與輸出無足夠的關係。沒有一種一至的方式來檢測這樣的問題,因為它取決於數據的性質。
6.數據集中是否有太多噪音?
當我從食品網站上爬下一幅圖像數據時,這發生在我身上。網路無法學習的標籤太多了。手動檢查一堆輸入樣本,看看標籤是否有問題。截止點是可爭議的,如這論文使用50%損壞標籤,在MNIST上準確率達到50%以上。
7.重新排列數據集
如果你的數據集沒有被打亂並且具有特定的順序(按照標籤排序),這可能會對學習產生負面影響。隨機打亂混洗你的數據集來避免這種情況。確保將輸入和標籤混在一起。
8.減少類別不平衡
每個B類圖像都有1000個A類圖像嗎?那麼你可能需要平衡損失函數或嘗試其他類別不平衡的方法。
9.你有足夠的訓練案例嗎?
如果你從頭開始訓練一個網路,那麼你可能需要大量的數據。對於圖像分類,人們說每一類需要1000張圖片。
10.確保你的批次不包含單個標籤
這可能發生在排序數據集中(即前10k樣本包含相同的類別)。通過混洗數據集可以輕鬆修復。
11.減少批量大小
這篇論文指出,擁有非常大的批次可以降低模型的泛化能力。
附註1.使用標準數據集(例如mnist,cifar10)
當測試新的網路架構或編寫新的代碼時,先使用標準數據集,而不是自己的數據。 這是因為這些數據集有很多參考結果,它們被證明是「可解決的」。不存在標籤噪音,訓練/測試分配差異,數據集難度過大等問題。
數據歸一化/擴增問題
12.標準化功能
你的輸入是否標準化,具有零均值和零單位差異?
13.你有太多的數據擴容
數據擴容具有正規化效果。過多的數據擴容,加上其他形式的正規化(weight L2,dropout等),可能會導致網路欠擬合。
14.檢查預處理模型的預處理
如果你使用預先訓練的模型,請確保你正在使用與訓練時相同的歸一化和預處理。 例如,圖像像素應該在[0,1],[-1,1]或[0,255]的範圍內?
15.檢查訓練/驗證/測試集的預處理
CS231n指出了一個常見的陷阱:
「...任何預處理統計(例如數據均值)只能在訓練數據上計算,然後應用於驗證/測試數據。 例如計算平均值,並從整個數據集中的每個圖像中減去它,然後將數據分割成訓練/驗證/測試將是一個錯誤。「
另外,檢查每個樣本或批次中是否有不同的預處理。
執行問題
16.嘗試解決一個更簡單的版本
這將有助於找到問題的所在。 例如,如果目標輸出是對象類和坐標,請嘗試將預測限制為對象類。
17. 尋找正確的損失「偶然」
再次從優秀的CS231n得到下面的提示:從小參數開始,不需正則化。 例如我們有10個類別,」偶然」就意味著我們將在10%的機會內獲得正確的類別,並且Softmax損失是正確類的負對數概率,所以:-ln(0.1)= 2.302。 此後,嘗試提高正規化實力,增加損失。
18.檢查你的損失函數
如果你使用了自己的損失函數,請檢查它是否存在錯誤並添加單元測試。通常我的損失會稍微不正確,並以微妙的方式傷害網路的性能。
19.驗證損失輸入
如果你正在使用由你的框架提供的損失函數,請確保你傳給它正確的內容。例如,在PyTorch中,我會混合NLLLoss和CrossEntropyLoss,因為前者需要一個softmax輸入,而後者不用。
20.調整損失權重
如果你的損失由幾個較小的損失函數組成,請確保其相對於每個函數的幅度是相對正確的。這可能涉及測試不同的損失權重組合。
21.監控其他指標
有時損失不是你的網路是否正確被訓練的最佳預測因素。如果可以,請使用其他指標,如準確性。
22.測試任何自定義層
你自己在網路中實現了哪些層?檢查並仔細檢查以確保它們按預期運作。
23.檢查「凍結」層或變數
檢查是否無意中禁用應該可以學習的某些層/變數的梯度更新。
24.增加網路規模
也許網路的表現力不足以捕獲目標函數。嘗試在完全連接的層中添加更多層或更多的隱藏單元。
25.檢查隱藏的維度錯誤
如果你的輸入看起來像(k,H,W)=(64,64,64),很容易錯過與錯誤維度相關的錯誤。對輸入維度使用奇數(例如,每個維度使用不同素數),並檢查它們如何通過網路傳播。
26. 探索梯度檢查
如果你手動實施梯度下降法,則梯度檢查確保你的反向傳播能夠正常運作。 更多信息:1 2 3。
訓練問題
27.解決一個真正的小數據集
過度擬合一小部分數據,並確保它可以正常運作。 例如,訓練1或2個樣本,看看你的網路是否可以學習區分這些。每一類別用更多的樣本來進行。
28.檢查初始化權重
如果不確定,使用Xavier或He初始化。 此外,你的初始化可能會導致你的本地最小值不足,因此請嘗試其他初始化,看是否有幫助。
29.改變你的超級參數
也許你使用一套特別糟糕的超級參數。 如果可行,請嘗試網格搜索。
30.減少正規化
太多的正規化可能會導致網路的過度欠擬合。減少正規化,如dropout,批量規範,權重/偏倚L2正規化等。厲害的「編程人員實踐深層次學習」課程中,傑里米·霍華德(Jeremy Howard)建議首先排除欠擬合。 這意味著你充分訓練數據達到過擬合,屆時再解決這個問題。
31. 給它時間
也許你的網路需要更多的時間訓練才能開始有意義的預測。如果你的損失在穩定下降,那就讓它進行更多的訓練。
32. 從訓練切換到測試模式
某些框架的層,如「批次規範」,「dropout」和其他層,在訓練和測試期間的反應不一樣。切換到適當的模式可能會幫助你的網路正常進行預測。
33.可視化訓練
監控每個層的激活,權重和更新。確保他們的大小匹配。例如,參數更新的大小(權重和偏差)應為1-e3。
考慮像Tensorboard和Crayon這樣的可視化庫。必要時,你也可以輸出權重/偏差/激活。注意層次激活中那些均值大於,嘗試批量標準或ELU。
Deeplearning4j指出了權重和偏差直方圖應該要有的樣子:「對於權重,這些直方圖在一段時間後應該具有近似高斯(正態)的分布。對於偏差,這些直方圖通常從0開始,最終會達到近似高斯(LSTM是個例外)。注意分歧為+/-無限遠的參數。留意變得非常大的偏差。這有時會在分類的輸出層發生,如果類別的分布非常不平衡。檢查層更新,它們應該具有高斯分布。
34.嘗試一個不同的優化器
你選擇的優化器不應該阻止你的網路進行訓練,除非你選擇了特別糟糕的超級參數。然而,適當的優化器可以幫助你在最短的時間內獲得最多的訓練。你正在使用的演算法的論文應該會指定優化器。如果沒有,我傾向於使用Adam或plain SGD包含momentum。
35.爆發/消失的梯度
檢查層更新,因為非常大的值可能代表梯度的爆發。梯度剪裁可能有幫助。檢查層激活。Deeplearning4j提出了一個很好的指導方針:「激活的良好標準差為0.5到2.0。超出此範圍可能代表激活的消失或爆炸。」
36.增加/減少學習率
低學習率會導致模型收斂速度非常慢。高學習率一開始就會很快減少損失,但可能難以找到一個很好的解決方案。 試試將當前學習速率乘以0.1或10。
37. 克服NaNs
在訓練RNNs的時候(我聽說),發生NaN(Non-a-Number)是一個更大的問題。 一些解決方法:降低學習率,特別是如果在前100次迭代中獲得NaNs。 NaN可以由零或自然對數零或負數產生。
羅素·斯圖爾特(Russell Stewart)對如何處理NaNs有很大的指導。 嘗試逐層評估網路,並查看NaN出現的位置。
翻譯:Young Lin
編輯:小咪
推薦閱讀:
※word embedding之GLOVE代碼
※機器學習與數據挖掘網上資源
※[貝葉斯六]之樸素貝葉斯分類器設計
※機器學習基礎與實踐(二)----數據轉換
※ML4-Brief Introduction of Deep Learning(李宏毅筆記)
TAG:學習方法 | 深度學習DeepLearning | 機器學習 |