TP2 - exercice 1 : le BA-BA des arbres de décisions avec scikit-learn

Dans cet exercice nous allons apprendre à manipuler la classe DecisionTreeClassifier du module tree qui permet de réaliser de la classification par la méthode des arbres de décision.

Nous travaillerons pour cela sur le jeu de données Iris que l'on peut charger à partir de scikit-learn et dont on trouve un descriptif sur Wikipedia.

Nous verrons comment construire un classifieur (l'appliquer pour obtenir des prédictions) et visualiser l'arbre de décision correspondant.

Question 1. Charger le jeu de données en utilisant la fonction load_iris du module datasets. En extraire le nombre d'observations, de descripteurs ainsi que le nombre de classes.

In [1]:
# generic imports #
#-----------------#
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
In [2]:
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier

iris = load_iris()
X = iris.data
y = iris.target
class_ids = iris.target_names

print('the Iris dataset is made of %d instances and %d features. There are %d classes : %s' 
      % (X.shape[0], X.shape[1], len(class_ids), " ; ".join(class_ids)))
the Iris dataset is made of 150 instances and 4 features. There are 3 classes : setosa ; versicolor ; virginica

Question 2. Découper le jeu de données en ensemble d'apprentissage (80%) et ensemble de test (20%), de manière stratifiée.

  • on utilisera la fonction train_test_split du module model_selection
  • le découpage est stratifié si on retrouve les mêmes proportions des différentes catégories dans les jeux d'apprentisage et de test : se référer à la documentation de la fonction train_test_split pour voir comment faire.
In [3]:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify = y, test_size = 0.2)

Question 3. Construire un arbre de décision à partir des données d'apprentissage en utilisant la classe DecisionTreeClassifier. On conservera tous les paramètres par défaut à l'exception de la profondeur maximale de l'arbre, qu'on fixera à 3.

  • Rappel: pour apprendre/construire un modèle dans scikit-learn il faut :
    1. l'instancier via le constructeur de la classe correspondante (en lui spécifiant la valeur des hyperparamètres, si besoin)
    2. appeler la méthode fit avec les données d'apprentissage $(X,y)$ en argument
In [4]:
tree_clf = DecisionTreeClassifier(max_depth=3)
tree_clf.fit(X_train,y_train)
Out[4]:
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=3,
            max_features=None, max_leaf_nodes=None,
            min_impurity_split=1e-07, min_samples_leaf=1,
            min_samples_split=2, min_weight_fraction_leaf=0.0,
            presort=False, random_state=None, splitter='best')

Question 4. Comparer les performances obtenues sur le jeu de test et sur le jeu d'apprentissage.

  • Rappel: pour réaliser la prédiction, il suffit d'appeler la méthode predict du classifier, avec les données de test comme argument.
  • On pourra simplement représenter les matrices de confusion correspondantes.
In [5]:
preds_test = tree_clf.predict(X_test)
preds_train = tree_clf.predict(X_train)

from sklearn.metrics import confusion_matrix

print("*** training data : confusion matrix ***")
print(confusion_matrix(y_train, preds_train))
print("*** test data: confusion matrix ***")
print(confusion_matrix(y_test,preds_test))
*** training data : confusion matrix ***
[[40  0  0]
 [ 0 38  2]
 [ 0  2 38]]
*** test data: confusion matrix ***
[[10  0  0]
 [ 0 10  0]
 [ 0  1  9]]

Question 5. Enfin, utiliser le code ci-dessous pour visualiser l'arbre obtenu.

  • NB : la sortie graphique proposée par scikit-learn s'appuie sur les outils GraphViz qui ne sont pas toujours bien compatibles avec Windows. Ce code devrait néanmoins fonctionner sous Linux ou Mac.
  • Dans ce code, l'objet tree_clf est l'objet abre de décision construit précédemment.
In [6]:
# re-fit on entire dataset 
tree_clf.fit(X,y)
# show tree
import graphviz 
dot_data = tree.export_graphviz(tree_clf, out_file=None) 
graph = graphviz.Source(dot_data) 
graph.render("iris")
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
<ipython-input-6-fd113526a6a7> in <module>()
      2 tree_clf.fit(X,y)
      3 # show tree
----> 4 import graphviz
      5 dot_data = tree.export_graphviz(tree_clf, out_file=None)
      6 graph = graphviz.Source(dot_data)

ModuleNotFoundError: No module named 'graphviz'
In [ ]: