Vision Transformer (ViT)
Transformer thị giác
Transformer thống trị NLP (text). Nhưng ảnh không phải chuỗi từ — làm sao áp dụng Transformer cho ảnh?
Hãy tưởng tượng bạn có tấm ảnh chụp phố cổ Hội An và cắt thành 9 mảnh ghép (jigsaw). Mỗi mảnh tự hỏi: "Mảnh nào liên quan đến tôi?" — đó là self-attention! Mảnh có đèn lồng sẽ "chú ý" đến mảnh có phố — dù chúng cách xa nhau.
Hình minh họa
Nhấn vào patch để xem attention connections. Mỗi patch "nhìn" tất cả patches khác.
ViT coi ảnh là chuỗi patches, giống Transformer coi văn bản là chuỗi tokens. Patches = tokens! Self-attention cho mỗi patch "nhìn" toàn bộ ảnh ngay từ lớp đầu tiên — CNN phải xếp nhiều lớp mới "nhìn xa" được.
CNN
- Inductive bias: cục bộ + chia sẻ bộ lọc
- Train tốt với ít data (1-10M ảnh)
- Nhìn cục bộ → xếp lớp mới nhìn xa
- Hiệu quả tham số hơn
ViT
- Không inductive bias → cần nhiều data hơn
- Vượt CNN khi data lớn (>100M ảnh)
- Nhìn toàn cục ngay lớp đầu (attention)
- Scale tốt hơn (ViT-22B)
DeiT (2021) cho thấy ViT train được trên ImageNet (1.4M) nhờ data augmentation mạnh + knowledge distillation từ CNN. Không cần 300M ảnh nữa! Ngày nay ViT là lựa chọn hàng đầu cho computer vision.
ViT-Base: 12 lớp Transformer, 768 hidden, 12 heads, patch 16×16, ảnh 224×224. Có bao nhiêu tokens trong sequence?
Giải thích
Vision Transformer (ViT) (Dosovitskiy et al., 2020) chứng minh kiến trúc Transformer thuần túy (không CNN) đạt SOTA trên image classification nhờ self-attention trên các patch ảnh.
= patch flatten, = linear projection, = positional embedding.
import torch
import torch.nn as nn
class ViT(nn.Module):
def __init__(self, img_size=224, patch_size=16, d_model=768,
n_layers=12, n_heads=12, num_classes=1000):
super().__init__()
n_patches = (img_size // patch_size) ** 2 # 196
# Patch embedding: flatten patch → linear projection
self.patch_embed = nn.Conv2d(
3, d_model, kernel_size=patch_size, stride=patch_size
) # (B, 3, 224, 224) → (B, 768, 14, 14)
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, d_model))
# Transformer encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=n_heads, batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, n_layers)
self.head = nn.Linear(d_model, num_classes)
def forward(self, x):
B = x.shape[0]
# 1. Patch embedding
x = self.patch_embed(x).flatten(2).transpose(1, 2) # (B, 196, 768)
# 2. Prepend [CLS] token
cls = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls, x], dim=1) # (B, 197, 768)
# 3. Add positional encoding
x = x + self.pos_embed
# 4. Transformer
x = self.transformer(x)
# 5. [CLS] output → classification
return self.head(x[:, 0])- ViT: chia ảnh thành patches → flatten → linear projection → Transformer encoder. Patches = tokens!
- [CLS] token attend đến mọi patch → tổng hợp thông tin toàn bộ ảnh → classification head.
- Không có inductive bias của CNN (locality, weight sharing) → cần nhiều data hơn hoặc augmentation mạnh (DeiT).
- Với data đủ lớn (>100M), ViT vượt CNN. Scale rất tốt: ViT-22B có 22 tỷ tham số.
- Biến thể: DeiT (ít data), Swin Transformer (hierarchical + shifted window), DINO (self-supervised).
Kiểm tra hiểu biết
ViT chia ảnh 224×224 thành patches 16×16. Có bao nhiêu patches (tokens)?