Python : A Simple Decision Tree and Random Forest Example

2 minute read

Decision trees are a popular tool in machine learning. They take a form of a tree with sequential questions which leads down a certain route given an answer.

Tree models where the target variable can take a discrete values are called classification trees, whereas when target variable takes continuous values they are called regression trees.

The model takes the form “if this .. then that” conditions to get to a specific outcome. Tree depth is an important concept, it represents how many questions are asked before we reach our result.

Random forests are a collection of decision trees whose results are aggregated into one final result. They are a powerful tool due to their ability to limit over-fitting without substantially increasing error due to bias which is a common case when using decision trees.

We will apply both Decision trees and Random forests models to the famous “kyphosis” dataset. The objective is to determine important risk factors for kyphosis following surgery.

The data was collected on 83 patients undergoing corrective spinal surgery:

  • Kyphosis : with the value “absent” or “present” indicating if a kyphosis was present after the operation.

  • Age : the age in months

  • Number : the number of vertebrae involved

  • Start : the number of the first vertebra operated on

We will use Python in this post, here is the R version. So let´s dive in :).

Preparing the data

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
k_data = pd.read_csv('kyphosis.csv')
k_data.head()
Kyphosis Age Number Start
0 absent 71 3 5
1 absent 158 3 14
2 present 128 4 5
3 absent 2 5 1
4 absent 1 4 15
k_data.describe()
Age Number Start
count 81.000000 81.000000 81.000000
mean 83.654321 4.049383 11.493827
std 58.104251 1.619423 4.883962
min 1.000000 2.000000 1.000000
25% 26.000000 3.000000 9.000000
50% 87.000000 4.000000 13.000000
75% 130.000000 5.000000 16.000000
max 206.000000 10.000000 18.000000

We can see that there are 2 missing from the count. It turns out that cases 15 and 28 were removed.

k_data.isnull().values.any()
False

Exploring the Data

To get a quick overlook through our data, we can use the “pairplot” from “seaborn” library, to plot pairwise relationships between the different columns.

import seaborn as sns
sns.pairplot(k_data,hue='Kyphosis')
<seaborn.axisgrid.PairGrid at 0x7f43f4aaeb00>

png

Training and Testing the Data

# Y
Y = k_data['Kyphosis']
# X
X = k_data.drop('Kyphosis',axis=1)
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.33)

Decision Trees

# Training the model
from sklearn.tree import DecisionTreeClassifier
k_tree = DecisionTreeClassifier()
k_tree.fit(X_train,y_train)
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=None,
            splitter='best')
# Testing and Evaluating the model
k_predictions = k_tree.predict(X_test)
from sklearn.metrics import classification_report,confusion_matrix
print(classification_report(y_test,k_predictions))
             precision    recall  f1-score   support

     absent       0.81      0.77      0.79        22
    present       0.17      0.20      0.18         5

avg / total       0.69      0.67      0.68        27
print(confusion_matrix(y_test,k_predictions))
[[17  5]
 [ 4  1]]
# Visualize the model
import pydot
from IPython.display import Image  
from sklearn.externals.six import StringIO  
from sklearn.tree import export_graphviz
nms = list(k_data.columns[1:])
dot_data = StringIO()  
export_graphviz(k_tree, out_file=dot_data,feature_names=nms,filled=True,rounded=True)

graph = pydot.graph_from_dot_data(dot_data.getvalue())  
Image(graph[0].create_png())

png

Random Forests

from sklearn.ensemble import RandomForestClassifier
k_forest = RandomForestClassifier(n_estimators=200)
k_forest.fit(X_train, y_train)
RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=None, max_features='auto', max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, n_estimators=200, n_jobs=1,
            oob_score=False, random_state=None, verbose=0,
            warm_start=False)
kf_predictions = k_forest.predict(X_test)
print(confusion_matrix(y_test,kf_predictions))
[[18  4]
 [ 4  1]]
print(classification_report(y_test,kf_predictions))
             precision    recall  f1-score   support

     absent       0.82      0.82      0.82        22
    present       0.20      0.20      0.20         5

avg / total       0.70      0.70      0.70        27