【翻譯 - CS229】對於機器學習應用的建議

前言:本文譯自Andrew Ng的Advice for Applying Machine Learning,原文PDF鏈接為cs229.stanford.edu/mate。筆者認為,在整個CS229的課程中,這節課的內容在實踐中最有價值。其他幾節課涉及到數學和演算法,看英文原文的學習效果也許好於看中文的二手翻譯。但是,我個人認為很有必要把這節課的內容翻譯為中文(而且目前知乎上好像還沒有翻譯)。

內容簡介

  • 如何將機器學習演算法,應用到不同的實際問題上?
  • 不會涉及太多數學
  • 本文的一些內容可能有爭議
  • 本文的一些內容可能不適用於機器學習的學術研究

目錄

1. 診斷:學習演算法出了什麼bug?怎麼debug?

2. 誤差分析

3. 開始解決實際問題

1. 診斷:學習演算法出了什麼bug?怎麼debug?

機器學習應用中,最常見的問題就是誤差太高。針對這種問題,最好的解決方案就是:

  • 進行診斷(diagnostics),找到問題所在。
  • 解決這個問題。

而不是費時費力地把各種方法都試一遍。

導致誤差太高的問題,可以分為三大類:過擬合 vs. 欠擬合優化目標不合適 vs. 優化演算法不合適與項目的現實目標不符合

【例1】我們在做一個垃圾郵件分類的機器學習演算法,特徵是100個單詞。我們選擇的模型是貝葉斯邏輯回歸(Bayesian Logistic Regression,簡稱BLR),訓練方法是梯度下降(gradient descent)。很可惜,模型在測試集上的誤差率是20%,高得不可接受。

BLR的目標函數:J(	heta)=max_{	heta}{sum_{i=1}^{m}{log{p(y^{(i)}|x^{(i)}, 	heta)}} - lambda||	heta||^2}

可選的改進方式:

  • 獲取更多的訓練樣本:解決high variance
  • 減少特徵:解決high variance
  • 增加特徵:解決high bias
  • 改變特徵(如電子郵件的標題、郵件的主體):解決high bias
  • 增加梯度下降的迭代次數:解決優化演算法問題
  • 用牛頓法(Newton"s Method)優化:解決優化演算法問題
  • 改變正則化參數lambda:解決優化目標問題
  • 改用其他模型,如SVM:解決優化目標問題

我們需要根據診斷結果,做出針對性的改進。

1.1. 過擬合與欠擬合

  • 過擬合overfitting):模型在訓練集上的表現很好(誤差小),但是在測試集上的表現很糟糕(誤差大),缺乏泛化(generalization)能力;模型有high variance;往往是因為模型太複雜、特徵太多。

  • 欠擬合underfitting):模型在訓練集和測試集上的表現都很糟糕(誤差大);模型有high bias;往往是因為模型太簡單、特徵太少。

【診斷方法】Learning Curve

Learning Curve圖的橫軸表示訓練樣本的大小,縱軸表示誤差率(模型在訓練集上的誤差率、模型在測試集上的誤差率)

上圖是high variance的learning curve。隨著訓練樣本的增多,測試誤差一直在逐步減小,但是仍然和訓練誤差之間存在比較大的間隙。所以,加入更多的訓練樣本有助於模型的表現。

上圖是high bias的learning curve。訓練誤差和測試誤差之間的間隙很小,但是它們都遠遠高於理想的誤差率。

1.2. 優化目標與優化演算法

【例1】(續)

  • BLR在垃圾郵件上的誤差率是2%,在正常郵件上的誤差率是2%。(後者誤差率太高)

  • SVM在垃圾郵件上的誤差率是10%,在正常郵件上的誤差率是0.01%。(情況比較理想)

  • 我們希望保留BLR,因為BLR計算效率更高。

診斷方法:

  1. 假設BLR的參數為	heta_{BLR},SVM的參數為	heta_{SVM}
  2. 其實,我們真正在乎的目標是加權準確率a(	heta) = max_{	heta}{sum_{i}{w^{(i)}I[h_	heta(x^{(i)})=y^{(i)}]}}。但是,由於它的種種特性(最主要的是不可導),各個模型採用的優化目標都有所不同,這可能會導致模型的優化目標和整體目標不符合。
  3. 從觀察中,我們發現a(	heta_{SVM}) > a(	heta_{BLR})
  4. 診斷的關鍵步驟:是否有J(	heta_{SVM}) > J(	heta_{BLR})?(J是BLR的目標函數)

情況1:J(	heta_{SVM}) > J(	heta_{BLR}),這說明優化演算法有問題,很有可能沒有收斂,	heta_{BLR}未能成功最大化J

情況2:J(	heta_{SVM}) leq J(	heta_{BLR}),這說明優化目標有問題,J(	heta)並不能很好地代表a(	heta)

1.3. 項目的現實目標

【例2】用強化學習(Reinforcement Learning)訓練直升機駕駛程序。

研究步驟:

  1. 建造直升機模擬器
  2. 確定目標函數J(	heta)
  3. 在模擬器中運行RL演算法,得到理想參數:	heta_{RL} = arg min_{	heta}{J(	heta)}

  4. 遇到的問題:駕駛程序的表現比人類駕駛員糟糕很多

三種可能的改進方式:

  1. 改進模擬器
  2. 修改目標函數J(	heta)

  3. 修改RL演算法

這個問題,和之前的優化目標問題類似。在現實中,我們並不是在優化J(	heta),而是在優化一個隱藏的目標函數——直升機開得好不好。J(	heta)只是後者的一個代表。

理想情況下,滿足以下三個條件,自動駕駛程序必然能很好地開直升機:

  1. 直升機模擬器能準確反映現實。
  2. 在模擬器內,RL演算法能成功最小化J(	heta)以得到	heta_{RL} = arg min_{	heta}{J(	heta)}
  3. 優化J(	heta),對應著優秀的直升機自動駕駛。

所以,這三個假設中的一個或多個肯定出了問題。我們可以使用如下的方法診斷:

  1. 程序能在模擬器內開好直升機,但是在現實中開不好 -> 模擬器有問題
  2. 假設人類駕駛員的參數是	heta_{human}J(	heta_{human}) < J(	heta_{RL}) -> RL演算法有問題
  3. J(	heta_{human}) geq J(	heta_{RL}) -> 目標J(	heta)有問題,優化J(	heta)並不代表程序能優秀地自動駕駛

即使沒有遇到嚴重的問題,這種診斷也有不少用處:

  • 理解這個應用問題的本質
  • 方便寫論文,診斷結果能夠為讀者提供很好的insight
  • 從「這個演算法能用」,到「這個演算法因為XXXX所以有用,以下是我的證明……」

2. 誤差分析

誤差分析:分析演算法的誤差來自於哪一個部分。機器學習里有pipeline這種概念,指的是一系列組成部分,結合在一起形成的pipeline整體,如下圖所示:

說到誤差分析時,中文裡似乎沒有區分以下兩種不同分析思路的詞,所以我將保留英文原文。

  • error analysis:試圖分析當前表現完美表現之間的差距,並找到能解釋這種差距的因素
  • ablative analysis:試圖分析基線表現(最糟糕)當前表現之間的差距,並找到能解釋這種差距的因素。

2.1. Error Analysis

從最現有的模型開始,我們把pipeline中每一個組成部分,依次替換為完美的ground truth label,並觀察準確率的提升情況。

如上圖所示,如果pipeline的某一部分,從目前的演算法變成完美的ground truth label後,沒有怎麼提升整體的準確率,那麼這個部分就沒有太大的改進空間。反之,則有極大的改進空間。在這個例子里,臉部識別和眼部識別的改進空間最大。

2.2. Ablative Analysis

從現有的模型開始,我們把它的每一個組成部分逐漸去掉,直到它變成最糟糕的基準線模型,並觀察整體準確率的改變。

如圖可見,郵件的標題特徵基本沒有提升準確率,郵件文本解析特徵極大提升了準確率。

3. 開始解決實際問題

在解決一個實際的機器學習問題時,通常有以下兩種基本思路:

3.1. 仔細設計(careful design)

花很長時間,設計出最好的特徵,收集最好的數據集,建造出最好的演算法架構。

優點:演算法的可擴展性強,可能可以找到新的、更優雅的學習演算法。

缺點:可能會導致過早優化(俗話說,過早優化是萬惡之源)

3.2. 先建模型,再逐漸改進(build-and-fix)

先建立起一個簡單模型,然後進行誤差分析和診斷,再逐步改進。

優點:速度快,耗時短。

缺點:可能不適用於學術研究,如探索新的機器學習演算法。

推薦閱讀:

最簡單的 GAN 解釋 (生成對抗網路)
CS 294: Deep Reinforcement Learning(11)
Python · 樸素貝葉斯(二)· MultinomialNB
圖像識別:基於位置的柔性注意力機制

TAG:机器学习 | 人工智能 |