You're offline — showing cached content
Module 5 — Neural Networks & Deep Learning intermediate 35 min

Building Your First Neural Network

The Project: Classify Handwritten Digits

The MNIST dataset is the “Hello World” of deep learning:

  • 70,000 grayscale images (28×28 pixels)
  • 10 classes: digits 0–9
  • Goal: identify which digit is in each image

Complete Code: MNIST Classifier

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

# ── 0. Device ──────────────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using: {device}")

# ── 1. Load Data ───────────────────────────────────────────────
transform = transforms.Compose([
    transforms.ToTensor(),                        # PIL Image → Tensor [0,1]
    transforms.Normalize((0.1307,), (0.3081,)),   # normalize (mean, std of MNIST)
])

train_dataset = datasets.MNIST("./data", train=True,  download=True, transform=transform)
test_dataset  = datasets.MNIST("./data", train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True,  num_workers=2)
test_loader  = DataLoader(test_dataset,  batch_size=64, shuffle=False, num_workers=2)

print(f"Training samples: {len(train_dataset)}")  # 60,000
print(f"Test samples:     {len(test_dataset)}")   # 10,000

# ── 2. Visualize Samples ───────────────────────────────────────
sample_images, sample_labels = next(iter(train_loader))
fig, axes = plt.subplots(2, 8, figsize=(16, 4))
for i, ax in enumerate(axes.flat):
    ax.imshow(sample_images[i].squeeze(), cmap="gray")
    ax.set_title(f"Label: {sample_labels[i].item()}")
    ax.axis("off")
plt.tight_layout()
plt.show()

# ── 3. Build the Model ─────────────────────────────────────────
class MNISTClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()       # [batch, 1, 28, 28] → [batch, 784]
        self.fc1 = nn.Linear(784, 256)
        self.bn1 = nn.BatchNorm1d(256)    # batch normalization
        self.fc2 = nn.Linear(256, 128)
        self.bn2 = nn.BatchNorm1d(128)
        self.fc3 = nn.Linear(128, 10)     # 10 classes
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        x = self.flatten(x)
        
        x = self.fc1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout(x)
        
        x = self.fc2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.dropout(x)
        
        x = self.fc3(x)  # raw logits (CrossEntropyLoss handles softmax)
        return x

model = MNISTClassifier().to(device)
print(model)
print(f"\nParameters: {sum(p.numel() for p in model.parameters()):,}")

# ── 4. Loss and Optimizer ──────────────────────────────────────
criterion = nn.CrossEntropyLoss()          # combines softmax + log + NLL
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

# ── 5. Training Loop ───────────────────────────────────────────
def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss, correct, total = 0, 0, 0
    
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()           # clear gradients
        outputs = model(images)         # forward pass
        loss = criterion(outputs, labels)  # compute loss
        loss.backward()                 # backward pass
        optimizer.step()                # update weights
        
        total_loss += loss.item()
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    
    return total_loss / len(loader), correct / total

def evaluate(model, loader, criterion, device):
    model.eval()   # disable dropout, batchnorm in eval mode
    total_loss, correct, total = 0, 0, 0
    
    with torch.no_grad():   # no gradient tracking needed
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    
    return total_loss / len(loader), correct / total

# ── 6. Run Training ────────────────────────────────────────────
NUM_EPOCHS = 15
history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}

for epoch in range(1, NUM_EPOCHS + 1):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc     = evaluate(model, test_loader, criterion, device)
    scheduler.step()
    
    history["train_loss"].append(train_loss)
    history["train_acc"].append(train_acc)
    history["val_loss"].append(val_loss)
    history["val_acc"].append(val_acc)
    
    print(f"Epoch {epoch:2d}/{NUM_EPOCHS} | "
          f"Train: loss={train_loss:.4f}, acc={train_acc:.2%} | "
          f"Val: loss={val_loss:.4f}, acc={val_acc:.2%}")

# ── 7. Plot Training Curves ────────────────────────────────────
epochs = range(1, NUM_EPOCHS + 1)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(epochs, history["train_loss"], label="Train", marker="o")
ax1.plot(epochs, history["val_loss"],   label="Val",   marker="s")
ax1.set_title("Loss"); ax1.set_xlabel("Epoch"); ax1.legend(); ax1.grid(alpha=0.3)

ax2.plot(epochs, history["train_acc"], label="Train", marker="o")
ax2.plot(epochs, history["val_acc"],   label="Val",   marker="s")
ax2.set_title("Accuracy"); ax2.set_xlabel("Epoch"); ax2.legend(); ax2.grid(alpha=0.3)

plt.tight_layout()
plt.savefig("training_curves.png", dpi=150)
plt.show()

# ── 8. Save Model ──────────────────────────────────────────────
torch.save(model.state_dict(), "mnist_model.pth")
print("Model saved!")

# ── 9. Load and Inference ──────────────────────────────────────
model_loaded = MNISTClassifier().to(device)
model_loaded.load_state_dict(torch.load("mnist_model.pth"))
model_loaded.eval()

# Predict on a single image
single_img, single_label = test_dataset[42]  # grab one test sample
with torch.no_grad():
    logit = model_loaded(single_img.unsqueeze(0).to(device))
    prob  = F.softmax(logit, dim=1)
    pred  = prob.argmax(dim=1).item()

print(f"True label: {single_label}  |  Predicted: {pred}  |  Confidence: {prob.max():.2%}")

What Each Component Does

ComponentRole
nn.Flatten()Converts 28×28 image → 784-vector
nn.Linear(in, out)Fully connected layer
nn.BatchNorm1dNormalizes activations → faster, more stable training
nn.Dropout(p)Randomly zeros p fraction of neurons → prevents overfitting
nn.CrossEntropyLoss()Loss for multi-class classification (includes softmax)
torch.optim.AdamAdaptive learning rate optimizer
model.train()Enable dropout / batchnorm in training mode
model.eval()Disable dropout / batchnorm in evaluation mode
torch.no_grad()Skip gradient computation during inference (saves memory)

Expected Results

With 15 epochs:

  • Train accuracy: ~99.5%
  • Validation accuracy: ~98.5%

This is state-of-the-art for a simple fully-connected network on MNIST. CNNs (covered in Module 7) can push this above 99.7%.

Knowledge Check

Why do we call `model.eval()` before evaluating on the test set?