You're offline — showing cached content
Module 3 — Data Exploration & Visualization beginner 28 min

Matplotlib & Seaborn

Why Visualize?

“A picture is worth a thousand numbers.”

Before building any ML model, plot your data. You’ll discover:

  • Distributions (is it normal? skewed?)
  • Outliers
  • Correlations between features
  • Class imbalances in classification problems

Matplotlib — The Foundation

import matplotlib.pyplot as plt
import numpy as np

# Simple line plot
x = np.linspace(0, 2 * np.pi, 100)
y = np.sin(x)

plt.figure(figsize=(10, 4))
plt.plot(x, y, color="cornflowerblue", linewidth=2, label="sin(x)")
plt.title("Sine Wave", fontsize=14)
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

Training Loss / Accuracy Curves

This is the first plot you make in every deep learning project:

# Simulated training history
epochs = range(1, 31)
train_loss = [1.0 * (0.9 ** i) for i in range(30)]
val_loss   = [1.0 * (0.88 ** i) + np.random.normal(0, 0.02) for i in range(30)]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Loss
ax1.plot(epochs, train_loss, label="Train Loss", color="royalblue")
ax1.plot(epochs, val_loss,   label="Val Loss",   color="tomato", linestyle="--")
ax1.set_title("Training & Validation Loss")
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax1.legend()
ax1.grid(True, alpha=0.3)

# Accuracy (simulated)
train_acc = [1 - l * 0.5 for l in train_loss]
val_acc   = [1 - l * 0.48 for l in val_loss]

ax2.plot(epochs, train_acc, label="Train Accuracy", color="royalblue")
ax2.plot(epochs, val_acc,   label="Val Accuracy",   color="tomato", linestyle="--")
ax2.set_title("Accuracy")
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Accuracy")
ax2.set_ylim([0, 1])
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("training_curves.png", dpi=150, bbox_inches="tight")
plt.show()

Histogram — Distribution of Values

data = np.random.normal(loc=50, scale=15, size=1000)

plt.figure(figsize=(8, 5))
plt.hist(data, bins=30, color="mediumpurple", edgecolor="white", alpha=0.8)
plt.axvline(np.mean(data), color="red", linestyle="--", label=f"mean={np.mean(data):.1f}")
plt.axvline(np.median(data), color="orange", linestyle="--", label=f"median={np.median(data):.1f}")
plt.title("Distribution of Scores")
plt.xlabel("Score")
plt.ylabel("Frequency")
plt.legend()
plt.show()

Scatter Plot — Relationships Between Variables

# Generate correlated data
np.random.seed(42)
area   = np.random.normal(1500, 400, 200)
price  = area * 200 + np.random.normal(0, 50000, 200)

plt.figure(figsize=(8, 6))
plt.scatter(area, price, alpha=0.5, color="teal", edgecolors="white", linewidths=0.5)
plt.title("House Area vs Price")
plt.xlabel("Area (sqft)")
plt.ylabel("Price ($)")
plt.tight_layout()
plt.show()

Seaborn — Beautiful Statistical Plots

Seaborn wraps matplotlib with a higher-level API and better defaults.

import seaborn as sns
import pandas as pd

# Use the built-in Titanic dataset
df = sns.load_dataset("titanic")

Distribution Plot

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Distribution of ages
sns.histplot(df["age"].dropna(), kde=True, bins=30, ax=axes[0], color="steelblue")
axes[0].set_title("Age Distribution")

# Box plot — compare distributions by group
sns.boxplot(data=df, x="class", y="age", ax=axes[1], palette="viridis")
axes[1].set_title("Age by Passenger Class")

plt.tight_layout()
plt.show()

Count Plot — Categorical Frequency

plt.figure(figsize=(8, 5))
sns.countplot(data=df, x="class", hue="survived", palette=["tomato", "mediumseagreen"])
plt.title("Survival Count by Class")
plt.xlabel("Passenger Class")
plt.ylabel("Count")
plt.legend(title="Survived", labels=["No", "Yes"])
plt.show()

Correlation Heatmap ⭐ (Very Important in ML)

# Select numeric columns
numeric_df = df[["age", "fare", "sibsp", "parch"]].dropna()
corr_matrix = numeric_df.corr()

plt.figure(figsize=(8, 6))
sns.heatmap(
    corr_matrix,
    annot=True,      # show numbers
    fmt=".2f",
    cmap="coolwarm",
    center=0,
    square=True,
    linewidths=0.5,
)
plt.title("Feature Correlation Matrix")
plt.show()

Values close to 1 = strong positive correlation, -1 = strong negative, 0 = no correlation.

Pair Plot — All Pairwise Relationships

iris = sns.load_dataset("iris")

g = sns.pairplot(
    iris,
    hue="species",
    palette="Set2",
    diag_kind="kde",
    plot_kws={"alpha": 0.6},
)
g.fig.suptitle("Iris Dataset — Pairwise Feature Relationships", y=1.02)
plt.show()

Key Plots for ML Projects

PlotWhen to use
Line plotTraining/validation loss curves
HistogramFeature distributions
Box plotSpotting outliers, comparing groups
Scatter plotFeature vs target, clusters
HeatmapFeature correlations
Count plotClass balance in classification
Pair plotExplore all feature pairs at once

Quick Style Settings

# Set a clean style globally
plt.style.use("dark_background")   # dark theme
# plt.style.use("seaborn-v0_8")    # seaborn light theme

# Seaborn context (scales fonts)
sns.set_context("notebook", font_scale=1.2)

# Figure defaults
plt.rcParams.update({
    "figure.figsize": (10, 6),
    "figure.dpi": 100,
    "font.size": 12,
})
Knowledge Check

You're building a classification model and want to check if your classes are imbalanced (e.g., 95% 'no cancer', 5% 'cancer'). Which chart is best for this?