Module 7 — Computer Vision advanced 30 min
Transfer Learning
The Core Idea
Training a CNN from scratch on CIFAR-10 required 30 epochs and ~77% accuracy. What if you only have 200 images of rare flower species?
Transfer learning solves this: take a model already trained on 1.2 million images (ImageNet), and adapt it to your task.
ImageNet training: 1,000 classes, 14M images, weeks of GPU time
→ Model learns universal visual features
→ You take these features and add YOUR classifier on top
→ Fine-tune for a few hours on YOUR data
Available Pretrained Models
import torchvision.models as models
# ResNet family — classic, reliable
resnet18 = models.resnet18(weights="IMAGENET1K_V1") # 11M params
resnet50 = models.resnet50(weights="IMAGENET1K_V2") # 25M params
# EfficientNet — accuracy/efficiency sweet spot
effnet_b0 = models.efficientnet_b0(weights="IMAGENET1K_V1") # 5M params
effnet_b4 = models.efficientnet_b4(weights="IMAGENET1K_V1") # 19M params
# Vision Transformer — modern, excellent accuracy
vit_b16 = models.vit_b_16(weights="IMAGENET1K_V1") # 86M params
# MobileNet — for mobile/embedded devices
mobilenet = models.mobilenet_v3_small(weights="IMAGENET1K_V1") # 2.5M params
# Quick ImageNet inference (1000-class)
import torch
from PIL import Image
from torchvision import transforms
model = models.resnet50(weights="IMAGENET1K_V2").eval()
preprocess = transforms.Compose([
transforms.Resize(256), transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
img = preprocess(Image.open("dog.jpg")).unsqueeze(0)
with torch.no_grad():
logits = model(img)
probs = torch.softmax(logits, dim=1)
top5 = torch.topk(probs, 5)
import json
labels = json.load(open("imagenet_labels.json"))
for prob, idx in zip(top5.values[0], top5.indices[0]):
print(f"{labels[str(idx.item())]:30s} {prob:.2%}")
# golden retriever 94.3%
# Labrador retriever 4.1%
# ...
Strategy 1: Feature Extraction (Freeze All)
Best when: small dataset (<1000 samples) or dataset very similar to ImageNet.
Keep the pretrained layers frozen, add a new classifier head:
import torch
import torch.nn as nn
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader
# --- Model: Freeze everything, replace head ---
def make_feature_extractor(num_classes):
model = models.resnet18(weights="IMAGENET1K_V1")
# Freeze ALL layers
for param in model.parameters():
param.requires_grad = False
# Replace final layer only — this is the only thing that trains!
in_features = model.fc.in_features # 512 for resnet18
model.fc = nn.Sequential(
nn.Dropout(0.3),
nn.Linear(in_features, num_classes)
)
# model.fc parameters have requires_grad=True by default (new layer)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable: {trainable:,} / {total:,} ({trainable/total:.1%})")
# Trainable: 517,130 / 11,706,378 (4.4%)
return model
Strategy 2: Full Fine-tuning
Best when: medium-sized dataset (1K–100K+) or domain differs from ImageNet.
Unfreeze all layers, but use a smaller learning rate for pretrained layers:
def make_finetune_model(num_classes):
model = models.efficientnet_b0(weights="IMAGENET1K_V1")
# Replace classifier head
in_features = model.classifier[1].in_features # 1280
model.classifier = nn.Sequential(
nn.Dropout(0.4),
nn.Linear(in_features, num_classes)
)
return model
def get_optimizer_with_differential_lr(model):
"""Different learning rates for different parts of the model."""
# Pretrained backbone: very small lr (don't destroy existing knowledge)
# New head: normal lr (train from scratch)
params = [
{"params": model.features.parameters(), "lr": 1e-5},
{"params": model.classifier.parameters(), "lr": 1e-3},
]
return torch.optim.AdamW(params, weight_decay=1e-4)
Complete Fine-tuning Example
Let’s classify 5 flower species with only ~100 images per class:
import torch, torchvision, os
import torch.nn as nn
from torchvision import models, transforms, datasets
from torch.utils.data import DataLoader, random_split
# Setup
device = "cuda" if torch.cuda.is_available() else "cpu"
NUM_CLASSES = 5 # daisy, dandelion, rose, sunflower, tulip
# Data transforms
train_tf = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.3, 0.3, 0.3),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
val_tf = transforms.Compose([
transforms.Resize(256), transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
full = datasets.ImageFolder("flowers/", transform=train_tf)
n_val = int(0.2 * len(full))
train_ds, val_ds = random_split(full, [len(full)-n_val, n_val])
val_ds.dataset.transform = val_tf # apply val transforms to val split
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=4)
# Model
model = models.efficientnet_b0(weights="IMAGENET1K_V1")
model.classifier = nn.Sequential(
nn.Dropout(0.4),
nn.Linear(1280, NUM_CLASSES)
)
model = model.to(device)
optimizer = torch.optim.AdamW([
{"params": model.features.parameters(), "lr": 1e-5},
{"params": model.classifier.parameters(), "lr": 1e-3},
], weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=[1e-4, 1e-2],
epochs=15, steps_per_epoch=len(train_loader)
)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
best_acc = 0.0
for epoch in range(1, 16):
model.train()
for imgs, labels in train_loader:
imgs, labels = imgs.to(device), labels.to(device)
optimizer.zero_grad()
loss = criterion(model(imgs), labels)
loss.backward()
optimizer.step()
scheduler.step()
model.eval(); correct = total = 0
with torch.no_grad():
for imgs, labels in val_loader:
imgs, labels = imgs.to(device), labels.to(device)
correct += (model(imgs).argmax(1) == labels).sum().item()
total += labels.size(0)
acc = correct / total
print(f"Epoch {epoch:2d} | Val Acc: {acc:.1%}")
if acc > best_acc:
best_acc = acc
torch.save(model.state_dict(), "best_flowers.pt")
print(f"\nBest accuracy: {best_acc:.1%}")
# Epoch 15 | Val Acc: 91.3% ← from ~100 images per class!
Strategy Cheat Sheet
| Your Dataset | Domain | Recommended Strategy |
|---|---|---|
| < 100 images | Similar to ImageNet | Feature extraction only |
| < 1K images | Similar to ImageNet | Feature extraction + small lr finetune |
| 1K–100K images | Similar to ImageNet | Full fine-tuning (differential lr) |
| Any size | Very different (e.g., medical X-rays) | Full fine-tuning from scratch or domain-pretrained model |
| Millions of images | Any | Consider training from scratch |
When fine-tuning a pretrained model, why do we use a smaller learning rate for the pretrained layers than for the new classification head?