Module 3 — Data Exploration & Visualization beginner 28 min
Matplotlib & Seaborn
Why Visualize?
“A picture is worth a thousand numbers.”
Before building any ML model, plot your data. You’ll discover:
- Distributions (is it normal? skewed?)
- Outliers
- Correlations between features
- Class imbalances in classification problems
Matplotlib — The Foundation
import matplotlib.pyplot as plt
import numpy as np
# Simple line plot
x = np.linspace(0, 2 * np.pi, 100)
y = np.sin(x)
plt.figure(figsize=(10, 4))
plt.plot(x, y, color="cornflowerblue", linewidth=2, label="sin(x)")
plt.title("Sine Wave", fontsize=14)
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
Training Loss / Accuracy Curves
This is the first plot you make in every deep learning project:
# Simulated training history
epochs = range(1, 31)
train_loss = [1.0 * (0.9 ** i) for i in range(30)]
val_loss = [1.0 * (0.88 ** i) + np.random.normal(0, 0.02) for i in range(30)]
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
# Loss
ax1.plot(epochs, train_loss, label="Train Loss", color="royalblue")
ax1.plot(epochs, val_loss, label="Val Loss", color="tomato", linestyle="--")
ax1.set_title("Training & Validation Loss")
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax1.legend()
ax1.grid(True, alpha=0.3)
# Accuracy (simulated)
train_acc = [1 - l * 0.5 for l in train_loss]
val_acc = [1 - l * 0.48 for l in val_loss]
ax2.plot(epochs, train_acc, label="Train Accuracy", color="royalblue")
ax2.plot(epochs, val_acc, label="Val Accuracy", color="tomato", linestyle="--")
ax2.set_title("Accuracy")
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Accuracy")
ax2.set_ylim([0, 1])
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("training_curves.png", dpi=150, bbox_inches="tight")
plt.show()
Histogram — Distribution of Values
data = np.random.normal(loc=50, scale=15, size=1000)
plt.figure(figsize=(8, 5))
plt.hist(data, bins=30, color="mediumpurple", edgecolor="white", alpha=0.8)
plt.axvline(np.mean(data), color="red", linestyle="--", label=f"mean={np.mean(data):.1f}")
plt.axvline(np.median(data), color="orange", linestyle="--", label=f"median={np.median(data):.1f}")
plt.title("Distribution of Scores")
plt.xlabel("Score")
plt.ylabel("Frequency")
plt.legend()
plt.show()
Scatter Plot — Relationships Between Variables
# Generate correlated data
np.random.seed(42)
area = np.random.normal(1500, 400, 200)
price = area * 200 + np.random.normal(0, 50000, 200)
plt.figure(figsize=(8, 6))
plt.scatter(area, price, alpha=0.5, color="teal", edgecolors="white", linewidths=0.5)
plt.title("House Area vs Price")
plt.xlabel("Area (sqft)")
plt.ylabel("Price ($)")
plt.tight_layout()
plt.show()
Seaborn — Beautiful Statistical Plots
Seaborn wraps matplotlib with a higher-level API and better defaults.
import seaborn as sns
import pandas as pd
# Use the built-in Titanic dataset
df = sns.load_dataset("titanic")
Distribution Plot
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Distribution of ages
sns.histplot(df["age"].dropna(), kde=True, bins=30, ax=axes[0], color="steelblue")
axes[0].set_title("Age Distribution")
# Box plot — compare distributions by group
sns.boxplot(data=df, x="class", y="age", ax=axes[1], palette="viridis")
axes[1].set_title("Age by Passenger Class")
plt.tight_layout()
plt.show()
Count Plot — Categorical Frequency
plt.figure(figsize=(8, 5))
sns.countplot(data=df, x="class", hue="survived", palette=["tomato", "mediumseagreen"])
plt.title("Survival Count by Class")
plt.xlabel("Passenger Class")
plt.ylabel("Count")
plt.legend(title="Survived", labels=["No", "Yes"])
plt.show()
Correlation Heatmap ⭐ (Very Important in ML)
# Select numeric columns
numeric_df = df[["age", "fare", "sibsp", "parch"]].dropna()
corr_matrix = numeric_df.corr()
plt.figure(figsize=(8, 6))
sns.heatmap(
corr_matrix,
annot=True, # show numbers
fmt=".2f",
cmap="coolwarm",
center=0,
square=True,
linewidths=0.5,
)
plt.title("Feature Correlation Matrix")
plt.show()
Values close to 1 = strong positive correlation, -1 = strong negative, 0 = no correlation.
Pair Plot — All Pairwise Relationships
iris = sns.load_dataset("iris")
g = sns.pairplot(
iris,
hue="species",
palette="Set2",
diag_kind="kde",
plot_kws={"alpha": 0.6},
)
g.fig.suptitle("Iris Dataset — Pairwise Feature Relationships", y=1.02)
plt.show()
Key Plots for ML Projects
| Plot | When to use |
|---|---|
| Line plot | Training/validation loss curves |
| Histogram | Feature distributions |
| Box plot | Spotting outliers, comparing groups |
| Scatter plot | Feature vs target, clusters |
| Heatmap | Feature correlations |
| Count plot | Class balance in classification |
| Pair plot | Explore all feature pairs at once |
Quick Style Settings
# Set a clean style globally
plt.style.use("dark_background") # dark theme
# plt.style.use("seaborn-v0_8") # seaborn light theme
# Seaborn context (scales fonts)
sns.set_context("notebook", font_scale=1.2)
# Figure defaults
plt.rcParams.update({
"figure.figsize": (10, 6),
"figure.dpi": 100,
"font.size": 12,
})
You're building a classification model and want to check if your classes are imbalanced (e.g., 95% 'no cancer', 5% 'cancer'). Which chart is best for this?