Transfer Learning
Học chuyển giao
Bạn cần phân loại 10 loại bệnh lá cây. Chỉ có 500 ảnh. Train CNN từ đầu cần hàng triệu ảnh. Làm sao?
Hình minh họa
Transfer learning trên một CNN pretrained ImageNet
Chọn chiến lược để xem lớp nào đóng băng (xám) và lớp nào được huấn luyện (cam).
Kết hợp tốt giữa tốc độ và linh hoạt. Dùng learning rate rất nhỏ cho block pretrained để không phá trọng số.
Đường cong accuracy (mô phỏng 10 epochs)
Feature extraction (xanh dương) đạt cao nguyên sớm vì backbone cố định. Fine-tune 3 lớp cuối (xanh lá) linh hoạt và thường là cân bằng tốt nhất. Full fine-tune (cam) cần nhiều dữ liệu hơn nhưng trần cao nhất.
So sánh nhanh
| Chiến lược | Dữ liệu cần | Thời gian | Rủi ro overfit | Val acc (mô phỏng) |
|---|---|---|---|---|
| Feature extraction | Rất ít (vài trăm – 1K ảnh) | Vài phút trên 1 GPU | Thấp | 80% (loss 0.58) |
| Fine-tune 3 lớp cuối | Trung bình (1K – 10K ảnh) | 1–2 giờ trên 1 GPU | Trung bình | 91% (loss 0.30) |
| Fine-tune toàn bộ | Nhiều (≥10K ảnh khuyến nghị) | Nhiều giờ đến vài ngày | Cao | 94% (loss 0.22) |
Transfer Learning = tận dụng kiến thức đã học. Lớp nông (cạnh, kết cấu) là universal — dùng được cho mọi bài toán ảnh. Chỉ lớp sâu (vật thể cụ thể) cần thay đổi. Tiết kiệm 99% thời gian và dữ liệu! Cùng ý tưởng áp dụng cho Transformer (BERT, GPT) và mọi kiến trúc hiện đại.
Ví dụ: phân loại chó/mèo, hoa, đồ vật đời thường. Dùng Feature Extraction: đóng băng toàn bộ backbone, chỉ train classifier head. Vài trăm ảnh đã đủ.
Ví dụ: ảnh sản phẩm bán hàng, biển báo Việt Nam. Dùng Fine-tune 2–3 lớp cuối với LR rất nhỏ (1e-5 đến 1e-4) để điều chỉnh nhẹ đặc trưng sâu mà không phá lớp nông.
Ví dụ: ảnh y tế, ảnh vệ tinh, ảnh kính hiển vi. Nên Full fine-tune với LR nhỏ và schedule cosine, cộng augmentation mạnh. Nếu data cực lớn, có thể train từ đầu nhưng pretrained vẫn là điểm khởi tạo tốt.
Val loss tăng trong khi train loss giảm → overfit. Giảm LR, tăng weight decay, thêm augmentation, đóng băng thêm lớp, hoặc dùng early stopping. Không tin cao nhất chỉ sau 1 epoch — luôn xác minh bằng nhiều seed.
Fine-tune BERT cho phân loại cảm xúc tiếng Việt. Nên dùng learning rate bao nhiêu?
Giải thích
Transfer Learning tận dụng kiến thức từ mô hình pretrained trên dataset lớn (ImageNet, WebText, Common Crawl) để giải bài toán mới với ít dữ liệu hơn. Thay vì khởi tạo trọng số ngẫu nhiên, ta khởi tạo từ một điểm đã rất gần tối ưu cho các đặc trưng chung.
Lớp nông CNN học đặc trưng tổng quát (cạnh, kết cấu, màu sắc) — dùng được cho MỌI bài toán ảnh. LLM pretrain học "hiểu ngôn ngữ" tổng quát. Kiến thức nền tảng này là universal, chỉ lớp cuối cần task-specific.
Về toán, gọi là trọng số tối ưu cho bài toán nguồn (ImageNet). Khi chuyển sang bài toán đích, ta khởi tạo và tối ưu với regularizer mềm kéo về gần :
Hệ số lớn tương đương đóng băng; nhỏ tương đương fine-tune tự do. Learning rate nhỏ là một cách ngầm để đạt hiệu ứng tương tự — gradient nhỏ giữ gần .
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import DataLoader
NUM_CLASSES = 10 # bài toán mới
# ──────────────────────────────────────────────────────────────
# 1. Feature extraction — đóng băng toàn bộ backbone
# ──────────────────────────────────────────────────────────────
def build_feature_extractor() -> nn.Module:
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
for param in model.parameters():
param.requires_grad = False # đóng băng
# Thay classifier head cho bài toán mới
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, NUM_CLASSES)
return model
# ──────────────────────────────────────────────────────────────
# 2. Fine-tune 3 lớp cuối — mở khoá layer4 + fc
# ──────────────────────────────────────────────────────────────
def build_last3_finetune() -> nn.Module:
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
for param in model.parameters():
param.requires_grad = False
for param in model.layer4.parameters():
param.requires_grad = True
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, NUM_CLASSES) # trainable theo mặc định
return model
# ──────────────────────────────────────────────────────────────
# 3. Full fine-tune — mở khoá toàn bộ
# ──────────────────────────────────────────────────────────────
def build_full_finetune() -> nn.Module:
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, NUM_CLASSES)
# Mặc định tất cả requires_grad=True
return model
# ──────────────────────────────────────────────────────────────
# 4. Optimizer với LR theo nhóm tham số
# - Lớp mới (head): LR lớn hơn
# - Backbone pretrained: LR rất nhỏ
# ──────────────────────────────────────────────────────────────
def make_optimizer(model: nn.Module, strategy: str):
if strategy == "feature":
# chỉ head được cập nhật
return torch.optim.AdamW(model.fc.parameters(), lr=1e-3, weight_decay=1e-4)
if strategy == "last3":
return torch.optim.AdamW(
[
{"params": model.layer4.parameters(), "lr": 1e-5},
{"params": model.fc.parameters(), "lr": 1e-3},
],
weight_decay=1e-4,
)
# full
return torch.optim.AdamW(
[
{"params": [p for n, p in model.named_parameters()
if not n.startswith("fc")], "lr": 1e-5},
{"params": model.fc.parameters(), "lr": 1e-3},
],
weight_decay=1e-4,
)
# ──────────────────────────────────────────────────────────────
# 5. Vòng lặp huấn luyện rút gọn
# ──────────────────────────────────────────────────────────────
def train(model, loader: DataLoader, optim, epochs: int = 10, device="cuda"):
criterion = nn.CrossEntropyLoss()
model.to(device).train()
for epoch in range(epochs):
running = 0.0
for x, y in loader:
x, y = x.to(device), y.to(device)
optim.zero_grad(set_to_none=True)
logits = model(x)
loss = criterion(logits, y)
loss.backward()
optim.step()
running += loss.item() * x.size(0)
print(f"epoch {epoch+1}: loss={running / len(loader.dataset):.4f}")
return modelBạn fine-tune ResNet50 cho 500 ảnh X-quang phổi, nhưng val accuracy chỉ đạt 62% trong khi train accuracy đã 99%. Can thiệp nào KHÔNG nên thử đầu tiên?
Giải thích
Trong thực tế, kỹ thuật quan trọng nhất khi fine-tune là differential learning rate — chia mạng thành nhiều nhóm, mỗi nhóm có LR riêng. Lớp càng nông, LR càng nhỏ.
import torch
from torch.optim import AdamW
def discriminative_param_groups(model, base_lr: float = 1e-3, decay: float = 0.3):
"""
Chia ResNet thành 4 nhóm theo độ sâu và gán LR giảm dần vào lớp nông.
decay = 0.3 nghĩa là mỗi nhóm sâu hơn có LR gấp 1/0.3 ≈ 3.3 lần lớp trước.
"""
groups = [
{"params": list(model.conv1.parameters())
+ list(model.bn1.parameters())
+ list(model.layer1.parameters()),
"lr": base_lr * decay ** 3},
{"params": list(model.layer2.parameters()),
"lr": base_lr * decay ** 2},
{"params": list(model.layer3.parameters()),
"lr": base_lr * decay},
{"params": list(model.layer4.parameters())
+ list(model.fc.parameters()),
"lr": base_lr},
]
return groups
optim = AdamW(discriminative_param_groups(model, base_lr=1e-3),
weight_decay=1e-4)
# Kết hợp với cosine schedule và warmup
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optim,
max_lr=[g["lr"] for g in optim.param_groups],
total_steps=len(train_loader) * EPOCHS,
pct_start=0.1,
div_factor=10,
)- Chọn pretrained phù hợp domain (ImageNet cho ảnh tự nhiên, MedCLIP/CheXNet cho y tế, v.v.).
- Thay classifier head trước khi load weights vào, hoặc sau khi load rồi khởi tạo lại.
- Quyết định chiến lược: feature / last-k / full dựa vào lượng data và độ lệch domain.
- Dùng differential LR + cosine schedule + warmup 5–10% bước đầu.
- Data augmentation phù hợp: không lật ảnh y tế, nhưng rotate được; mixup/cutmix cho ảnh tự nhiên.
- Early stopping trên val loss, lưu checkpoint theo val metric (F1/AUC), không theo loss nếu mất cân bằng.
- Đừng quên freeze BatchNorm stats (eval mode trên các block đóng băng) — đây là bug cực kỳ phổ biến.
- Tận dụng pretrained model → tiết kiệm 99% dữ liệu và thời gian. Paradigm 'pretrain once, fine-tune many'.
- Feature extraction: đóng băng toàn bộ backbone, chỉ train head — phù hợp khi data rất ít và domain gần ImageNet.
- Fine-tune 2–3 lớp cuối: cân bằng tốt giữa linh hoạt và ổn định — lựa chọn mặc định cho đa số project thực tế.
- Full fine-tune: linh hoạt nhất nhưng tốn data và dễ overfit — dùng khi domain khác xa pretrain hoặc data lớn.
- Lớp nông = đặc trưng universal (cạnh, kết cấu); lớp sâu = task-specific. Luôn giữ LR cho backbone rất nhỏ (1e-5 đến 1e-4).
- Paradigm chủ đạo của AI hiện đại: ResNet/ViT cho ảnh, BERT/GPT/LLaMA cho NLP, CLIP cho multimodal — đều là foundation model + fine-tune.
Kiểm tra hiểu biết
Bạn có 200 ảnh chó/mèo. Nên dùng chiến lược transfer learning nào?