深度學習任務面臨非平衡數據問題?試試這個簡單方法
來自專欄我是程序員
摘要:本文使用最簡單的過採樣方法解決駝背鯨魚分類任務,取得了很好的效果。
對於數據科學或機器學習研究者而言,當解決任何機器學習問題時,可能面臨的最大問題之一就是訓練數據不平衡的問題。本文將嘗試使用圖像分類問題來揭示訓練數據中不平衡類別的奧秘。
數據不平衡問題是什麼?
在一個分類問題中,當你想要預測一個或多個類中的樣本數量極少時,可能會遇到數據中類不平衡的問題,即部分類的樣本數量遠遠大於其它類中的樣本數量。
例子
- 欺詐預測(真實交易的欺詐數量要低得多);
- 自然災害預測(壞事件發生的頻率將遠遠低於好事);
- 識別圖像分類中的惡性腫瘤(具有腫瘤的圖像將比訓練樣本內的無腫瘤的圖像少得多);
為什麼這會是個問題?
不平衡課程造成問題主要是由於以下兩個原因:
- 由於模型/演算法從來沒有充分地查看全部類別信息,對於實時不平衡的類別沒有得到最優化的結果;
- 由於少數樣本類的觀察次數極少,這會產生一個驗證或測試樣本的問題,即很難在類中進行表示;
解決這個問題的方法有哪些?
解決這個問題的方法主要有三種,三種各有各自的優缺點:
- 下採樣(Undersampling):隨機刪除具有足夠觀察多樣本的類,以便數據中類的數量比較平衡。雖然這種方法非常簡單,但很有可能刪除的數據中可能包含有關預測的重要信息。
- 過採樣(Oversampling):對於不平衡類(樣本數少的類),隨機地增加觀測樣本的數量,這些觀測樣本只是現有樣本的副本,雖然增加了樣本的數量,但過採樣可能導致訓練數據過擬合。
- 合成取樣(SMOT):該技術要求綜合地製造不平衡類的樣本,類似於使用最近鄰分類。問題是當觀察的數目是極其罕見的類時不知道怎麼做。 儘管每種方法都有各自的優點,但沒有什麼固定的使用方式,需要根據實際問題不斷自己嘗試。現在將使用深度學習特定的圖像分類問題來詳細研究這個問題。
圖像分類中的不平衡類
在本節中,將分析一個圖像分類問題(其中存在不平衡類問題),然後使用一種簡單有效的技術來解決它。
問題:在kaggle上選擇了「駝背鯨識別挑戰」任務,期望解決不平衡類別的挑戰(理想情況下,所分類的鯨魚數量少於未分類的鯨類)。
Kagele上任務說明:在這場比賽中,面臨的挑戰是要建立一個演算法來識別圖像中的鯨魚種類。將分析Happy Whale資料庫(包含25,000多張圖像),這些數據來自研究機構和公共貢獻者。通過競賽,你將有助於為全球海洋哺乳動物種群動態開啟豐富的理解領域。查看Happy Whale數據集
由於這是一個多標籤圖像分類問題,首先想要檢查數據是如何在類中分布的。
上圖表明,在4251張訓練圖像中,每個類只有一張圖像的超過了2000張。還有一些類只有2~5張圖像。可見這是一個嚴重的不平衡類問題。我們不能期望深度學習模型每個類別僅使用一張圖像進行訓練。這也會產生一個問題,即如何在訓練和驗證樣本之間創建一個分界線,理想情況下希望每個類都在訓練樣本和驗證樣本中都有表示。
接下來應該做什麼?
本文考慮了兩個特別的選項:
- 選項1:對訓練樣本進行嚴格的數據增強(只需要針對特定類的數據增強,單這可能無法完全解決本文的問題)。
- 選項2:類似於之前提到的過採樣技術。只是使用不同的圖像增強技術將不平衡類的圖像複製到訓練數據中15次。 在開始使用選項2處理數據之前,可以從訓練樣本中查看少量圖像。
從圖像中可以看到,圖像是特定於鯨魚的尾巴,因此,識別將可能與圖像的方向有關。同時注意到數據中有很多圖像是特定的黑白或只有R/G/B通道。
根據這些觀察結果,使用以下代碼對訓練樣本中不平衡類的圖像進行小幅改動並保存:
import osfrom PIL import Imagefrom PIL import ImageFilterfilelist = train[Image].loc[(train[cnt_freq]<10)].tolist()for count in range(0,2): for imagefile in filelist: os.chdir(/home/paperspace/fastai/courses/dl1/data/humpback/train) im=Image.open(imagefile) im=im.convert("RGB") r,g,b=im.split() r=r.convert("RGB") g=g.convert("RGB") b=b.convert("RGB") im_blur=im.filter(ImageFilter.GaussianBlur) im_unsharp=im.filter(ImageFilter.UnsharpMask) os.chdir(/home/paperspace/fastai/courses/dl1/data/humpback/copy) r.save(str(count)+r_+imagefile) g.save(str(count)+g_+imagefile) b.save(str(count)+b_+imagefile) im_blur.save(str(count)+bl_+imagefile) im_unsharp.save(str(count)+un_+imagefile)
以上代碼對不平衡類中的每張圖像(頻率小於10)都進行如下處理:
- 將每張圖像的增強副本保存為R / B&G ;
- 保存每張圖像的增強副本;
- 保存每張圖像未銳化的增強副本; 在上面的代碼中可以看到,使用pillow庫來嚴格執行此練習,現在已經為所有不平衡的類分配了至少10個樣本。接下來進行訓練。
圖像增強:只想確保模型能夠獲得鯨魚fluke的詳細視圖。為此,將縮放合併成圖像增強。
學習率設定:從圖中可以看到,將學習率定為0.01時效果最好。
使用Resnet50模型(第一層參數不變)進行了很少的迭代訓練就能取得很好的效果,這是由於imagenet資料庫中也有鯨魚圖像。
epoch trn_loss val_loss accuracy 0 1.827677 0.492113 0.895976 1 0.93804 0.188566 0.964128 2 0.844708 0.175866 0.967555 3 0.571255 0.126632 0.977614 4 0.458565 0.116253 0.979991 5 0.410907 0.113607 0.980544 6 0.42319 0.109893 0.981097
測試數據集上效果如何?
在kaggle排行榜上可以看到模型在測試集上的效果,本文提出的解決方案在本次比賽中排名34,平均精度均值(MAP)為0.41928。
結論
有時候,最簡單的方法是最合乎邏輯的(如果你沒有更多的數據,只需要複製現有的數據,並有輕微的變化即可),也是最有效的。
數十款阿里雲產品限時折扣中,趕緊點擊領劵開始雲上實踐吧!
以上為譯文,由阿里云云棲社區組織翻譯。文章原標題《Deep Learning Tips and Tricks》譯者:海棠,審校:Uncle_LLD。
文章為簡譯,更為詳細的內容,請查看原文。更多技術乾貨敬請關注云棲社區知乎機構號:阿里云云棲社區 - 知乎
本文為雲棲社區原創內容,未經允許不得轉載。
推薦閱讀:
※2018年5月Top 10 機器學習開源項目
※RNN(循環神經網路)-2
※機器學習基石筆記4:機器學習可行性論證 上
※2.7 蒙特卡洛近似
※[貝葉斯四]之貝葉斯分類器設計
TAG:TensorFlow | 深度學習DeepLearning | 機器學習 |