【翻譯 - CS229】對於機器學習應用的建議
內容簡介:
- 如何將機器學習演算法,應用到不同的實際問題上?
- 不會涉及太多數學
- 本文的一些內容可能有爭議
- 本文的一些內容可能不適用於機器學習的學術研究
目錄:
1. 診斷:學習演算法出了什麼bug?怎麼debug?
2. 誤差分析
3. 開始解決實際問題
1. 診斷:學習演算法出了什麼bug?怎麼debug?
機器學習應用中,最常見的問題就是誤差太高。針對這種問題,最好的解決方案就是:
- 進行診斷(diagnostics),找到問題所在。
- 解決這個問題。
而不是費時費力地把各種方法都試一遍。
導致誤差太高的問題,可以分為三大類:過擬合 vs. 欠擬合,優化目標不合適 vs. 優化演算法不合適,與項目的現實目標不符合。
【例1】我們在做一個垃圾郵件分類的機器學習演算法,特徵是100個單詞。我們選擇的模型是貝葉斯邏輯回歸(Bayesian Logistic Regression,簡稱BLR),訓練方法是梯度下降(gradient descent)。很可惜,模型在測試集上的誤差率是20%,高得不可接受。
BLR的目標函數:
可選的改進方式:
- 獲取更多的訓練樣本:解決high variance
- 減少特徵:解決high variance
- 增加特徵:解決high bias
- 改變特徵(如電子郵件的標題、郵件的主體):解決high bias
- 增加梯度下降的迭代次數:解決優化演算法問題
- 用牛頓法(Newton"s Method)優化:解決優化演算法問題
- 改變正則化參數:解決優化目標問題
- 改用其他模型,如SVM:解決優化目標問題
我們需要根據診斷結果,做出針對性的改進。
1.1. 過擬合與欠擬合
- 過擬合(overfitting):模型在訓練集上的表現很好(誤差小),但是在測試集上的表現很糟糕(誤差大),缺乏泛化(generalization)能力;模型有high variance;往往是因為模型太複雜、特徵太多。
- 欠擬合(underfitting):模型在訓練集和測試集上的表現都很糟糕(誤差大);模型有high bias;往往是因為模型太簡單、特徵太少。
【診斷方法】Learning Curve
Learning Curve圖的橫軸表示訓練樣本的大小,縱軸表示誤差率(模型在訓練集上的誤差率、模型在測試集上的誤差率)
上圖是high variance的learning curve。隨著訓練樣本的增多,測試誤差一直在逐步減小,但是仍然和訓練誤差之間存在比較大的間隙。所以,加入更多的訓練樣本有助於模型的表現。上圖是high bias的learning curve。訓練誤差和測試誤差之間的間隙很小,但是它們都遠遠高於理想的誤差率。1.2. 優化目標與優化演算法
【例1】(續)
- BLR在垃圾郵件上的誤差率是2%,在正常郵件上的誤差率是2%。(後者誤差率太高)
- SVM在垃圾郵件上的誤差率是10%,在正常郵件上的誤差率是0.01%。(情況比較理想)
- 我們希望保留BLR,因為BLR計算效率更高。
診斷方法:
- 假設BLR的參數為,SVM的參數為。
- 其實,我們真正在乎的目標是加權準確率。但是,由於它的種種特性(最主要的是不可導),各個模型採用的優化目標都有所不同,這可能會導致模型的優化目標和整體目標不符合。
- 從觀察中,我們發現
- 診斷的關鍵步驟:是否有?(是BLR的目標函數)
情況1:,這說明優化演算法有問題,很有可能沒有收斂,未能成功最大化。
情況2:,這說明優化目標有問題,並不能很好地代表。
1.3. 項目的現實目標
【例2】用強化學習(Reinforcement Learning)訓練直升機駕駛程序。
研究步驟:
- 建造直升機模擬器
- 確定目標函數
- 在模擬器中運行RL演算法,得到理想參數:
- 遇到的問題:駕駛程序的表現比人類駕駛員糟糕很多
三種可能的改進方式:
- 改進模擬器
- 修改目標函數
- 修改RL演算法
這個問題,和之前的優化目標問題類似。在現實中,我們並不是在優化,而是在優化一個隱藏的目標函數——直升機開得好不好。只是後者的一個代表。
理想情況下,滿足以下三個條件,自動駕駛程序必然能很好地開直升機:
- 直升機模擬器能準確反映現實。
- 在模擬器內,RL演算法能成功最小化以得到。
- 優化,對應著優秀的直升機自動駕駛。
所以,這三個假設中的一個或多個肯定出了問題。我們可以使用如下的方法診斷:
- 程序能在模擬器內開好直升機,但是在現實中開不好 -> 模擬器有問題
- 假設人類駕駛員的參數是。 -> RL演算法有問題
- -> 目標有問題,優化並不代表程序能優秀地自動駕駛
即使沒有遇到嚴重的問題,這種診斷也有不少用處:
- 理解這個應用問題的本質
- 方便寫論文,診斷結果能夠為讀者提供很好的insight
- 從「這個演算法能用」,到「這個演算法因為XXXX所以有用,以下是我的證明……」
2. 誤差分析
誤差分析:分析演算法的誤差來自於哪一個部分。機器學習里有pipeline這種概念,指的是一系列組成部分,結合在一起形成的pipeline整體,如下圖所示:
說到誤差分析時,中文裡似乎沒有區分以下兩種不同分析思路的詞,所以我將保留英文原文。
- error analysis:試圖分析當前表現與完美表現之間的差距,並找到能解釋這種差距的因素
- ablative analysis:試圖分析基線表現(最糟糕)與當前表現之間的差距,並找到能解釋這種差距的因素。
2.1. Error Analysis
從最現有的模型開始,我們把pipeline中每一個組成部分,依次替換為完美的ground truth label,並觀察準確率的提升情況。
如上圖所示,如果pipeline的某一部分,從目前的演算法變成完美的ground truth label後,沒有怎麼提升整體的準確率,那麼這個部分就沒有太大的改進空間。反之,則有極大的改進空間。在這個例子里,臉部識別和眼部識別的改進空間最大。2.2. Ablative Analysis
從現有的模型開始,我們把它的每一個組成部分逐漸去掉,直到它變成最糟糕的基準線模型,並觀察整體準確率的改變。
如圖可見,郵件的標題特徵基本沒有提升準確率,郵件文本解析特徵極大提升了準確率。3. 開始解決實際問題
在解決一個實際的機器學習問題時,通常有以下兩種基本思路:
3.1. 仔細設計(careful design)
花很長時間,設計出最好的特徵,收集最好的數據集,建造出最好的演算法架構。
優點:演算法的可擴展性強,可能可以找到新的、更優雅的學習演算法。
缺點:可能會導致過早優化(俗話說,過早優化是萬惡之源)
3.2. 先建模型,再逐漸改進(build-and-fix)
先建立起一個簡單模型,然後進行誤差分析和診斷,再逐步改進。
優點:速度快,耗時短。
缺點:可能不適用於學術研究,如探索新的機器學習演算法。
推薦閱讀:
※最簡單的 GAN 解釋 (生成對抗網路)
※CS 294: Deep Reinforcement Learning(11)
※Python · 樸素貝葉斯(二)· MultinomialNB
※圖像識別:基於位置的柔性注意力機制