You're offline — showing cached content
Module 4 — Classical Machine Learning intermediate 28 min

Decision Trees

What is a Decision Tree?

A decision tree is exactly what it sounds like — a tree of yes/no questions that lead to a prediction.

                     Is Feature A > 5?
                    /                  \
                Yes                    No
               /                         \
    Is Feature B > 3?               Predict: Class 0
       /           \
    Yes              No
     /                 \
Predict: Class 1    Predict: Class 2  

Every internal node = a question about a feature
Every leaf node = a prediction


How Does It Learn?

The tree learns by finding the best splits that separate classes as cleanly as possible.

It uses Gini Impurity (or Entropy) to measure how mixed a group is:

Gini=1i=1kpi2Gini = 1 - \sum_{i=1}^{k} p_i^2

  • Gini = 0 → perfectly pure (all one class) ✅
  • Gini = 0.5 → maximally impure (50/50 mix) ❌

The algorithm tries every possible split and picks the one that reduces impurity the most.


Decision Tree in scikit-learn

from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt

# ── Data ──────────────────────────────────────────────────────
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(
    iris.data, iris.target, test_size=0.2, random_state=42
)

# ── Train ─────────────────────────────────────────────────────
tree = DecisionTreeClassifier(
    max_depth=3,         # limit depth to prevent overfitting
    min_samples_split=5, # need at least 5 samples to split
    random_state=42,
)
tree.fit(X_train, y_train)

# ── Evaluate ──────────────────────────────────────────────────
y_pred = tree.predict(X_test)
print(f"Test accuracy: {accuracy_score(y_test, y_pred):.2%}")

# ── Visualize the tree ────────────────────────────────────────
plt.figure(figsize=(16, 8))
plot_tree(
    tree,
    feature_names=iris.feature_names,
    class_names=iris.target_names,
    filled=True,
    rounded=True,
    fontsize=10,
)
plt.title("Decision Tree (max_depth=3)")
plt.tight_layout()
plt.show()

Feature Importance

Decision trees tell you which features mattered most:

import pandas as pd

importance = pd.Series(
    tree.feature_importances_,
    index=iris.feature_names
).sort_values(ascending=False)

print("Feature Importances:")
print(importance)
# petal length (cm)    0.56
# petal width (cm)     0.43
# sepal length (cm)    0.01
# sepal width (cm)     0.00

importance.plot(kind="barh", color="steelblue")
plt.title("Feature Importance")
plt.xlabel("Importance")
plt.show()

The Overfitting Problem

A key challenge with decision trees: they overfit easily.

import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification

X, y = make_classification(n_samples=500, n_features=10,
                           n_informative=5, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

depths = range(1, 20)
train_scores, test_scores = [], []

for d in depths:
    dt = DecisionTreeClassifier(max_depth=d, random_state=42)
    dt.fit(X_train, y_train)
    train_scores.append(dt.score(X_train, y_train))
    test_scores.append(dt.score(X_test, y_test))

import matplotlib.pyplot as plt
plt.figure(figsize=(10, 5))
plt.plot(depths, train_scores, label="Train Accuracy", marker="o")
plt.plot(depths, test_scores,  label="Test Accuracy",  marker="s")
plt.xlabel("Max Depth")
plt.ylabel("Accuracy")
plt.title("Overfitting in Decision Trees")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

At high depths:

  • Train accuracy → 100% (memorizes every sample)
  • Test accuracy → drops (doesn’t generalize)

This is called overfitting — the model learns the training data too well, including the noise.


Controlling Overfitting

# Key hyperparameters to tune
tree = DecisionTreeClassifier(
    max_depth=5,          # max depth of the tree
    min_samples_split=10, # min samples needed to create a split
    min_samples_leaf=5,   # min samples required in each leaf
    max_features=0.8,     # use 80% of features at each split
    max_leaf_nodes=20,    # maximum number of leaf nodes
    random_state=42,
)
HyperparameterEffectIncrease to…
max_depthTree sizeReduce overfitting
min_samples_splitSplit thresholdReduce overfitting
min_samples_leafLeaf sizeReduce overfitting
max_featuresFeature subsetReduce overfitting

Decision Trees for Regression

from sklearn.tree import DecisionTreeRegressor
from sklearn.datasets import fetch_california_housing
from sklearn.metrics import r2_score

housing = fetch_california_housing()
X_train, X_test, y_train, y_test = train_test_split(
    housing.data, housing.target, test_size=0.2, random_state=42
)

regressor = DecisionTreeRegressor(max_depth=6, random_state=42)
regressor.fit(X_train, y_train)
y_pred = regressor.predict(X_test)

print(f"R²: {r2_score(y_test, y_pred):.4f}")
Knowledge Check

Your decision tree achieves 99% accuracy on training data but only 65% on test data. What is this called?