可視化展示神經網路是如何將分類正確率提升的

這篇不會涉及數學推導和模型原理,僅從可視化的角度形象說明神經網路到底如何提升分類正確率

問題背景

N個64維向量,其標籤分為兩類,記為類別1和類別2,類別1以紅色表示,類別2以藍色表示。

使用PCA技術,將64維向量的3個最強主成分抽取出來,繪圖得到如下結果

從中可以看到,

  1. 除了一個維度的坐標範圍是[4,8.5],其餘兩個維度坐標範圍都在1e-15及以下,說明原始數據的主成分集中在一維
  2. 紅色和藍色混雜分布,考慮到上一點提到的一維分布,不易區分

其中第一點可以進一步圖形化說明,將PCA降維之後強度最強的一個維度的分布畫出來,如下所示。這裡為了避免紅色和藍色點重疊,將紅色點的縱坐標向上移動一個微小距離,將藍色點的縱坐標向下移動一個微小距離。

從上面的一維分布圖來看,紅色和藍色存在大量的交叉區域,僅僅依靠PCA找到的最強的一個維度,不可能得到性能較好的分類器。

神經網路是怎麼解決這個問題的

由於涉及到實際工作中的數據、程序,這裡略過神經網路設計和訓練方法細節,只考慮神經網路的全連接輸出層。全連接層的輸入是形狀(N,64)的tensor,輸出是(N,2)的tensor。這裡(?,64)的tensor可以視為神經網路將64維數據經過複雜運算之後投射到另一個64維空間的效果,將其畫出來圖形如下

對比原始數據的分布,上圖可以看出兩點

  1. 3個維度的取值範圍接近,說明神經網路將此前集中在1維的數據成功的擴展到3維空間,數據點之間的距離加大,分類面可以更好的將這兩類數據分開
  2. 紅色數據的集中度明顯提升,這樣分類的難度就會大幅下降,分類精度提升

雖然以上圖這個角度,肉眼看紅藍兩色在一些區域還有重疊,但考慮到這是3D圖,通過旋轉等手段可以看到這些點只是在視覺的投影方向重疊,其本身在3維空間並沒有重疊。同時,3個維度的取值範圍都很大,神經網路已經可以在這些點之間生成靈活的分類面,實現高精度(Precision和Recall都達到99%)的分類器

總結

本文通過原始數據和經過訓練的神經網路運算得到的輸出層數據可視化對比,說明了神經網路是如何將原始數據投射到另一個空間,實現高準確率分類。

同時,這種可視化也是一種理解和檢驗神經網路性能的方法,通過將輸入和輸出數據可視化可以看出神經網路的性能以及可能的問題。

推薦閱讀:

TAG:深度學習DeepLearning | 神經網路 | 機器學習 |