ICML 2017 Best Paper Award論文解讀(1)
------2017.12.16 更新PPT 鏈接-------
(本系列解讀文章涉及較多細節,篇幅較長,故分為三個部分,本篇為第一部分。)
ICML 2017 基本已經落下帷幕,在大會第一天就公布今年的Best Paper,課題組上周組會要討論這篇文章,於是我就花了幾天時間仔細的研究了這篇Best Paper:Understanding Black-box Predictions via Influence Functions。
為了組會講解方便,我也做了粗糙的一個PPT,傳送門在此。
目錄
第一部分
- 總體工作介紹
- 機器學習基本概念的數學表述
第二部分
- Influence Function演算法介紹(二階可微和強凸條件下)
- Influence Function快速計算(二階優化技巧)
- HVPs
- Conjecture Gradient Method
- Stochastic estimation
第三部分
- Influence Function擴展至非可微(Non-differentiable)和非凸(Non-convex)模型
- Influence Function 應用舉例
- 輔助理解機器學習模型到底學到了什麼
- 針對神經網路模型進行Training-Attacks
- 用於debug數據集的分布是否一致
- 用於對數據集中的錯誤標註進行修復
正文
1. 總體工作介紹
這篇工作最大的亮點,就是直面了一個許多深度學習和計算機視覺研究人員十分關心的一個問題,就是對於一個深度神經網路,它為什麼能輸出這樣的預測結果(Why did it make this prediction?)。換言之,就是我們一直以為是Black-box深度神經網路到底從訓練集中到底學到了什麼?。
更重要的是這篇文章所採用的一套非常嚴謹和規範化的做machine learning研究的範式。我想這也是它之所以被選為Best paper 的重要原因之一,也是各位大佬希望能對於機器學習和計算機視覺的研究風氣起到一個引領的作用,現在很多的研究只重視實驗數據的漂亮,而忽略了 對演算法本身的探討,比如很多研究會迴避掉為什麼這樣的模型和網路結構會work的問題,當然這也只是我個人的猜測。
1.1 提出問題
回到剛才提到的做研究的範式,首先是大家都經常說的一個話題,就是要能夠提一個好問題,同時把問題細化和分解,這篇文章的作者在Black-box的問題背景下提了一個非常好(counterfactual)的問題:既然決定一個模型預測結果的是模型的參數(或者稱之為權重),而模型的參數是取決於選用的演算法和數據,那我們更能不能進一步的通過跟蹤訓練數據的變化來觀察預測結果有什麼變化?
這篇文章更具體的來把問題分解,第一步,如果我們把訓練集中的一個樣本去掉,那我們重新訓練得到的新模型做出的預測會有什麼樣的變化?第二步,如果我們只是對訓練集中的一個樣本數據進行一個微小的擾動,那我的新的模型參數帶來的預測結果會有什麼樣的變化?為什麼要分兩步呢?其實第一步是第二步的基礎,在後續的演算法解讀的部分會進行細節的解讀。
除了提一個好的問題,這篇文章在理論推導和證明上也是可圈可點,首先提出了在二次可微和強凸條件下的證明,再通過一些優化的技巧,將方法推廣至非可微和非凸模型上。這種漸進的研究方式也是機器學習領域常用的套路。
1.2 引入Influence Function
在這樣的問題背景下,作者借鑒了robust statistics中的influence function的技術,用來對模型參數的變化進行衡量。具體的講就是,使用influence function研究了兩個問題,(1)是對training set中一個樣本數據加入一個擾動,觀察模型參數產生的變化;(2)是對tranining set中的一個樣本數據加入一個擾動,觀察對於test set中的一個樣本,相應的loss的變化。
這樣,通過influence function就可以判斷出訓練集中的樣本對於演算法模型的訓練有沒有影響,有多大的影響,是好的影響還是壞的影響。
對於很多研究者來說,更關心的是influence function能怎麼用,能怎麼樣幫助我修改我的演算法或者實驗過程,我們就先對應用場景進行簡單介紹,數學證明和推導放在後邊進行解讀。
1.3 Influence Function應用舉例
文章主要提了四個方面的應用:
(1)使用influence function,來判斷一個演算法模型到底學到了什麼信息。文章首先在SVM和CNN模型上分別訓練一個二分類分類器,然後通過influence function 計算得出哪些訓練集中圖片對模型參數影響較大,進而通過判斷不同的模型學到的到底是圖片的層次的信息。
(2)Security in Machine Learning也是最近機器學習領域比較熱門的研究方向,文章提出influence function可以指導如何進行Training-attacks,區別於Test-attacks,對於針對測試集中的圖片,選取訓練集中的一張或幾張圖片,在influence function的指導下,對訓練集中選取的圖片進行處理,得到視覺上幾乎不可見,卻又能是模型參數發生較大變化的新的訓練數據,進而使模型將測試集中的圖片誤判。通過這種方式,也可以指導我們涉及出一些安全保護機制,提升機器學習模型的安全性。
(3)Debug domian mismatch, 文章以測試集和訓練集數據分布不一致為例,舉例講解了如何使用influence function來找到哪些樣本對模型參數影響最大,對數據分布影響最大,可以幫助我們檢查和修正我們設計的資料庫。
(4)修正被錯誤標註的數據;現在數據標註的需求越來越大,而無論是專家標註,還是眾包平台的標註,都難免產生數據標註錯誤的情況,我們也可以使用influence function 來找到這些被標註錯誤的數據,進行修正。
1.4 Influence function 高效計算
在引入了influence function的基礎上,作者花了相當的筆墨在如何高效計算influence function上,通過使用二階優化(second-order optimization)的一些技巧,使influence function能夠基於現有的深度學習框架實現快速計算,作者也已經將相關的代碼開源。
2.機器學習基本概念的數學表述
作為一個搞視覺的孩子,初讀這種Machine Learning的文章,分分鐘被虐哭,也終於體會到數學基礎薄弱的痛苦。於是在這裡從頭對機器學習的一些概念,從統計學和概率角度進行簡單復盤,部分表述來自MIT 9.520(鏈接在最下方)和李航博士的統計學習方法。
2.1 統計學習定義
Learning is viewed as a generalization/inference problem from
usually small sets of high dimensional, noisy data.(from MIT 9.520 Lecture 2)
我們現在常說的機器學習,其實就是統計機器學習。
對於我們常說的監督學習(supervised learning),訓練集 ,輸入數據 ,標籤為 ,定義我們用於訓練的數據是獨立地從滿足分布 , 中採樣得到的,訓練集定義 : 。也即 。我們假設數據是滿足獨立同分布(i.i.d)的。
我們有以下兩點需要注意:
- 聯合數據分布 是固定的,同時也是未知的
- 我們的目標是得到條件概率 ,
2.2 泛化能力
機器學習的目標是能夠通過在訓練集訓練得到一個能在測試集或者未來數據上做出正確預測的模型,而我們常常用泛化能力(Gener)一詞來衡量一個模型的在測試集或者未來數據上的預測能力。
回到我們剛才提到的機器學習的目標,我們希望通過訓練得到一個函數 ,給定一個輸入 ,給出結果 。我們定義一個模型的假設空間(Hypothesis Spach) ,其中包含所有的函數 ,因此我們的目標就是找到 中,預測性能最好,泛化能力最強的函數 。因此定義損失函數 作為對預測性能的度量,針對Classification 和 Regression通常使用不同的損失函數。
有了Loss function,我們就可以通過它來找到最好的模型,由於模型的輸入輸出是隨機變數,且滿足聯合概率分布 ,我們就可以得到損失函數的期望: ,這個也是理論上模型 關於聯合概率分布 的平均意義下的損失,我們稱之為期望風險(Expectd Risk),由於聯合概率是未知的,我們無法計算得到 。於是我們引入另一個定義,經驗風險(Empirical Risk),記作 ,定義式為: ,根據大數定理,當樣本容量趨於無窮大的時候,經驗風險趨於期望風險,於是我們用經驗風險來近似估計期望風險。
2.3 ERM 與 SRM
基於上面談的幾點,在假設空間,損失函數和訓練數據都確定的情況下,我們可以採用ERM(Empirical Risk Minimization)策略來確定最優的模型,公式化的表達為: 當樣本容量足夠大時,ERM可以得到不錯的學習效果;但是當樣本容量很小時,ERM就容易出現過擬合。
SRM(Structural Risk Minimization)就是為了防止過擬合而提出來的策略,公式化表示為: ,其中 是懲罰項(或正則化項), 是係數。
通過以上的基本概念,我們看到監督學習問題也就變成了ERM或者SRM的優化問題。
(第一部分完)
參考文獻
[1] 統計學習方法 李航 清華大學出版社
[2] 9.520/6.860: Statistical Learning Theory and Applications, Fall 2016
[3] Understanding Black-box Predictions via Influence Functions
推薦閱讀:
※Coursera吳恩達《神經網路與深度學習》課程筆記(1)-- 深度學習概述
※10分鐘快速入門PyTorch (3)
TAG:深度学习DeepLearning | 神经网络 |