第二十九章 Decision Tree演算法(上)
這一篇文章中,討論一種被廣泛使用的分類演算法——決策樹(decision tree)。決策樹的優勢在於構造過程不需要任何領域知識或參數設置,因此在實際應用中,對於探測式的知識發現,決策樹更加適用。
決策樹案例
通俗來說,決策樹分類的思想類似於找對象。現想像一個女孩的母親要給這個女孩介紹男朋友,於是有了下面的對話:
女兒:多大年紀了?
母親:26。
女兒:長的帥不帥?
母親:挺帥的。
女兒:收入高不?
母親:不算很高,中等情況。
女兒:是公務員不?
母親:是,在稅務局上班呢。
女兒:那好,我去見見。
這個女孩的決策過程就是典型的分類樹決策。相當於通過年齡、長相、收入和是否公務員對將男人分為兩個類別:見和不見。假設這個女孩對男人的要求是:30歲以下、長相中等以上並且是高收入者或中等以上收入的公務員,那麼這個可以用下圖表示女孩的決策邏輯。
上圖完整表達了這個女孩決定是否見一個約會對象的策略,其中綠色節點表示判斷條件,橙色節點表示決策結果,箭頭表示在一個判斷條件在不同情況下的決策路徑,圖中紅色箭頭表示了上面例子中女孩的決策過程。
這幅圖基本可以算是一顆決策樹,說它「基本可以算」是因為圖中的判定條件沒有量化,如收入高中低等等,還不能算是嚴格意義上的決策樹,如果將所有條件量化,則就變成真正的決策樹了。
有了上面直觀的認識,我們可以正式定義決策樹了:
決策樹(decision tree)是一個樹結構(可以是二叉樹或非二叉樹)。其每個非葉節點表示一個特徵屬性上的測試,每個分支代表這個特徵屬性在某個值域上的輸出,而每個葉節點存放一個類別。使用決策樹進行決策的過程就是從根節點開始,測試待分類項中相應的特徵屬性,並按照其值選擇輸出分支,直到到達葉子節點,將葉子節點存放的類別作為決策結果。
可以看到,決策樹的決策過程非常直觀,容易被人理解。目前決策樹已經成功運用於醫學、製造產業、天文學、分支生物學以及商業等諸多領域。知道了決策樹的定義以及其應用方法,下面介紹決策樹的構造演算法。
決策樹的構造
不同於貝葉斯演算法,決策樹的構造過程不依賴領域知識,它使用屬性選擇度量來選擇將元組最好地劃分成不同的類的屬性。所謂決策樹的構造就是進行屬性選擇度量確定各個特徵屬性之間的拓撲結構。
構造決策樹的關鍵步驟是分裂屬性。所謂分裂屬性就是在某個節點處按照某一特徵屬性的不同劃分構造不同的分支,其目標是讓各個分裂子集儘可能地「純」。儘可能「純」就是盡量讓一個分裂子集中待分類項屬於同一類別。分裂屬性分為三種不同的情況:
1、屬性是離散值且不要求生成二叉決策樹。此時用屬性的每一個劃分作為一個分支。
2、屬性是離散值且要求生成二叉決策樹。此時使用屬性劃分的一個子集進行測試,按照「屬於此子集」和「不屬於此子集」分成兩個分支。
3、屬性是連續值。此時確定一個值作為分裂點split_point,按照>split_point和<=split_point生成兩個分支。
構造決策樹的關鍵性內容是進行屬性選擇度量,屬性選擇度量是一種選擇分裂準則,是將給定的類標記的訓練集合的數據劃分D「最好」地分成個體類的啟發式方法,它決定了拓撲結構及分裂點split_point的選擇。
屬性選擇度量演算法有很多,一般使用自頂向下遞歸分治法,並採用不回溯的貪心策略。這裡介紹ID3和c4.5兩種常用演算法。
ID3演算法
從資訊理論知識中我們直到,期望信息越小,信息增益越大,從而純度越高。所以ID3演算法的核心思想就是以信息增益度量屬性選擇,選擇分裂後信息增益最大的屬性進行分裂。下面先定義幾個要用到的概念。
設D為用類別對訓練元組進行的劃分,則D的熵表示為:
其中pi表示第i個類別在整個訓練元組中出現的概率,可以用屬於此類別元素的數量除以訓練元組元素總數量作為估計。熵的實際意義表示是D中元組的類標號所需要的平均信息量。現在我們假設將訓練元組D按屬性A進行劃分,則A對D劃分的期望信息為:
而信息增益即為兩者的差值:
ID3演算法就是在每次需要分裂時,計算每個屬性的增益率,然後選擇增益率最大的屬性進行分裂。下面我們繼續用SNS社區中不真實賬號檢測的例子說明如何使用ID3演算法構造決策樹。為了簡單起見,我們假設訓練集合包含10個元素:
其中s、m和l分別表示小、中和大。
設L、F、H和R表示日誌密度、好友密度、是否使用真實頭像和賬號是否真實,下面計算各屬性的信息增益。
因此日誌密度的信息增益是0.276。
用同樣方法得到H和F的信息增益分別為0.033和0.553。
因為F具有最大的信息增益,所以第一次分裂選擇F為分裂屬性,分裂後的結果如下圖表示:
在上圖的基礎上,再遞歸使用這個方法計運算元節點的分裂屬性,最終就可以得到整個決策樹。
上面為了簡便,將特徵屬性離散化了,其實日誌密度和好友密度都是連續的屬性。對於特徵屬性為連續值,可以如此使用ID3演算法:先將D中元素按照特徵屬性排序,則每兩個相鄰元素的中間點可以看做潛在分裂點,從第一個潛在分裂點開始,分裂D並計算兩個集合的期望信息,具有最小期望信息的點稱為這個屬性的最佳分裂點,其信息期望作為此屬性的信息期望。
C4.5演算法
ID3演算法存在一個問題,就是偏向於多值屬性,例如,如果存在唯一標識屬性ID,則ID3會選擇它作為分裂屬性,這樣雖然使得劃分充分純凈,但這種劃分對分類幾乎毫無用處。ID3的後繼演算法C4.5使用增益率的信息增益擴充,試圖克服這個偏倚。
C4.5演算法首先定義了「分裂信息」,其定義可以表示成:
其中各符號意義與ID3演算法相同,然後,增益率被定義為:
C4.5選擇具有最大增益率的屬性作為分裂屬性,其具體應用與ID3類似,不再贅述。
如果屬性用完了怎麼辦
在決策樹構造過程中可能會出現這種情況:所有屬性都作為分裂屬性用光了,但有的子集還不是純凈集,即集合內的元素不屬於同一類別。在這種情況下,由於沒有更多信息可以使用了,一般對這些子集進行「多數表決」,即使用此子集中出現次數最多的類別作為此節點類別,然後將此節點作為葉子節點。
關於剪枝
在實際構造決策樹時,通常要進行剪枝,這時為了處理由於數據中的雜訊和離群點導致的過分擬合問題。剪枝有兩種:
先剪枝——在構造過程中,當某個節點滿足剪枝條件,則直接停止此分支的構造。
後剪枝——先構造完成完整的決策樹,再通過某些條件遍歷樹進行剪枝。
流程圖形式的決策樹
上圖流程圖就是一個假想的郵件分類系統決策樹,正方形代表判斷模塊,橢圓形代表終止模塊,表示已經得出結論,可以終止運行。判斷模塊引出的左右箭頭稱為分支,它可以到達另一個判斷模塊或者終止模塊。決策樹的主要優勢在於數據形式非常容易理解。
優點
計算複雜度不高,對中間值的缺失不敏感,可以處理不相干特徵數據,輸出結果易於理解。
缺點
可能會產生過度匹配問題。
適用數據類型
數值型跟標稱型
ID3演算法python2實現
from math import logimport operatordef loadDataSet(fileName): #general function to parse tab -delimited floats dataMat = [] #assume last column is target value fr = open(fileName) for line in fr.readlines(): curLine = line.strip().split(,) fltLine = list(curLine) #map all elements to float() dataMat.append(fltLine) return dataMat#def createDataSet(): #dataSet = [[1, 1, yes], #[1, 1, yes], #[1, 0, no], #[0, 1, no], #[0, 1, no]] #labels = [no surfacing,flippers] #print dataSet #change to discrete values #return dataSet, labelsdef calcShannonEnt(dataSet):#計算給定數據集的香農熵 numEntries = len(dataSet)#計算數據集中實例的總數 #print numEntries labelCounts = {}#創建一個數據字典 for featVec in dataSet: #the the number of unique elements and their occurance currentLabel = featVec[-1]#鍵值是最後一列數值 #print currentLabel if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0#如果鍵值不存在,則擴展字典並將當前鍵值加入字典 labelCounts[currentLabel] += 1#當前鍵值加入字典 #print labelCounts shannonEnt = 0.0 for key in labelCounts: prob = float(labelCounts[key])/numEntries# 使用所有類標籤的發生概率計算類別出現的概率 #print prob shannonEnt -= prob * log(prob,2) #log base 2 用這個概率計算香農熵 #print shannonEnt return shannonEntdef splitDataSet(dataSet, axis, value):#劃分數據集,dataSet為待劃分的數據集,axis為劃分數據集的特徵,value為特徵的返回值 retDataSet = []#創建新的list對象 for featVec in dataSet: #print featVec if featVec[axis] == value:#將符合特徵的數據抽取出來 reducedFeatVec = featVec[:axis] #chop out axis used for splitting #print reducedFeatVec reducedFeatVec.extend(featVec[axis+1:])#extend()函數只接受一個列表作為參數,並將該參數的每個元素都添加到原有的列表中 #print reducedFeatVec retDataSet.append(reducedFeatVec)#append()向列表的尾部添加一個元素,任意,可以是tuple return retDataSetdef chooseBestFeatureToSplit(dataSet):#遍歷整個數據集,循環計算香農熵和splitDataSet()函數,找到最好的劃分方式 numFeatures = len(dataSet[0]) - 1 #the last column is used for the labels判斷當前數據集包含多少特徵屬性 baseEntropy = calcShannonEnt(dataSet)#計算整個數據集的原始香農熵,保存最初的無序度量值,用於與劃分完之後的數據集計算的熵值進行比較 bestInfoGain = 0.0; bestFeature = -1 for i in range(numFeatures): #iterate over all the features 遍曆數據集中的所有特徵 featList = [example[i] for example in dataSet]#create a list of all the examples of this feature使用列表推導創建新的列表 uniqueVals = set(featList) #get a set of unique values將數據集中所有可能存在的值寫入featlist中,並從列表中創建集合 newEntropy = 0.0 for value in uniqueVals:#遍歷當前特徵值中的所有唯一屬性值, subDataSet = splitDataSet(dataSet, i, value)#對每個特徵劃分一次數據集 prob = len(subDataSet)/float(len(dataSet)) newEntropy += prob * calcShannonEnt(subDataSet) #計算數據集的新熵值,並對所有唯一特徵值得到的熵求和 infoGain = baseEntropy - newEntropy #calculate the info gain; ie reduction in entropy#信息增益 if (infoGain > bestInfoGain): #compare this to the best gain so far#比較所有特徵中的信息增益 bestInfoGain = infoGain #if better than current best, set to best bestFeature = i#返回最好特徵劃分的索引值 return bestFeature #returns an integerdef majorityCnt(classList): classCount={}#創建鍵值為classList中唯一值的數據字典,字典對象存儲了classList中每個類標籤現的頻率 for vote in classList: if vote not in classCount.keys(): classCount[vote] = 0 classCount[vote] += 1 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)#利用operator操作鍵值排序字典 return sortedClassCount[0][0]#返回出現次數最多的分類名稱def createTree(dataSet,labels):#創建樹的函數代碼 classList = [example[-1] for example in dataSet]#創建了名為classList的列表變數,包含所有類標籤 if classList.count(classList[0]) == len(classList): # return classList[0]#stop splitting when all of the classes are equal if len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSet return majorityCnt(classList) bestFeat = chooseBestFeatureToSplit(dataSet) bestFeatLabel = labels[bestFeat] myTree = {bestFeatLabel:{}}#存儲了樹的所有信息 del(labels[bestFeat])#當前數據集選取的最好特徵存儲在變數bestFeat中, featValues = [example[bestFeat] for example in dataSet]#得到列表包含的所有屬性值 uniqueVals = set(featValues) for value in uniqueVals:#遍歷當前選擇特徵包含的所有屬性值 subLabels = labels[:] #copy all of labels, so trees dont mess up existing labels為了保證每次調用函數createtree()時不改變原始列表的內容,使用新變數subLabels代替原始列表 myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)#遞歸調用函數cteatetree()得到的返回值插入到字典變數mytree中 return myTree def classify(inputTree,featLabels,testVec):#遞歸函數, firstStr = inputTree.keys()[0] secondDict = inputTree[firstStr] featIndex = featLabels.index(firstStr)#使用index查找當前列表中第一個匹配firstStr變數的元素 key = testVec[featIndex] valueOfFeat = secondDict[key] if isinstance(valueOfFeat, dict): #如果到達葉子節點 classLabel = classify(valueOfFeat, featLabels, testVec)#返回當前節點的分類標籤 else: classLabel = valueOfFeat return classLabeldef storeTree(inputTree,filename):#決策樹分類器的存儲 import pickle#pickle序列化對象可以在磁碟上保存對象,並在需要時讀出來 fw = open(filename,w) pickle.dump(inputTree,fw) fw.close() def grabTree(filename): import pickle fr = open(filename) return pickle.load(fr)import matplotlib.pyplot as pltdecisionNode = dict(boxstylex="sawtooth", fc="0.8")leafNode = dict(boxstylex="round4", fc="0.8")arrow_args = dict(arrowstylex="<-")def getNumLeafs(myTree):#遍歷整顆樹,累計葉子節點的個數,並返回數值 numLeafs = 0 firstStr = myTree.keys()[0]#第一個關鍵字是以第一次劃分數據集的類別標籤 secondDict = myTree[firstStr]#表示子節點的數值 for key in secondDict.keys():#遍歷整顆樹的所有子節點 if type(secondDict[key]).__name__==dict:#test to see if the nodes are dictonaires, if not they are leaf nodes type()函數判斷子節點是否為字典類型,如果子節點是字典類型 numLeafs += getNumLeafs(secondDict[key])#則該節點是一個判斷節點,遞歸調用getNumLeafs()函數 else: numLeafs +=1 return numLeafsdef getTreeDepth(myTree):#遍歷過程中遇到判斷節點的個數 maxDepth = 0 firstStr = myTree.keys()[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__==dict:#test to see if the nodes are dictonaires, if not they are leaf nodes thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepthdef plotNode(nodeTxt, centerPt, parentPt, nodeType):#執行實際的繪圖功能 createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords=axes fraction, xytext=centerPt, textcoords=axes fraction, va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )#全局繪圖區def plotMidText(cntrPt, parentPt, txtString):#計運算元節點與父節點的中間位置,在父節點間填充文本信息 xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on計算寬與高 numLeafs = getNumLeafs(myTree) #this determines the x width of this tree計算樹的寬 depth = getTreeDepth(myTree)#計算樹的高 firstStr = myTree.keys()[0] #the text label for this node should be this cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) plotMidText(cntrPt, parentPt, nodeTxt)#標記子節點屬性值 plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__==dict:#test to see if the nodes are dictonaires, if not they are leaf nodes plotTree(secondDict[key],cntrPt,str(key)) #recursion else: #its a leaf node print the leaf node plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD#if you do get a dictonary you know its a tree, and the first element will be another dict def createPlot(inTree):#第一個版本的函數 fig = plt.figure(1, facecolor=white) fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses plotTree.totalW = float(getNumLeafs(inTree))#全局變數plotTree.totalW存儲樹的寬度 plotTree.totalD = float(getTreeDepth(inTree))#全局變數plotTree.totalD存儲樹的深度 plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;#使用plotTree.totalW,plotTree.totalD計算樹節點的擺放位置,這樣可以將樹繪製在水平方向和垂直方向的中心位置 plotTree(inTree, (0.5,1.0), ) plt.show()#def createPlot(): #fig = plt.figure(1, facecolor=white) #fig.clf() #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses #plotNode(a decision node, (0.5, 0.1), (0.1, 0.5), decisionNode) #plotNode(a leaf node, (0.8, 0.1), (0.3, 0.8), leafNode) #plt.show()def retrieveTree(i):#輸出預先存儲的樹信息,避免每次測試都要從數據集中創建樹的麻煩,該函數主要用於測試,返回預定義的樹結構 listOfTrees =[{no surfacing: {0: no, 1: {flippers: {0: no, 1: yes}}}}, {no surfacing: {0: no, 1: {flippers: {0: {head: {0: no, 1: yes}}, 1: no}}}} ] return listOfTrees[i]if __name__ == "__main__": myMat=loadDataSet(C:UsersHZFDesktopcredit.txt) labels=[A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14,A15] myTree=createTree(myMat,labels) print myTree #numLeafs=getNumLeafs(myTree) maxDepth=getTreeDepth(myTree) print maxDepth createPlot(myTree) #classLabel=classify(myTree,labels,[1,1]) #labels=[age,prescript,astigmatic,tearRate] #storeTree(myTree,C:/Users/HZF/Desktop/machinelearninginaction/Ch03/classifierStorage.txt) #pickle.load(fr)=grabTree(C:/Users/HZF/Desktop/machinelearninginaction/Ch03/classifierStorage.txt) #print pickle.load(fr) #print numLeafs #print maxDepth #print classLabel #myMat,labels=createDataSet() #myMat[0][-1]=maybe #print myMat #shannonEnt=calcShannonEnt(myMat) #print shannonEnt #retDataSet=splitDataSet(myMat, 0, 0) #print retDataSet #bestFeature=chooseBestFeatureToSplit(myMat) #print bestFeature #myTree=createTree(myMat,labels) #classify(inputTree,featLabels,testVec) #print myTree
上篇就寫這麼多了,中下篇會儘快更新哦!
參考文獻
1、《機器學習實戰》(書)
2、演算法雜貨鋪——分類演算法之決策樹(博客)
3、其他網站資料(略)
推薦閱讀:
※024 Swap Nodes in Pairs[M]
※K-means計算城市聚類
※自動駕駛還在盯著「演算法」和「感測器」么?風河打算為其開發通用底層操作系統了
※替換空格
※【轉】機器學習新手必學十大演算法指南
TAG:演算法 |