如何畫XGBoost裡面的決策樹(decision tree)

最近用XGBoost很多, 訓練完模型後, 一般只是看看特徵重要性(feature importance score). 我對這種黑箱模型一般是不放心的, 所以喜歡把結果儘可能的畫出來看看. XGBoost是一種Boosting Tree方法, 模型中每個決策樹是可以畫出來看看的. 以為這是個很簡單問題, 後來發現其實坑還挺多的, 這裡簡單總結一下.

XGBoost有個plot_tree 函數, 訓練好模型後, 直接調用這個函數就可以了:

from xgboost import XGBClassifierfrom xgboost import plot_treeimport matplotlib.pyplot as pltmodel = XGBClassifier()model.fit(X, y)plot_tree(model)plt.show()

可以得到類似下面這個的圖, plot_tree有些參數可以調整, 比如num_trees=0表示畫第一棵樹, rankdir=LR表示圖片是從左到右(Left to Right). 圖片來自這裡.

下面問題就來了:

1. f1,f2是feature ID, 我的變數名跑哪裡去了? 怎麼加上去?

2. 怎麼調整圖片大小? 我常用的plt.figure(figsize=(10,10))怎麼不管用

3. 怎麼改變圖中字體大小? 字太小看著傷眼睛啊

4. 怎麼把圖存成pdf或者其它格式?

下面就一個個問題來解決.

如何改圖中的feature ID? 這估計是XGBoost裡面最大的一個坑了. XGBoost很多函數會用的一個參數fmap (也就是feature map),但是文檔裡面基本沒解釋這個fmap是怎麼產生的. 花了九牛二虎之力之後, 發現Kaggle上有好心人提供了解決方案.

def ceate_feature_map(features): outfile = open(xgb.fmap, w) i = 0 for feat in features: outfile.write({0} {1} q
.format(i, feat)) i = i + 1 outfile.close()ceate_feature_map(train_data.columns)

這個函數就是根據給定的特徵名字(我直接使用了數據的列名稱), 按照特定格式生成一個xgb.fmap文件, 這個文件就是XGBoost文檔裡面多次提到的fmap, 注意使用的時候, 直接提供文件名, 比如fmap=xgb.fmap.

有了fmap, 在調用plot_tree函數的時候, 直接指定fmap文件即可:

plot_tree(fmap=xgb.fmap)

這裡又有個坑. 雖然使用了fmap函數, 畫出來的圖仍然是feature ID. 我查看了一下本機上plot_tree函數, 發現並沒有fmap這個參數. 去XGBoost github上看了一下相關的函數, fmap這個函數是存在的. 用pip把XGBoost重裝一下, 問題仍然存在. 然後去github上下載了最新版本, 重新編譯安裝, 發現還是不行. 我猜測是本機已經有XGBoost而且版本號是最新的0.6(但有些函數其實在github上被更新了), 安裝的時候發現版本號一樣, 所以實際並沒有覆蓋老的版本. 所以嘗試卸載老版本重新安裝github版. 可是畫圖還是使用feature ID! 最後發現時import的問題, 需要使用reload重新import xgboost.

如何改變圖中字體大小? 我發現XGBoost裡面的tree如果超過3層, 基本字會很小, 很難看清楚. 所以想調大字體. 研究了很久XGBoost的源代碼, 發現XGBoost是使用了graphviz做圖, 可是XGBoost本身的wrapper只使用了graphviz裡面的一個參數graph_attr, 還有另外兩個參node_attr, edge_attr 都沒有用到, 直接後果就是屬於node_attr, edge_attr 的字體大小屬性不能更改.

我索性把XGBoost的源碼拷貝到我的程序, 然後做了相應的修改.

import re_NODEPAT = re.compile(r(d+):[(.+)])_LEAFPAT = re.compile(r(d+):(leaf=.+))_EDGEPAT = re.compile(ryes=(d+),no=(d+),missing=(d+))_EDGEPAT2 = re.compile(ryes=(d+),no=(d+))def _parse_node(graph, text): """parse dumped node""" match = _NODEPAT.match(text) if match is not None: node = match.group(1) graph.node(node, label=match.group(2), shape=plaintext) return node match = _LEAFPAT.match(text) if match is not None: node = match.group(1) graph.node(node, label=match.group(2).replace(leaf=,), shape=plaintext) return node raise ValueError(Unable to parse node: {0}.format(text))def _parse_edge(graph, node, text, yes_color=#0000FF, no_color=#FF0000): """parse dumped edge""" try: match = _EDGEPAT.match(text) if match is not None: yes, no, missing = match.groups() if yes == missing: graph.edge(node, yes, label=yes, missing, color=yes_color) graph.edge(node, no, label=no, color=no_color) else: graph.edge(node, yes, label=yes, color=yes_color) graph.edge(node, no, label=no, missing, color=no_color) return except ValueError: pass match = _EDGEPAT2.match(text) if match is not None: yes, no = match.groups() graph.edge(node, yes, label=yes, color=yes_color) graph.edge(node, no, label=no, color=no_color) return raise ValueError(Unable to parse edge: {0}.format(text))from graphviz import Digraphbooster = xgboost_model.get_booster()tree = booster.get_dump(fmap=xgb.fmap)[0]tree = tree.split()kwargs = { #label: A Fancy Graph, fontsize: 10, #fontcolor: white, #bgcolor: #333333, #rankdir: BT }kwargs = kwargs.copy()#kwargs.update({rankdir: rankdir})graph = Digraph(format=pdf, node_attr=kwargs,edge_attr=kwargs,engine=dot)#,edge_attr=kwargs,graph_attr=kwargs,#graph.attr(bgcolor=purple:pink, label=agraph, fontcolor=white)yes_color=#0000FFno_color=#FF0000for i, text in enumerate(tree): if text[0].isdigit(): node = _parse_node(graph, text) else: if i == 0: # 1st string must be node raise ValueError(Unable to parse given string as tree) _parse_edge(graph, node, text, yes_color=yes_color,no_color=no_color)graph.render(XGBoost_tree.pdf)graph

這裡有幾點要說明:

程序裡面有幾處shape=plaintext,在XGBoost源碼裡面是shape=circle或者shape=box, 我改成shape=plaintext是想是圖更緊湊一些,這樣看得更清楚. 不好的地方是, 圓圈大小代表了樣本多少, 這也是很重要的信息. (有朋友留言提到這裡圓圈大小並沒有樣本多少的信息, 我研究了一下XGBoost dump file, 自己試驗了一下, 發現圓圈大小的確和樣本大小無關, 只跟圓圈裡面變數名長度有關係, 這裡更正一下)

程序裡面label=match.group(2).replace(leaf=,)是因為XGBoost原圖的葉節點會有leaf=XXX, 我覺得很占空間, 所以也去掉了.

這裡 tree = booster.get_dump(fmap=xgb.fmap)[0], fmap就是前面生成的fmap文件, [0]表示第一棵樹, 如果你想要其他樹, 修改這個數字即可.

這裡graph = Digraph(format=pdf, node_attr=kwargs,edge_attr=kwargs,engine=dot)就是控制畫圖的主要參數, 格式是PDF,你可以改成PNG等. 這裡的graph_attr, node_attr, edge_attr 分別控制圖片不同部分的屬性, 這裡我修改了node和edge的字體大小. XGBoost源碼裡面只有graph_attr, 也就是說只能控制graph屬性.

基本就這樣了, 你可以直接複製我以上的程序使用. 看來有必要提交一個PR了.


推薦閱讀:

如何在Python上安裝xgboost?
記 XGBoost 的一個坑
機器學習競賽大殺器XGBoost--原理篇
LightGBM 中文文檔發布,持續 Update 中,歡迎各位大佬前來裝逼 | ApacheCN

TAG:xgboost | 机器学习 |