深度學習中雜訊標籤的影響和識別

問題導入

在機器學習領域中,常見的一類工作是使用帶標籤數據訓練神經網路實現分類、回歸或其他目的,這種訓練模型學習規律的方法一般稱之為監督學習。在監督學習中,訓練數據所對應的標籤質量對於學習效果至關重要。如果學習時使用的標籤數據都是錯誤的,那麼不可能訓練出有效的預測模型。同時,深度學習使用的神經網路往往結構複雜,為了得到良好的學習效果,對於帶標籤的訓練數據的數量也有較高要求,即常被提到的大數據或海量數據。

矛盾在於:給數據打標籤這個工作在很多場景下需要人工實現,海量、高質量標籤本身費時費力,在經濟上相對昂貴。因此,實際應用中的機器學習問題必須面對噪音標籤的影響,即我們拿到的每一個帶標籤數據集都要假定其中是包含雜訊的。進一步,由於樣本量很大,對於每一個帶標籤數據集,我們不可能人工逐個檢查並校正標籤。

基於上述矛盾現狀,在實際工作中必須面對以下兩點問題

1. 訓練集帶標籤樣本中噪音達到什麼水平對於模型預測結果會有致命影響

2. 對於任意給定帶標籤訓練集,如何快速找出可能是噪音的樣本

本文接下來將圍繞這兩點通過實驗給出分析

數據、神經網路設計和代碼

本文以Tensorflow教程中提及的MNIST問題[1]為數據來源和問題定義。此問題為圖像識別問題,圖片為手寫的0-9字元,每個圖片格式為28*28灰度圖。訓練集數據包括55000張手寫數字和標籤,驗證集包括約10000張圖片和標籤。通過訓練神經網路從而實現當輸入一張驗證集中的圖片後,神經網路能夠正確預測這張圖片的標籤。

對於MNIST問題本身,Tensorflow教程[2]描述的包含2個卷積池化層的CNN網路已經足以實現99%左右的預測精度,因此在本實驗中,筆者直接引用Tensorflow官方樣例中的CNN網路[3]作為預測模型的神經網路。

本文所有代碼可以在筆者的Github項目中獲得:wangyaobupt/NoisyLabels

雜訊標籤對於分類器性能的影響

考慮到MNIST是機器學習領域使用多年的資料庫,且在其數據上訓練的模型已經得到了較好的結果,由此可以合理推斷其標籤本身的雜訊含量較低(這個推理將在下一個章節通過實驗證實)。因此,在這一節的實驗中,我們假定原始的MNIST的訓練集和驗證集標籤都是無雜訊的。

使用如下步驟給標籤添加雜訊

1. 根據給定的雜訊比例noiseLevel,從N個總樣本中選擇出K個樣本,K = N*noiseLevel

2. 對於選出的K個樣本中的每一個樣本,將其原始標籤替換為0-9之間扣除原始標籤之外的隨機數

上述演算法的代碼實現如下,testcase2.py提供了完整的可執行程序

# Add random noise to MNIST training set# input:# mnist_data: data structure that follow tensorflow MNIST demo# noise_level: a percentage from 0 to 1, indicate how many percentage of labels are wrongdef addRandomNoiseToTrainingSet(mnist_data, noise_level): # the data structure of labels refer to DataSet in tensorflow/tensorflow/contrib/learn/python/learn/datasets/mnist.py label_data_set = mnist_data.train.labels #print label_data_set.shape totalNum = label_data_set.shape[0] corruptedIdxList = randomSelectKFromN(int(noise_level*totalNum),totalNum) #print DEBUG: 1st elements in corruptedIdxList is: , corruptedIdxList[0], length = , len(corruptedIdxList) for cIdx in corruptedIdxList: #print "DEBUG: convert index = ", cIdx correctLabel = label_data_set[cIdx] #print DEBUG: Correct label = , correctLabel wrongLabel = convertCorrectLabelToCorruptedLabel(correctLabel) #print DEBUG: Wrong label = , wrongLabel label_data_set[cIdx] = wrongLabel# uniform randomly select K integers from range [0,N-1]def randomSelectKFromN(K, N): #print DEBUG: K = ,K, N = , N resultList =[] seqList = range(N) while (len(resultList) < K): index = (int)(np.random.rand(1)[0] * len(seqList)) #index = 0 # for DEBUG ONLY resultList.append(seqList[index]) seqList.remove(seqList[index]) #print resultList return resultList# Convert correct ont-hot vector label to a wrong label, the error pattern is randomly selected, i.e. not considering the content of imagedef convertCorrectLabelToCorruptedLabel(correctLabel): correct_value = np.argmax(correctLabel, 0) target_value = int(np.random.rand(1)[0]*10)%10 if target_value == correct_value: target_value = ((target_value+1) % 10) result = np.zeros(correctLabel.shape) result[target_value] = 1.0 return result

這樣,當給定雜訊水平之後,上述演算法完成添加雜訊,進一步用帶雜訊的訓練集訓練出模型,最終在驗證集上對模型評價精度。下圖是雜訊標籤比例在0-100%範圍內變化時,模型精度的變化。

從上圖可以看出,在雜訊標籤佔比不超過60%的情況下,驗證集精度保持在96%以上,即便雜訊標籤佔比達到70%,驗證集精度仍然能達到93%。在雜訊標籤佔比超過70%之後,精度結果快速下降,當雜訊佔比達到88%時,預測精度已經下降到7%。這個水平已經低於純隨機預測,考慮到此問題為10分類問題,在完全隨機的情況下,預期精度的數學期望也在10%左右。

這裡就引出了兩個問題:

1. 為什麼在噪音標籤佔比70%的情況下,模型抗雜訊性能這麼好?

2. 70%之後的快速下降又是由什麼導致的

為了回答上述問題,要重新審視此前加雜訊標籤的方法。在加雜訊的第一步,我們均勻的隨機抽取出一定比例的標籤,考慮到原始數據10類標籤的分布是基本均勻的,那麼抽出來的K個樣本中10類標籤的數量基本一致。在第二步,對於每個標籤,我們將正確標籤抹去,從正確標籤之外的9個字元中選擇一個作為標籤,由於選擇演算法本身也是隨機的,那麼,錯誤標籤是均勻分布在其他9類的。

上述解釋如果還不夠直觀,那麼可以看下圖。假設有1000條正確標籤為2的數據,在70%的雜訊條件下,只有300條數據標籤為2,其餘700條數據的標籤均勻分布在其他9類。這樣,正確標籤(300條『2』標籤)相比其他任何一個類別,仍然佔有明顯數量優勢,所以CNN才可以根據這個數量優勢學習到正確標籤2.

而當雜訊比例進一步增加後,數量對比優勢會逐漸弱化,例如下圖。這種情況下正確標籤雖然佔比仍然多於其他分類,但是數量上已經只有2倍的差異。在模型訓練中,正確標籤帶來的梯度下降增益不足以對抗錯誤標籤的影響,神經網路傾向於學習到隨機標籤。

由上述兩張圖可以看出,如果在多分類問題中噪音標籤是均勻分布的,同時正確標籤相對於每個類別的錯誤標籤有數倍的數量優勢,那麼訓練過程有可能承受較高的雜訊標籤水平得到相對精確的模型。但如果噪音標籤已經與正確標籤數量接近,那麼很難訓練出有意義的模型。

如何快速識別出疑似雜訊的標籤

在真實應用中,我們顯然不會人工在訓練數據集上添加雜訊。但如前文所述,訓練數據集本身是含有雜訊的,除了人工逐個審查,有沒有辦法快速找出疑似是雜訊的標籤呢?

為了解決這個問題,我們回到基於CNN網路的MNIST分類器最後一層來看。在分類器的最後一層,全連接網路包含10個神經元,輸出10個運算結果,可以看作一個10維向量。這個10維向量經過softmax運算可以轉為離散概率分布,其和為1,每個維度代表分類器預測當前圖片屬於某一類的概率。最終的預測結果就是取離散概率分布中概率值最高的一類作為預測結果。

在實驗中觀察不同樣本的概率分布,可以看到有以下兩種情況

  • 當一張圖片清晰且無歧義時,神經網路輸出的離散概率分布是集中在一個標籤的,例如正確標籤概率為0.999,其餘9種類別的概率接近於0.
  • 當一張圖片存在歧義時,神經網路輸出的離散概率分布就不會只集中在一個標籤,有可能最強的標籤概率只有0.6,第二強的標籤概率0.39,其餘8個類別概率為0 這樣的結果意味著神經網路認為這張標籤有二義性。

基於這個認識,就可以設計出一種方法,讓神經網路把自己認為存在二義性的樣本和標籤篩選出來,即實現了非人工快速找出疑似噪音標籤。

下面是二義性判斷的代碼實現,二義性在這裡定量的定義為:分類器認為最有可能類別的概率低於70%,同時第二可能類別概率高於15%。下列代碼是挑選二義性概率分布的實現,是simpleCNN.py的一部分,testcase3.py提供了篩選二義性樣本的可執行程序

# Filter out images with low SNR. # The term low SNR is defined as: in the probability distribution of this sample, the largest value is <= 0.7, while the 2nd largest value >= 0.15 # the raw images data (in shape of 1*784 vector), labels, and top 2 possibilities by CNN will be returned # Parameter: # train_or_test, 0 means train data, 1 means test data def filterLowSNRSamples(self, mnist, train_or_test=0): if train_or_test == 1: data = mnist.test else: data = mnist.train resultList = [] for sample_idx in range(data.images.shape[0]): prob_dist, label=self.sess.run([self.output_prob_distribution, self.label], feed_dict={ self.x: np.reshape(data.images[sample_idx], (1, 784)), self.y_: np.reshape(data.labels[sample_idx], (1,10)), self.keep_prob:1.0}) raw_prob_array = prob_dist[0] #search for position of the largest value and the 2nd largest value top_1_pos , top_2_pos = findPosOfLargestTwoElement(raw_prob_array, 10) #Low SNR criteria if raw_prob_array[top_1_pos] <= 0.7 and raw_prob_array[top_2_pos] >= 0.15: resultList.append((sample_idx, data.images[sample_idx], label, top_1_pos, top_2_pos)) if (sample_idx % 1000 == 0): print "DEBUG, current idx = %d, num_of_low_SNR = %d" % (sample_idx, len(resultList)) return resultList

使用這套方法,在MNIST的55000個訓練數據和標籤中篩選出408個疑似有二義性的圖片,下圖是部分典型圖片。由此來看,MNIST本身的標籤質量是較高的。下圖中不少標籤人工識別也存在困難,這恰恰說明了找出的標籤很大程度上就是「疑似雜訊標籤」

小結

本文對於MNIST數據集,使用CNN分類器,考察了雜訊對模型預測精度的影響,實驗結果表明,在均勻分布的隨機雜訊條件下,CNN模型可以在雜訊標籤佔比70%的情況下預測精度無明顯下降。進一步,為了識別原始訓練集中的疑似雜訊樣本,文中使用訓練好的CNN模型通過預測向量的概率分布,識別存在二義性的標籤,實現了低代價找出訓練集雜訊標籤的目的。

參考文獻

[1] tensorflow.org/get_star

[2] tensorflow.org/get_star

[3] TF MNIST code


推薦閱讀:

TAG:深度學習DeepLearning | 機器學習 | TensorFlow |