Module 5 — Neural Networks & Deep Learning intermediate 35 min
Building Your First Neural Network
The Project: Classify Handwritten Digits
The MNIST dataset is the “Hello World” of deep learning:
- 70,000 grayscale images (28×28 pixels)
- 10 classes: digits 0–9
- Goal: identify which digit is in each image
Complete Code: MNIST Classifier
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
# ── 0. Device ──────────────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using: {device}")
# ── 1. Load Data ───────────────────────────────────────────────
transform = transforms.Compose([
transforms.ToTensor(), # PIL Image → Tensor [0,1]
transforms.Normalize((0.1307,), (0.3081,)), # normalize (mean, std of MNIST)
])
train_dataset = datasets.MNIST("./data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST("./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)
print(f"Training samples: {len(train_dataset)}") # 60,000
print(f"Test samples: {len(test_dataset)}") # 10,000
# ── 2. Visualize Samples ───────────────────────────────────────
sample_images, sample_labels = next(iter(train_loader))
fig, axes = plt.subplots(2, 8, figsize=(16, 4))
for i, ax in enumerate(axes.flat):
ax.imshow(sample_images[i].squeeze(), cmap="gray")
ax.set_title(f"Label: {sample_labels[i].item()}")
ax.axis("off")
plt.tight_layout()
plt.show()
# ── 3. Build the Model ─────────────────────────────────────────
class MNISTClassifier(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten() # [batch, 1, 28, 28] → [batch, 784]
self.fc1 = nn.Linear(784, 256)
self.bn1 = nn.BatchNorm1d(256) # batch normalization
self.fc2 = nn.Linear(256, 128)
self.bn2 = nn.BatchNorm1d(128)
self.fc3 = nn.Linear(128, 10) # 10 classes
self.dropout = nn.Dropout(0.3)
def forward(self, x):
x = self.flatten(x)
x = self.fc1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.bn2(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc3(x) # raw logits (CrossEntropyLoss handles softmax)
return x
model = MNISTClassifier().to(device)
print(model)
print(f"\nParameters: {sum(p.numel() for p in model.parameters()):,}")
# ── 4. Loss and Optimizer ──────────────────────────────────────
criterion = nn.CrossEntropyLoss() # combines softmax + log + NLL
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
# ── 5. Training Loop ───────────────────────────────────────────
def train_epoch(model, loader, optimizer, criterion, device):
model.train()
total_loss, correct, total = 0, 0, 0
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad() # clear gradients
outputs = model(images) # forward pass
loss = criterion(outputs, labels) # compute loss
loss.backward() # backward pass
optimizer.step() # update weights
total_loss += loss.item()
preds = outputs.argmax(dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
return total_loss / len(loader), correct / total
def evaluate(model, loader, criterion, device):
model.eval() # disable dropout, batchnorm in eval mode
total_loss, correct, total = 0, 0, 0
with torch.no_grad(): # no gradient tracking needed
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
total_loss += loss.item()
preds = outputs.argmax(dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
return total_loss / len(loader), correct / total
# ── 6. Run Training ────────────────────────────────────────────
NUM_EPOCHS = 15
history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}
for epoch in range(1, NUM_EPOCHS + 1):
train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
val_loss, val_acc = evaluate(model, test_loader, criterion, device)
scheduler.step()
history["train_loss"].append(train_loss)
history["train_acc"].append(train_acc)
history["val_loss"].append(val_loss)
history["val_acc"].append(val_acc)
print(f"Epoch {epoch:2d}/{NUM_EPOCHS} | "
f"Train: loss={train_loss:.4f}, acc={train_acc:.2%} | "
f"Val: loss={val_loss:.4f}, acc={val_acc:.2%}")
# ── 7. Plot Training Curves ────────────────────────────────────
epochs = range(1, NUM_EPOCHS + 1)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
ax1.plot(epochs, history["train_loss"], label="Train", marker="o")
ax1.plot(epochs, history["val_loss"], label="Val", marker="s")
ax1.set_title("Loss"); ax1.set_xlabel("Epoch"); ax1.legend(); ax1.grid(alpha=0.3)
ax2.plot(epochs, history["train_acc"], label="Train", marker="o")
ax2.plot(epochs, history["val_acc"], label="Val", marker="s")
ax2.set_title("Accuracy"); ax2.set_xlabel("Epoch"); ax2.legend(); ax2.grid(alpha=0.3)
plt.tight_layout()
plt.savefig("training_curves.png", dpi=150)
plt.show()
# ── 8. Save Model ──────────────────────────────────────────────
torch.save(model.state_dict(), "mnist_model.pth")
print("Model saved!")
# ── 9. Load and Inference ──────────────────────────────────────
model_loaded = MNISTClassifier().to(device)
model_loaded.load_state_dict(torch.load("mnist_model.pth"))
model_loaded.eval()
# Predict on a single image
single_img, single_label = test_dataset[42] # grab one test sample
with torch.no_grad():
logit = model_loaded(single_img.unsqueeze(0).to(device))
prob = F.softmax(logit, dim=1)
pred = prob.argmax(dim=1).item()
print(f"True label: {single_label} | Predicted: {pred} | Confidence: {prob.max():.2%}")
What Each Component Does
| Component | Role |
|---|---|
nn.Flatten() | Converts 28×28 image → 784-vector |
nn.Linear(in, out) | Fully connected layer |
nn.BatchNorm1d | Normalizes activations → faster, more stable training |
nn.Dropout(p) | Randomly zeros p fraction of neurons → prevents overfitting |
nn.CrossEntropyLoss() | Loss for multi-class classification (includes softmax) |
torch.optim.Adam | Adaptive learning rate optimizer |
model.train() | Enable dropout / batchnorm in training mode |
model.eval() | Disable dropout / batchnorm in evaluation mode |
torch.no_grad() | Skip gradient computation during inference (saves memory) |
Expected Results
With 15 epochs:
- Train accuracy: ~99.5%
- Validation accuracy: ~98.5%
This is state-of-the-art for a simple fully-connected network on MNIST. CNNs (covered in Module 7) can push this above 99.7%.
Why do we call `model.eval()` before evaluating on the test set?