Hướng dẫn interactive decision tree python

Hướng dẫn interactive decision tree python

Photo by Aaron Burden on Unsplash

Decision Trees are broadly used supervised models for classification and regression tasks. In this article, we will talk about decision tree classifiers and how we can dynamically visualize them. These classifiers build a sequence of simple if/else rules on the training data through which they predict the target value. Decision trees are simple to interpret due to their structure and the ability we have to visualize the modeled tree.

Using sklearn export_graphviz function we can display the tree within a Jupyter notebook. For this demonstration, we will use the sklearn wine data set.

from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn import tree
from sklearn.datasets import load_wine
from IPython.display import SVG
from graphviz import Source
from IPython.display import display
# load dataset
data = load_wine()

# feature matrix
X = data.data

# target vector
y = data.target

# class labels
labels = data.feature_names

# print dataset description
print(data.DESCR)

estimator = DecisionTreeClassifier()
estimator.fit(X, y)

graph = Source(tree.export_graphviz(estimator, out_file=None
, feature_names=labels, class_names=['0', '1', '2']
, filled = True))

display(SVG(graph.pipe(format='svg')))

Hướng dẫn interactive decision tree python

Decision Tree with default parameters

In the tree plot, each node contains the condition (if/else rule) that splits the data, along with a series of other metrics of the node. Gini refers to the Gini impurity, a measure of the impurity of the node, i.e. how homogeneous are the samples within the node. We say that a node is pure when all its samples belong to the same class. In that case, there is no need for further split and this node is called a leaf. Samples is the number of instances in the node, while the value array shows the distribution of these instances per class. At the bottom we see the majority class of the node. When filled option of export_graphviz is set to True each node gets colored according to the majority class.

While easy to understand, decision trees tend to over-fit the data, by constructing complex models. Over fitted models will most likely not generalize well in “unseen” data. Two main approaches to prevent over-fitting are pre and post-pruning. Pre-pruning means restricting the depth of a tree prior to creation while post-pruning is removing non-informative nodes after the tree has been built.

Sklearn learn decision tree classifier implements only pre-pruning. Pre-pruning can be controlled through several parameters such as the maximum depth of the tree, the minimum number of samples required for a node to keep splitting and the minimum number of instances required for a leaf . Below, we plot a decision tree on the same data, this time setting max_depth = 3.

Hướng dẫn interactive decision tree python

Decision Tree with max_depth = 3

This model is less deep and thus less complex than the one we trained and plotted initially.

Other than pre-pruning parameters, a decision tree has a series of other parameters that we try to optimize whenever building a classification model. We usually assess the effect of these parameters by looking at accuracy metrics. To get a grasp of how changes in parameters affect the structure of the tree we could again visualize a tree at each stage. Instead of plotting a tree each time we make a change, we can make use of Jupyter Widgets (ipywidgets) to build an interactive plot of our tree.

Jupyter widgets are interactive elements that allow us to render controls inside the notebook. There are two options to install ipywidgets, through pip and conda.

With pip

pip install ipywidgets
jupyter nbextension enable --py widgetsnbextension

With conda

conda install -c conda-forge ipywidgets

For this application, we will use the interactive function. First, we define a function that trains and plots a decision tree. Then, we pass this function along with a set of values for each of the parameters of interest to the interactive function. The latter returns a Widget instance that we show with display.

from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn import tree
from sklearn.datasets import load_wine
from IPython.display import SVG
from graphviz import Source
from IPython.display import display
from ipywidgets import interactive
# load dataset
data = load_wine()
# feature matrix
X = data.data
# target vector
y = data.target
# class labels
labels = data.feature_names
def plot_tree(crit, split, depth, min_split, min_leaf=0.2):
estimator = DecisionTreeClassifier(random_state = 0
, criterion = crit
, splitter = split
, max_depth = depth
, min_samples_split=min_split
, min_samples_leaf=min_leaf)
estimator.fit(X, y)
graph = Source(tree.export_graphviz(estimator
, out_file=None
, feature_names=labels
, class_names=['0', '1', '2']
, filled = True))

display(SVG(graph.pipe(format='svg')))

return estimatorinter=interactive(plot_tree
, crit = ["gini", "entropy"]
, split = ["best", "random"]
, depth=[1,2,3,4]
, min_split=(0.1,1)
, min_leaf=(0.1,0.5))
display(inter)

Hướng dẫn interactive decision tree python

Initial view of widget

In this example, we expose the following parameters:

  • criterion: measure of the quality of split at the nodes
  • splitter: the split strategy at each node
  • max_depth: the maximum depth of the tree
  • min_samples_split: the minimum required number of instances in a node
  • min_samples_leaf: the minimum required number of instances at a leaf node

The last two parameters can be set either as integers or floats. Floats are interpreted as percentages of the total number of instances. For more details on the parameters you can read the sklearn class documentation.

Widget demonstration

This interactive widget allows us to modify the tree parameters and see the plot change dynamically. Through this interaction we are able to get a grasp of the effect of each parameter, by revealing the resulting change at each step.

Although this is not a tool for model performance assessment or parameter tuning, it has several benefits. It can serve as a means of assessing the complexity of our model, through the inspection of depth, number of nodes and purity of leaves. On the other hand, it can give us useful insights on the data, as we see how many and which features the tree has used. In addition, we might be able to discover conditions that clearly distinguish our samples into the different classes.

In conclusion, I find this interactive visualization a fun tool to get a deeper understanding of the abstract process of building a decision tree, detached from a particular data set, that will give us a head start next time we build a decision tree for one of our projects!