<EYD與機器學習>五:Support Vector Machines
各位知乎兒大家好,這是<EYD與機器學習>專欄的第五篇文章,這篇文章以《Hands-on Machine Learning with Scikit-Learn and TensorFlow》[1](後面簡稱為HMLST)第五章的內容為主線,其間會根據成員自己的理解對章節內容進行一定的調整和補充。
本章主要講解了一些關於支持向量機(SVM)的問題。支持向量機是大家比較熟悉的一個機器學習模型,其最重要的性質和特點是:訓練後的模型僅與支持向量有關。SVM主要用於模式識別、異常檢測,分類和回歸等方面。講解支持向量機的文獻眾多,本文側重於講解在Scikit_Learn框架下,SVM分類和回歸的演算法實現,問題集中於核函數的選擇與參數的調試,希望大家在看完這篇文章後可以嘗試找一個數據集利用SVM演算法對其進行分類或者回歸,這也是本章最後幾個習題要求大家做的。本文中針對SVM的理論推導不會涉及太多,如果各位知乎兒想要對理論推導方面有更深入的了解,可以查看HMLST中,第五章第四小節(P151-P159)以及附錄C中的內容。
Hands-On Machine Learning with Scikit-Learn and TensorFlow第五章:Support Vector Machines
支持向量機(SVM)是一種非常強大且多功能的機器學習模型,能夠執行線性或非線性分類,回歸,甚至異常值檢測。它是機器學習中最受歡迎的模型之一,任何對機器學習感興趣的人都應該將它放在工具箱中。支持向量機特別適用於複雜但小型或中型數據集的分類。
5.1 線性SVM分類
當我們學習一個方法時,總是習慣於由簡單切入,往複雜發展。本文也是一樣,讓我們先著眼於一個線性的分類問題,嘗試用線性SVM解決它。在本節中,依然採用了大家熟悉的鳶尾花數據集,選取的特徵為花瓣寬度和長度,故本節的問題總結起來應該是一個二維的分類問題:
如上圖所示,山鳶尾(Iris-Setosa)和雜色鳶尾(Iris-Verscicolor)在這個由花瓣寬度和花瓣長度組成的二維空間中,是一個線性可分的。我們有很多種分類方法可以實現它們的分類,如圖5.1左圖中紅線和紫線所示,相較於圖中綠色虛線那樣甚至無法正確分類的模型來說它們當然是好的。但是當模型中引入新變數時,這兩種分類似乎就有些雞肋,由於它們離樣本點太近了,新樣本稍有偏差就會被錯誤分類。SVM分類的內在思想就是不僅要找到可以正確劃分各類別的「線」,還期望我們所找到的「線」離現有的樣本點越遠越好(例如圖5.1右圖中那條黑色直線)。我們可以把SVM分類看作是在類與類之間找到一條最寬的街道(圖5.1右圖中的虛線所示),這也叫作最大間隔分類問題。
當我們訓練好一個SVM分類器後,如果向模型添加在虛線外的樣本點時,對目前的模型是沒有影響的,但是當添加的樣本點在虛線之間,那麼模型就會發生變化(「街道」會變窄)。那麼在虛線上的樣本點就有了特殊的意義,它們被稱為支持向量(圖中被紅色標亮的樣本點)。
值得注意的是,SVM對特徵的維度敏感,當你的數據維度差異較大時,記得在運用SVM分類器之前進行特徵縮放(StandardScale),這點非常重要。如圖5.2所示,左圖中縱坐標維度比橫坐標維度大太多,求得的最大間隔幾乎與橫軸平行,其間距很窄。但是當經過特徵縮放處理後再次進行SVM分類,我們就可以得到右圖中更好的分類效果。
5.1.1 軟間隔分類
上節中我們簡單粗暴用一條街道分開了所有樣本點,這樣的分類方法叫做硬間隔分類。但是現實問題並不總能保證數據的線性可分性,也不能保證沒有離群值的出現。當我們遇到一些如圖5.3所示的數據集時,硬間隔分類演算法通常表現缺缺:左圖中由於異常值導致數據集不再線性可分,故SVM演算法沒能給出分類方案;右圖中,為了解決出現在中間的異常值,SVM犧牲了間隔寬度去擬合異常點,直接影響了演算法的泛化能力。
為了避免以上問題,我們通常會在處理實際問題時使用軟間隔分類方法。軟間隔方法的目標是在最大化間隔和最小化違約懲罰之間尋求一個平衡(違約懲罰即對於不滿足約束條件的樣本點進行懲罰)。那麼在演算法中這種平衡要怎麼控制呢?讓我們從SVM的數學本質中尋找答案:
線性SVM分類器的目標是通過對數據集的學習,得到一組權重 和偏移量 ,當對新的樣本 進行預測時,通常計算決策方程 :當結果大於零,則判斷為正類(1),否則判斷為負類(0)。
讓我們把鳶尾花數據集放在一個三維的空間中查看(圖5.4),我們所得到的決策方程就是圖中網格代表的超平面。圖中用來劃分類別的判定邊界由一系列使決策方程等於0的點組成,也就是圖中兩個平面的交線(圖中黑色直線)。而圖中的虛線則分別代表了決策方程等於-1和1時的樣本點,這兩條虛線相互平行且關於決策邊界對稱。訓練一個線性SVM分類器的目的就是為了找到一組 和 ,最大化兩條虛線的間隔,避免(硬間隔)或最小化(軟間隔)不滿足約束條件的樣本點數目。
對於決策超平面來說,它的斜率為權重向量 的範數, 。如圖5.5所示,權值向量 取得越小,那麼得到的間隔越大。故SVM的目的是最小化 以最大化間隔。由於 在 處不可微,故在實際問題中我們通常將目標函數設為可微又相對簡潔的 ,因為這兩者求得的結果是一樣的。
當求解的目標確定了,我們還需要考慮的是問題的約束條件。對於一個硬間隔SVM分類演算法,我們需要始終保證對於所有正例,決策函數的值應該大於1;對於所有反例,決策函數的值則需要小於-1(誰也別想走到中間的「街道」上!)。定義對於所有正例( ), ;對於所有負例( ), 。那麼這個約束就可以表示為: 。那麼以下就是硬分隔SVM的表達式:
軟間隔分類器是基於硬分隔分類器的衍生,對於上式引入一個表徵樣本點 不滿足約束的程度的鬆弛變數 ,原式變為:
那麼現在目標函數中就包含了兩個對立的目標:1>最小化鬆弛變數以縮小間隔,減少不符合約束的樣本點;2>最小化去權重以增大間隔。超參數C扮演的正是權衡這兩個目標的角色,可以由人為預先設定,當C取無窮大時,上式等價於硬分隔演算法的表達式。下面我們將通過實例看一下超參數C的改變對模型的影響。
圖5.6展示了對於同樣的非線性可分數據集,當超參數C分別為1和100時的分類效果。C越小,則意味著更寬的間隔和更多的不滿足約束的樣本點。
在Scikit_Learn框架下,實現這一分類的演算法有以下兩種:
1.from sklearn.svm import LinearSVC lin_clf = LinearSVC(loss="hinge", C=1, random_state=42)2.from sklearn.svm import SVC svm_clf = SVC(kernel="linear", C=1)
演算法2的運算速度較演算法1慢很多,具體原因後面會詳細解釋。當處理線性分類問題,特別是大規模數據時,建議選用LinearSVC。除此之外,我們還可以使用第三章提到的隨機梯度下降演算法進行同樣的分類:SGDClassifier(loss=hinge,alpha=1/(m*C))。(PS:這裡反覆提及的hinge是一種衡量樣本點不滿足約束的程度的損失函數, )
5.2 非線性SVM分類
上節我們討論了線性分類的問題,但現實中存在很多線性不可分的問題。對於線性不可分問題,有一個解決方案是升維,即將原問題映射到更高維的空間中,增大其線性可分的概率。這樣說起來可能有些抽象,我們舉個例子:
圖5.7左圖展示的是在一維狀態下樣本點的分布情況,顯然,此數據集在一維下並不線性可分。但是當添加了一個新的特徵 後,如右圖所示,數據集就可以用圖中紅色虛線所代表的超平面進行線性分類。
基於這個思想,我們在處理當前維度的線性不可分問題時,先通過PolynomialFeatures()對數據集中的數據進行升維,將現有樣本的屬性投影到高維空間,以提高其線性可分的可能性,然後再用線性的SVM演算法對其進行分類。下面將對半月數據集上進行試驗:
polynomial_svm_clf = Pipeline(( ("poly_features",PolynomialFeatures(degree=3)), ("scaler",StandardScaler()), ("svm_clf",LinearSVC(C=10,Loss="hinge")) ))
5.2.1 多項式核
添加多項式特徵是有效提高模型線性可分性的一種簡單可行的方法,而且可廣泛應用於許多機器學習演算法。但是低度多項式無法解決複雜問題,高度多項式又會給模型增加太多參數,影響模型的運算速度。
由於SVM中核函數的存在,高度多項式增加太多參數的問題得到了有效解決。因為實際上,這些參數並沒有真正地被真正添加,這聽起來或許讓人摸不著頭腦,其中原理後面會在書中有詳細的解釋,有興趣的可以查看[1]。下面我們先通過實驗驗證一下效果,以下實驗的數據集依然是半月數據集。
from sklearn.svm import SVCsvm_clf = SVC(kernal="poly", degree=3, coef0=1, C=5)
圖5.9展示了C值和r值分別設定為不同值情況下,SVM的分類效果。在一個分類任務中,如果出現過擬合,可以適當的降低多項式度C。反之亦然。超參數coef0的作用則是控制高度項和低度項的權重比。那麼,還記得如何找到合適的超參數值嗎?(回顧第二章的網格搜索)
5.2.2 增加其他特徵
前文中,我們向原樣本增加了多項式特徵,以求提高樣本的線性可分性。除此之外,我們還可以增加其他類型的特徵,採用與多項式函數類似的其他函數對樣本特徵進行轉化。例如,向之前提到的簡單一維數據集中引入兩個地標點 ,然後我們向模型引入高斯徑向基函數(RBF),以 對原樣本點進行轉化。這樣聽起來可能有些抽象,圖5.10展示了轉化前後的樣本點。
高斯徑向基函數:
在圖5.10右圖中,我們分別以距離地標點 的距離作為橫縱坐標,繪製了經過高斯基函數映射後的樣本點。之前在一維空間中線性不可分的樣本點,經過高斯RBF的映射後,變成了線性可分的點(圖中紅色虛線所示)。將數據集中每個樣本點都設置為地標點是最簡單粗暴的提高模型線性可分性的方法。對於一個有m的樣本點,n維特徵的數據集,經過映射後,將變成一個擁有m個樣本點,m維特徵的新數據集。當然,這帶來的後果是你的模型的特徵數目會變得非常巨大。
5.2.3 高斯徑向基函數核
下面,我們又要請出神奇的核函數技術,即是5.2.1節提到的可以在不真正增加特徵的情況下,達到與添加了特徵一樣的效果的技術。這會有效解決上一節中映射結束後,特徵數目過大的問題。下面我們將在半月數據集上應用高斯核函數:
from sklearn.svm import SVCsvm_clf = SVC(kernal="rbf", gamma=5, C=0.001)
圖5.11展示了不同 和 值時,分類器的運行結果。 越大,表現在圖5.10左圖上,意味著更陡峭的坡度和更窄的鈴形曲線,分類意義上,則表現為到每個點的距離對分類的影響變小。表現在圖5.11上,決策邊界最終變得更加不規則,圍繞著個體實例擺動。所以 可以看做一個正則化超參數:如果模型過擬合,適當降低;如果模型欠擬合,適當增大。
除開高斯RBF以外,SVC還提供了其他的核函數,但是應用相對較少:
線性核:
拉普拉斯核:
Sigmoid核:
還有一些為特定的任務制定的特殊核函數,例如對於用於分類文檔或者DNA序列的基於Levenshtein距離的字元串核。
選取核函數時,我們通常採取以下準則:1.首先嘗試線性核,特別是當訓練集規模或特徵維較大時;其次,當訓練集不是非常大時,可以嘗試RBF核函數;再次,如果你有多餘的時間,可以嘗試以下其他的核函數。
5.2.4 計算複雜度
完成線性分類有三種函數:SVC(kernal="linear"),LinearSVC和SGDClassifier。下面我們將針對如何在上述三種函數中選取最好的函數進行討論:
LinearSVC是基於liblinear library的函數,並不支持核函數技術,但是它的複雜度基本與數據集規模線性相關,時間複雜度為 。默認loss為 ,在進行多分類時,默認採取One-versus-All(參見第三章),也就是說針對 個類構建 個分類器。
SVC(kernal="linear")是基於libsvm library的函數,即在函數中採用核函數技術。其時間複雜度在 到 之間,這意味著當樣本點數目 增加時,時間複雜度會大幅增加。故這個函數比較適用於特徵維度較大,但是數據集規模不大的項目,尤其適用於稀疏特徵數據集,在這種情況下,演算法使用的是每個實例的非零特性的平均數量。該函數默認loss為 ,在進行多分類時,採取One-versus-one,即針對 個類構建 個分類器。
綜上,我們可以得出一個簡單的結論:無論是對單分類還是多分類問題,LinearSVC通常比SVC(kernal="linear")計算速度快。
5.3 SVM回歸
在前面的章節中,我們探討了SVM在分類問題中的應用,分類問題可以形象地解釋為在樣本點間尋找一條儘可能寬的街道,將它們分隔開來。而當將SVM用於回歸問題時,則恰恰相反,我們試圖找到一條街道,使儘可能多的樣本點落在上面。「街道」的寬度由超參數 控制,下面我們將在一組隨機數據上採用線性SVR演算法。
from sklearn.svm import LinearSVRsvm_reg = LinearSVR(epsilon=1.5)
當要進行非線性回歸時,同樣可以引入SVM中的核函數技術。圖5.13展示了在一個隨機二次數據集上採用2度多項式核的SVR演算法時的回歸效果:
from sklearn.svm import SVRsvm_poly_reg = SVR(kernel="poly",degree=2,C=100,epsilon=0.1)
演算法中C控制著正則化程度,越大的C意味著更小的正則化程度,當出現過擬合是,適當減小C,反之亦然。
5.4 總結
在這一章,我們探討了SVM在分類和回歸中的應用問題,其中提及的核函數技術是一種重要的方法,可以幫助我們學習進行非線性分類和回歸。由於筆者寫這篇文章的初衷是讓各位知乎兒在看完本文後,可以學會如何利用Scikit_Learn框架進行實際的SVM編程,故沒有在文中描述SVM的數學原理。下面的習題是書中所附的習題,希望感興趣的知乎兒可以試著做一下,答案在完整代碼中可以找到。
1.利用SVM分類器對MNIST數據集進行分類。選用one-versus-all解決這個10分類問題,並採用驗證集對演算法中的超參數進行調節。
2.在第二章的加州房價數據集上應用SVM回歸技術。
最後,若本次文章中出現了什麼問題以及大家對我們的工作有什麼建議歡迎提出!我們將及時對文章進行改進。
——EVA 編輯
參考文獻:
[1]Géron A. Hands-on machine learning with Scikit-Learn and TensorFlow: concepts, tools, and techniques to build intelligent systems[M]. " OReilly Media, Inc.", 2017.
完整代碼:
ageron/handson-ml
推薦閱讀:
※71歲吳宇森回歸《追捕》?
※評台灣改變政體「回歸」大陸,台灣方面誰有資格決定這件事?yolfilm的回答
※暴走大事件終於回歸,那麼回歸後的暴大會有什麼不同?