隨機森林,gbdt,xgboost的決策樹子類Python講解
來自專欄 常用機器學習演算法實現與講解
1.前言
上一篇文章我們講了最簡單的決策樹,這篇文章主要介紹github上他人封裝好的決策樹代碼,為後面的隨機森林,gbdt,xgboost等演算法做準備,看懂了我上一篇文章的同學們,應該能輕鬆看懂這一篇的講解。附上一篇文章鏈接CART決策樹(Decision Tree)的Python源碼實現
建議閱讀順序:先閱讀源代碼,再來看源碼關鍵方法的講解,源碼地址:RRdmlearning/Machine-Learning-From-Scratch
不知為何知乎上的代碼格式沒有原文章便於理解,大家可在cs229論壇社區|深度學習社區|機器學習社區|人工智慧社區閱讀。
2.源碼講解
2.1 DecisionTee(決策樹)
首先我們來看一下DecisionTree中的變數
class DecisionTree(object): #Super class of RegressionTree and ClassificationTree. def __init__(self, min_samples_split=2, min_impurity=1e-7, max_depth=float("inf"), loss=None): self.root = None # Root node in dec. tree # Minimum n of samples to justify split self.min_samples_split = min_samples_split # The minimum impurity to justify split self.min_impurity = min_impurity # The maximum depth to grow the tree to self.max_depth = max_depth # Function to calculate impurity (classif.=>info gain, regr=>variance reduct.) # 切割樹的方法,gini,方差等 self._impurity_calculation = None # Function to determine prediction of y at leaf # 樹節點取值的方法,分類樹:選取出現最多次數的值,回歸樹:取所有值的平均值 self._leaf_value_calculation = None # If y is one-hot encoded (multi-dim) or not (one-dim) self.one_dim = None # If Gradient Boost self.loss = loss
其中兩個最重要的變數是impurity_calculation和leaf_value_calculation。
impurity_calculation代表你切割樹的標準是什麼,例如若是分類樹則切割的標準是基尼指數,回歸樹則是最小平方殘差
leaf_value_calculation代表你計算節點值的方法是什麼,例如若是分類樹則取切割數據集中數量最多的種類,回歸樹則計算切割數據集中所有的平均值
下面我們可以由DecisionTree這個基類來建議分類樹,回歸樹
2.2 ClassificationTree(分類樹)
分類樹定義如下
class ClassificationTree(DecisionTree): def _calculate_information_gain(self, y, y1, y2): # Calculate information gain p = len(y1) / len(y) entropy = calculate_entropy(y) info_gain = entropy - p * calculate_entropy(y1) - (1 - p) * calculate_entropy(y2) # print("info_gain",info_gain) return info_gain def _majority_vote(self, y): most_common = None max_count = 0 for label in np.unique(y): # Count number of occurences of samples with label count = len(y[y == label]) if count > max_count: most_common = label max_count = count # print("most_common :",most_common) return most_common def fit(self, X, y): self._impurity_calculation = self._calculate_information_gain self._leaf_value_calculation = self._majority_vote super(ClassificationTree, self).fit(X, y)
方法 def _calculate_information_gain():
是切割樹的標準,這裡使用的是交叉熵
方法 def _majority_vote():
則是計運算元節點值的方法,這裡使用的是選取數據集中出現最多的種類
方法 def fit():
將分類樹切割的標準與計運算元節點值的方式傳回給基類DecisionTree
2.3 RegressionTree(回歸樹)
回歸樹定義如下
class RegressionTree(DecisionTree): def _calculate_variance_reduction(self, y, y1, y2): var_tot = calculate_variance(y) var_1 = calculate_variance(y1) var_2 = calculate_variance(y2) frac_1 = len(y1) / len(y) frac_2 = len(y2) / len(y) # Calculate the variance reduction variance_reduction = var_tot - (frac_1 * var_1 + frac_2 * var_2) return sum(variance_reduction) def _mean_of_y(self, y): value = np.mean(y, axis=0) return value if len(value) > 1 else value[0] def fit(self, X, y): self._impurity_calculation = self._calculate_variance_reduction self._leaf_value_calculation = self._mean_of_y super(RegressionTree, self).fit(X, y)
def _calculate_variance_reduction():
是切割樹的標準,這裡使用的是平方殘差
def _mean_of_y():
是計運算元節點值的方法,這裡使用的是取數據集中的平均值
def fit():
方法將回歸樹切割的標準與計運算元節點值的方式傳回給基類DecisionTree
3.運行
直接運行decision_tree_classifier_example.py文件或decision_tree_regressor_example.py即可。
分別是分類樹與決策樹。
此文章為記錄自己一路的學習路程,也希望能給廣大初學者們一點點幫助,如有錯誤,疑惑歡迎一起交流。
推薦閱讀:
※1-4 Supervised Learning
※Building Deep Neural Network from scratch-吳恩達深度學習第一課第四周習題答案(1)
※邁克爾 · 喬丹:我討厭將機器學習稱為AI
※非負矩陣分解(NMF)及一個小實例
※機器學習各種熵:從入門到全面掌握
TAG:決策樹 | 機器學習 | 深度學習DeepLearning |