OpenCV機器學習——支持向量機SVM

OpenCV中集成了多種機器學習演算法供我們方便使用,如果我們要訓練數據進行分類,不用自己寫分類器,只需要調用相應的庫和類即可輕鬆實現。本文重點不在於介紹機器學習原理及數學推導,著重介紹OpenCV中的機器學習相關函數,並且用十分簡單的訓練數據作為例子實現分類。

對於OpenCV的機器學習分類器,大多換湯不換藥,構造方法和實現方法很類似,基本遵循原始數據—訓練分類器—進行分類的步驟,某些演算法可能有特殊的初始化參數,需要額外設置

在實現任何分類器之前,都需要訓練數據。插句題外話,訓練數據的好壞是一個分類器成功與否的決定性條件,數據選取永遠凌駕於分類器演算法選取之上,如果訓練數據選取得當,無論使用任何演算法都會得到不錯的效果,反之如果訓練數據選取不當,分類演算法是無法彌補的。在此我們使用簡單的二維數據作為訓練數據,其標號分別為1和-1,我們用圖像來直觀的表示:

t//設定800*800的二維坐標平面區域ntint width = 800, height = 800;ntMat I = Mat::zeros(height, width, CV_8UC3);nnt//訓練數據集,前10個標記為1,後10個標記為-1ntfloat trainingData[20][2] =nt{ { 100, 100 }, { 200, 100 }, { 400, 100 }, { 200, 200 }, { 500, 200 },nt{ 100, 300 }, { 300, 300 }, { 400, 300 }, { 100, 400 }, { 200, 500 },nt{ 600, 600 }, { 700, 300 }, { 700, 300 }, { 400, 500 }, { 600, 500 },nt{ 200, 700 }, { 300, 600 }, { 500, 600 }, { 600, 300 }, { 400, 700 } };nt//訓練數據集存入矩陣ntMat trainingDataMat(20, 2, CV_32FC1, trainingData);nnt//訓練數據標記,前10個標記為1,後10個標記為-1ntfloat labels[20] =nt{ 1.0, 1.0, 1.0, 1.0, 1.0,nt1.0, 1.0, 1.0, 1.0, 1.0,nt-1.0, -1.0, -1.0, -1.0, -1.0,nt-1.0, -1.0, -1.0, -1.0, -1.0 };nt//訓練數據標記存入矩陣ntMat labelsMat(20, 1, CV_32FC1, labels);nnt//將訓練數據用不同顏色畫出:1為綠色,-1為藍色ntfor (int i = 0; i < 20; i++)nt{nttif (labels[i] == 1.0)ntttcircle(I, Point(trainingData[i][0], trainingData[i][1]), 2, Scalar(255, 0, 0), 2);nttelse ntttcircle(I, Point(trainingData[i][0], trainingData[i][1]), 2, Scalar(0, 255, 0), 2);nt}ntimshow("dataset", I);n

注意訓練數據集矩陣類型一定是CV_32FC1型,長寬分別為數據個數和維度(20個訓練數據,2維);訓練數據標記矩陣是一維向量,也建議使用CV_32FC1型,還可用CV_32SC1型,長度為數據個數,要和訓練數據一一對應(如例子中前10個數據標記為1,後10個數據標記為-1)

接下來是SVM參數設定,建議設定方法是初始化一個空類,需要什麼參數單獨設定,具體如下:

CvSVMParams params;nparams.svm_type = CvSVM::C_SVC;nparams.kernel_type = CvSVM::LINEAR;nparams.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 100, FLT_EPSILON);n

其中,CvSVMParams可設置的參數有:(具體分類涉及SVM數學原理,不進行展開)

int svm_type:用來設定SVM的類型,分為C_SVC=100,nNU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104這5種,通常使用C_SVC=100作為一般的SVM分類器

int kernel_type:用來設定SVM所用核函數類型,分為LINEAR=0,nPOLY=1, RBF=2, SIGMOID=3這三種,文中訓練數據分類較為簡單,用線性核LINEAR=0即可

double degree:用來設定多項式內核函數(POLY=1)的冪次

double gamma:用來設定內核函數(POLY/ RBF/nSIGMOID)的參數gamma(多項式係數)

double coef0:用來設定內核函數(POLY/ RBF/nSIGMOID)的參數coef0(常數項)

double C、double nu、double p、CvMat* class_weights:用來設定非C_SVC=100類型的相應參數

CvTermCriteria term_crit:用來設定SVM迭代終止條件,其構造類型為(intntype, int max_iter, double expsolon),三個參數分別意為結束方式(迭代次數為基準的CV_TERMCRIT_ITER或誤差值為基準的CV_TERMCRIT_EPS),最大迭代次數,最小誤差值

nnnnnnnnnnnnnnnnnn綜上所述,文中SVM參數設置為:一般SVM分類器,線性核,循環終止,100次循環,最小誤差值為定義FLT_EPSILON(1.192092896e-07F)。

設置完參數後,就該是SVM訓練了,由於類初始化需要CvMat*數據類型,依舊建議初始化一個空類,需要什麼參數用函數添加,具體如下:

CvSVM SVM;nSVM.train(trainingDataMat, labelsMat, Mat(), Mat(), params);n

svm.train即為訓練函數,其參數為

bool train( const cv::Mat& trainData, const cv::Mat& responses,nconst cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),nCvSVMParams params=CvSVMParams() )n

const cv::Mat& trainData:訓練數據集,前文設定20*2的Mat trainingDataMat,再次提醒格式一定是CV_32FC1

const cv::Mat& responses:響應數據,即前文的訓練數據標記,20*1的向量Mat labelsMat,格式最好也是CV_32FC1

const cv::Mat& varIdx=cv::Mat(), constncv::Mat& sampleIdx=cv::Mat():兩個參數表示感興趣的特徵和樣本,如沒有感興趣對象則設為空矩陣Mat()即可

nnnnnnCvSVMParamsnparams=CvSVMParams():SVM參數設定,即前文設定的CvSVMParams params

運行完train後,樣本訓練過程結束,可用SVM.predict()函數進行分類,用SVM.get_support_vector_count()函數和SVM.get_support_vector()函數查看支持向量,下面分別介紹三個函數:

float predict( const cv::Mat& sample, bool returnDFVal=false ) constn

函數作用:判斷sample的類別

參數const Mat& sample:待分類向量,文中訓練數據是二維數據,因此待分類向量應是1*2的Mat矩陣,數據類型應為float型(CV_32F)

參數bool returnDFVal=false:判斷是否為二分類器,通常情況下不用設定,默認false即可

返回值:const Mat& sample的分類結果,文中返回值應為前文設定的訓練數據標記種類1或-1

簡單例子:

float temp[2] = { i, j };nMat sampleMat(1, 2, CV_32F, temp);nfloat response = SVM.predict(sampleMat);n

int get_support_vector_count() constnconst float* get_support_vector(int i) constn

兩個函數作用是獲得支持向量,通常需要結合使用。int get_support_vector_count()得到支持向量個數,將結果遍歷帶入float*nget_support_vector(int i)的參數i便可獲得每個支持向量

簡單例子:

int c = SVM.get_support_vector_count();nfor (int i = 0; i < c; ++i)n{const float* v = SVM.get_support_vector(i);}n

SVM函數大體如此,完整代碼及注釋:

#include <iostream>n#include <opencv.hpp>nusing namespace std;nusing namespace cv;nnvoid main()n{nt//設定800*800的二維坐標平面區域ntint width = 800, height = 800;ntMat I = Mat::zeros(height, width, CV_8UC3);nnt//訓練數據集,前10個標記為1,後10個標記為-1ntfloat trainingData[20][2] =nt{ { 100, 100 }, { 200, 100 }, { 400, 100 }, { 200, 200 }, { 500, 200 },nt{ 100, 300 }, { 300, 300 }, { 400, 300 }, { 100, 400 }, { 200, 500 },nt{ 600, 600 }, { 700, 300 }, { 700, 300 }, { 400, 500 }, { 600, 500 },nt{ 200, 700 }, { 300, 600 }, { 500, 600 }, { 600, 300 }, { 400, 700 } };nt//訓練數據集存入矩陣ntMat trainingDataMat(20, 2, CV_32FC1, trainingData);nnt//訓練數據標記,前10個標記為1,後10個標記為-1ntfloat labels[20] =nt{ 1.0, 1.0, 1.0, 1.0, 1.0,nt1.0, 1.0, 1.0, 1.0, 1.0,nt-1.0, -1.0, -1.0, -1.0, -1.0,nt-1.0, -1.0, -1.0, -1.0, -1.0 };nt//訓練數據標記存入矩陣ntMat labelsMat(20, 1, CV_32FC1, labels);nnt//將訓練數據用不同顏色畫出:1為綠色,-1為藍色ntfor (int i = 0; i < 20; i++)nt{nttif (labels[i] == 1.0)ntttcircle(I, Point(trainingData[i][0], trainingData[i][1]), 2, Scalar(255, 0, 0), 2);nttelse ntttcircle(I, Point(trainingData[i][0], trainingData[i][1]), 2, Scalar(0, 255, 0), 2);nt}ntimshow("dataset", I);nnt//SVM參數設置ntCvSVMParams params;ntparams.svm_type = CvSVM::C_SVC;ntparams.kernel_type = CvSVM::LINEAR;ntparams.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 100, FLT_EPSILON);nnt//SVM訓練ntCvSVM SVM;ntSVM.train(trainingDataMat, labelsMat, Mat(), Mat(), params);nnt//SVM分類結果顯示:1區域為綠色,-1區域為藍色ntfor (int i = 0; i < I.rows; ++i)ntfor (int j = 0; j < I.cols; ++j)nt{nttfloat temp[2] = { i, j };nttMat sampleMat(1, 2, CV_32FC1, temp);nttfloat response = SVM.predict(sampleMat);nnttif (response == 1)ntttI.at<Vec3b>(j, i) = Vec3b(255, 0, 0);nttelse if (response == -1)ntttI.at<Vec3b>(j, i) = Vec3b(0, 255, 0);nt}ntfor (int i = 0; i < 20; i++)nt{nttif (labels[i] == 1.0)ntttcircle(I, Point(trainingData[i][0], trainingData[i][1]), 2, Scalar(255, 255, 255), 2);nttelsentttcircle(I, Point(trainingData[i][0], trainingData[i][1]), 2, Scalar(0, 0, 0), 2);nt}nnt//支持向量標註,用紅圈圈出ntint c = SVM.get_support_vector_count();ntfor (int i = 0; i < c; ++i)nt{nttconst float* v = SVM.get_support_vector(i);nttcircle(I, Point((int)v[0], (int)v[1]), 6, Scalar(0, 0, 255), 2, 8);nt}ntimshow("result", I);nntwaitKey();n}n

分類結果:


推薦閱讀:

如今Weex與ReactNative哪個好?
如何實現Pixhawk與matlab捷聯慣導?
有哪些像扇貝那樣的學習網站,領域可以是計算機、數學等?
如何對指定文件夾進行簡單加密?

TAG:OpenCV | 计算机技术 | 机器学习 |