You're offline — showing cached content
Module 4 — Classical Machine Learning intermediate 22 min

Clustering with K-Means

What is Clustering?

Clustering groups similar data points together — without any labels. The algorithm finds the natural structure in your data.

Use clustering for:

  • Customer segmentation
  • Document topic discovery
  • Anomaly detection
  • Gene grouping in biology
  • Image compression

K-Means Algorithm

K-Means is the most popular clustering algorithm. Here’s how it works:

1. Choose K (number of clusters)
2. Randomly place K centroids
3. Assign each point to the nearest centroid
4. Move each centroid to the mean of its assigned points
5. Repeat steps 3-4 until centroids stop moving
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.datasets import make_blobs
from sklearn.preprocessing import StandardScaler

# ── Generate sample data ──────────────────────────────────────
X, y_true = make_blobs(
    n_samples=300,
    centers=4,
    cluster_std=0.7,
    random_state=42,
)

# ── Fit K-Means ───────────────────────────────────────────────
kmeans = KMeans(
    n_clusters=4,
    init="k-means++",  # smarter initialization
    n_init=10,         # try 10 random starts, pick best
    max_iter=300,
    random_state=42,
)
kmeans.fit(X)

labels = kmeans.labels_            # cluster assignment for each point
centers = kmeans.cluster_centers_  # coordinates of centroids
inertia = kmeans.inertia_          # sum of squared distances to centroids

print(f"Inertia: {inertia:.2f}")

# ── Visualize ─────────────────────────────────────────────────
plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.scatter(X[:, 0], X[:, 1], c="gray", alpha=0.5, s=20)
plt.title("Before Clustering")

plt.subplot(1, 2, 2)
scatter = plt.scatter(X[:, 0], X[:, 1], c=labels, cmap="tab10", alpha=0.7, s=20)
plt.scatter(centers[:, 0], centers[:, 1], c="red", marker="X", s=200,
            zorder=5, label="Centroids")
plt.title("After K-Means (K=4)")
plt.legend()

plt.tight_layout()
plt.show()

The Elbow Method — Choosing K

The hardest part of K-Means: choosing the right number of clusters.

# Try K from 1 to 10 and plot inertia
inertias = []
K_range = range(1, 11)

for k in K_range:
    km = KMeans(n_clusters=k, n_init=10, random_state=42)
    km.fit(X)
    inertias.append(km.inertia_)

plt.figure(figsize=(8, 5))
plt.plot(K_range, inertias, marker="o", color="royalblue", linewidth=2)
plt.axvline(4, color="red", linestyle="--", alpha=0.6, label="Optimal K=4")
plt.xlabel("Number of Clusters (K)")
plt.ylabel("Inertia")
plt.title("Elbow Method for Optimal K")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

Look for the “elbow” — the point where adding more clusters gives little improvement. At that point, the inertia stops dropping steeply.


Silhouette Score — Quality Metric

from sklearn.metrics import silhouette_score

# Score = how well each point fits its own cluster vs others
# Range: -1 (wrong cluster) to 1 (perfect)
score = silhouette_score(X, labels)
print(f"Silhouette Score: {score:.4f}")  # ~0.70+

# Compare across different K values
for k in range(2, 8):
    km = KMeans(n_clusters=k, n_init=10, random_state=42)
    sil = silhouette_score(X, km.fit_predict(X))
    print(f"K={k}: Silhouette={sil:.4f}")

Real-World Example: Customer Segmentation

import pandas as pd

# Simulated customer data
np.random.seed(42)
n = 500

customers = pd.DataFrame({
    "recency":   np.abs(np.random.normal(30, 20, n)),   # days since last purchase
    "frequency": np.abs(np.random.normal(5, 3, n)),     # purchases per month
    "monetary":  np.abs(np.random.normal(200, 100, n)), # avg spend
})

# Scale features (important! K-Means is distance-based)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(customers)

# Cluster into 4 customer segments
km = KMeans(n_clusters=4, n_init=10, random_state=42)
customers["segment"] = km.fit_predict(X_scaled)

# Analyze segments
print(customers.groupby("segment").mean().round(1))
#           recency  frequency  monetary
# segment
# 0           8.3       9.1      315.2  ← Champions (recent, frequent, big spenders)
# 1          55.2       2.2       95.8  ← At Risk (not recent, low frequency)
# 2          25.6       5.4      198.7  ← Loyal customers
# 3          40.1       1.2       42.3  ← Lost customers

Limitations of K-Means

  • You must specify K in advance
  • Assumes clusters are roughly spherical
  • Sensitive to outliers
  • Doesn’t work well with very different cluster densities

Alternatives: DBSCAN (finds arbitrary shapes), Hierarchical clustering, Gaussian Mixture Models

Knowledge Check

You run K-Means with K=3 on a dataset and get a silhouette score of 0.65. You then try K=4 and get 0.81. What should you do?