You're offline — showing cached content
Module 7 — Computer Vision advanced 30 min

Transfer Learning

The Core Idea

Training a CNN from scratch on CIFAR-10 required 30 epochs and ~77% accuracy. What if you only have 200 images of rare flower species?

Transfer learning solves this: take a model already trained on 1.2 million images (ImageNet), and adapt it to your task.

ImageNet training: 1,000 classes, 14M images, weeks of GPU time
→ Model learns universal visual features
→ You take these features and add YOUR classifier on top
→ Fine-tune for a few hours on YOUR data

Available Pretrained Models

import torchvision.models as models

# ResNet family — classic, reliable
resnet18  = models.resnet18(weights="IMAGENET1K_V1")   # 11M params
resnet50  = models.resnet50(weights="IMAGENET1K_V2")   # 25M params

# EfficientNet — accuracy/efficiency sweet spot
effnet_b0 = models.efficientnet_b0(weights="IMAGENET1K_V1")  # 5M params
effnet_b4 = models.efficientnet_b4(weights="IMAGENET1K_V1")  # 19M params

# Vision Transformer — modern, excellent accuracy
vit_b16 = models.vit_b_16(weights="IMAGENET1K_V1")     # 86M params

# MobileNet — for mobile/embedded devices
mobilenet = models.mobilenet_v3_small(weights="IMAGENET1K_V1")  # 2.5M params

# Quick ImageNet inference (1000-class)
import torch
from PIL import Image
from torchvision import transforms

model = models.resnet50(weights="IMAGENET1K_V2").eval()
preprocess = transforms.Compose([
    transforms.Resize(256), transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

img = preprocess(Image.open("dog.jpg")).unsqueeze(0)
with torch.no_grad():
    logits = model(img)
    probs  = torch.softmax(logits, dim=1)
    top5   = torch.topk(probs, 5)

import json
labels = json.load(open("imagenet_labels.json"))
for prob, idx in zip(top5.values[0], top5.indices[0]):
    print(f"{labels[str(idx.item())]:30s} {prob:.2%}")
# golden retriever                  94.3%
# Labrador retriever                4.1%
# ...

Strategy 1: Feature Extraction (Freeze All)

Best when: small dataset (<1000 samples) or dataset very similar to ImageNet.

Keep the pretrained layers frozen, add a new classifier head:

import torch
import torch.nn as nn
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader

# --- Model: Freeze everything, replace head ---
def make_feature_extractor(num_classes):
    model = models.resnet18(weights="IMAGENET1K_V1")
    
    # Freeze ALL layers
    for param in model.parameters():
        param.requires_grad = False
    
    # Replace final layer only — this is the only thing that trains!
    in_features = model.fc.in_features      # 512 for resnet18
    model.fc = nn.Sequential(
        nn.Dropout(0.3),
        nn.Linear(in_features, num_classes)
    )
    # model.fc parameters have requires_grad=True by default (new layer)
    
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    print(f"Trainable: {trainable:,} / {total:,} ({trainable/total:.1%})")
    # Trainable: 517,130 / 11,706,378 (4.4%)
    
    return model

Strategy 2: Full Fine-tuning

Best when: medium-sized dataset (1K–100K+) or domain differs from ImageNet.

Unfreeze all layers, but use a smaller learning rate for pretrained layers:

def make_finetune_model(num_classes):
    model = models.efficientnet_b0(weights="IMAGENET1K_V1")
    
    # Replace classifier head
    in_features = model.classifier[1].in_features   # 1280
    model.classifier = nn.Sequential(
        nn.Dropout(0.4),
        nn.Linear(in_features, num_classes)
    )
    
    return model

def get_optimizer_with_differential_lr(model):
    """Different learning rates for different parts of the model."""
    # Pretrained backbone: very small lr (don't destroy existing knowledge)
    # New head: normal lr (train from scratch)
    params = [
        {"params": model.features.parameters(), "lr": 1e-5},
        {"params": model.classifier.parameters(), "lr": 1e-3},
    ]
    return torch.optim.AdamW(params, weight_decay=1e-4)

Complete Fine-tuning Example

Let’s classify 5 flower species with only ~100 images per class:

import torch, torchvision, os
import torch.nn as nn
from torchvision import models, transforms, datasets
from torch.utils.data import DataLoader, random_split

# Setup
device = "cuda" if torch.cuda.is_available() else "cpu"
NUM_CLASSES = 5  # daisy, dandelion, rose, sunflower, tulip

# Data transforms
train_tf = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.3, 0.3, 0.3),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
val_tf = transforms.Compose([
    transforms.Resize(256), transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

full = datasets.ImageFolder("flowers/", transform=train_tf)
n_val = int(0.2 * len(full))
train_ds, val_ds = random_split(full, [len(full)-n_val, n_val])
val_ds.dataset.transform = val_tf    # apply val transforms to val split

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True,  num_workers=4)
val_loader   = DataLoader(val_ds,   batch_size=32, shuffle=False, num_workers=4)

# Model
model = models.efficientnet_b0(weights="IMAGENET1K_V1")
model.classifier = nn.Sequential(
    nn.Dropout(0.4),
    nn.Linear(1280, NUM_CLASSES)
)
model = model.to(device)

optimizer = torch.optim.AdamW([
    {"params": model.features.parameters(), "lr": 1e-5},
    {"params": model.classifier.parameters(), "lr": 1e-3},
], weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=[1e-4, 1e-2],
    epochs=15, steps_per_epoch=len(train_loader)
)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

best_acc = 0.0
for epoch in range(1, 16):
    model.train()
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        loss = criterion(model(imgs), labels)
        loss.backward()
        optimizer.step()
        scheduler.step()
    
    model.eval(); correct = total = 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            correct += (model(imgs).argmax(1) == labels).sum().item()
            total += labels.size(0)
    acc = correct / total
    print(f"Epoch {epoch:2d} | Val Acc: {acc:.1%}")
    if acc > best_acc:
        best_acc = acc
        torch.save(model.state_dict(), "best_flowers.pt")

print(f"\nBest accuracy: {best_acc:.1%}")
# Epoch 15 | Val Acc: 91.3%  ← from ~100 images per class!

Strategy Cheat Sheet

Your DatasetDomainRecommended Strategy
< 100 imagesSimilar to ImageNetFeature extraction only
< 1K imagesSimilar to ImageNetFeature extraction + small lr finetune
1K–100K imagesSimilar to ImageNetFull fine-tuning (differential lr)
Any sizeVery different (e.g., medical X-rays)Full fine-tuning from scratch or domain-pretrained model
Millions of imagesAnyConsider training from scratch
Knowledge Check

When fine-tuning a pretrained model, why do we use a smaller learning rate for the pretrained layers than for the new classification head?