修剪决策树

user9238790:

大家好,下面是决策树的摘要,因为它非常庞大。

在此处输入图片说明

如何使树木停止生长时的最低中的一个节点是5岁以下这里是产生决策树的代码。SciKit-决策树上,我们可以看到这样做的唯一方法是通过min_impurity_decrease,但是我不确定它的具体工作方式。

import numpy as np
import pandas as pd
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier


X, y = make_classification(n_samples=1000,
                           n_features=6,
                           n_informative=3,
                           n_classes=2,
                           random_state=0,
                           shuffle=False)

# Creating a dataFrame
df = pd.DataFrame({'Feature 1':X[:,0],
                                  'Feature 2':X[:,1],
                                  'Feature 3':X[:,2],
                                  'Feature 4':X[:,3],
                                  'Feature 5':X[:,4],
                                  'Feature 6':X[:,5],
                                  'Class':y})


y_train = df['Class']
X_train = df.drop('Class',axis = 1)

dt = DecisionTreeClassifier( random_state=42)                
dt.fit(X_train, y_train)

from IPython.display import display, Image
import pydotplus
from sklearn import tree
from sklearn.tree import _tree
from sklearn import tree
import collections
import drawtree
import os  

os.environ["PATH"] += os.pathsep + 'C:\\Anaconda3\\Library\\bin\\graphviz'

dot_data = tree.export_graphviz(dt, out_file = 'thisIsTheImagetree.dot',
                                 feature_names=X_train.columns, filled   = True
                                    , rounded  = True
                                    , special_characters = True)

graph = pydotplus.graph_from_dot_file('thisIsTheImagetree.dot')  

thisIsTheImage = Image(graph.create_png())
display(thisIsTheImage)
#print(dt.tree_.feature)

from subprocess import check_call
check_call(['dot','-Tpng','thisIsTheImagetree.dot','-o','thisIsTheImagetree.png'])

更新资料

我认为min_impurity_decrease可以以某种方式帮助实现目标。调整min_impurity_decrease实际上会修剪树。谁能解释一下min_impurity_decrease。

我试图理解scikit学习中的等式,但是我不确定right_impurity和left_impurity的值是什么。

N = 256
N_t = 256
impurity = ??
N_t_R = 242
N_t_L = 14
right_impurity = ??
left_impurity = ??

New_Value = N_t / N * (impurity - ((N_t_R / N_t) * right_impurity)
                    - ((N_t_L / N_t) * left_impurity))
New_Value

更新2

在一定条件下修剪,而不是修剪成一定的值。例如我们确实以6/4和5/5分割,但不以6000/4或5000/5分割。假设某个值与其节点中的相邻值相比是否在某个百分比以下,而不是某个值。

      11/9
   /       \
  6/4       5/5
 /   \     /   \
6/0  0/4  2/2  3/3
大卫·戴尔(David Dale):

使用min_impurity_decrease或任何其他内置的停止标准无法直接限制叶子的最小值(特定类的出现次数)。

我认为,不更改scikit-learn的源代码即可完成此操作的唯一方法是对进行后修剪为此,您可以遍历树并删除最小类数小于5(或您想到的任何其他条件)的节点的所有子级。我将继续您的示例:

from sklearn.tree._tree import TREE_LEAF

def prune_index(inner_tree, index, threshold):
    if inner_tree.value[index].min() < threshold:
        # turn node into a leaf by "unlinking" its children
        inner_tree.children_left[index] = TREE_LEAF
        inner_tree.children_right[index] = TREE_LEAF
    # if there are shildren, visit them as well
    if inner_tree.children_left[index] != TREE_LEAF:
        prune_index(inner_tree, inner_tree.children_left[index], threshold)
        prune_index(inner_tree, inner_tree.children_right[index], threshold)

print(sum(dt.tree_.children_left < 0))
# start pruning from the root
prune_index(dt.tree_, 0, 5)
sum(dt.tree_.children_left < 0)

此代码将首先打印74,然后91这意味着该代码创建了17个新的叶子节点(实际上是删除了到其祖先的链接)。这棵树,以前看起来像

在此处输入图片说明

现在看起来像

在此处输入图片说明

因此您可以看到确实减少了很多。

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章