Multi-Head Attention
Chú ý đa đầu
Câu "Tôi yêu Việt Nam". Self-attention chỉ dùng 1 bộ trọng số Q, K, V. Nó có thể vừa nắm cú pháp ("Tôi" → chủ ngữ), vừa nắm ngữ nghĩa ("yêu" → "Việt Nam"), vừa nắm vị trí? Có cách nào tốt hơn?
Hãy tưởng tượng bạn xem bóng đá trên TV. Camera 1 quay toàn cảnh (vị trí), Camera 2 zoom vào cầu thủ (chi tiết), Camera 3 quay khán đài (bối cảnh). Đạo diễn kết hợp tất cả → bạn hiểu trận đấu đầy đủ. Mỗi head attention = 1 camera!
Hình minh họa
Chọn head và từ truy vấn. Quan sát cách mỗi head "nhìn" câu theo kiểu khác nhau.
Bạn vừa thấy mỗi head nhìn cùng câu nhưng theo "góc" khác nhau. Head 3 (vị trí) chú ý từ gần nhất, Head 2 (ngữ nghĩa) chú ý từ liên quan nghĩa — dù xa. Kết hợp tất cả → hiểu đầy đủ!
Multi-Head Attention = chạy h bộ attention song song, mỗi bộ trong subspace d_k = d_model/h. Tổng tham số gần bằng single-head attention, nhưng mỗi head chuyên biệt 1 kiểu quan hệ → phong phú hơn rất nhiều!
GPT-3 có 96 heads. Nghiên cứu cho thấy: bỏ 1 số heads, performance gần không đổi → mỗi head có "chuyên môn" riêng, hệ thống rất robust!
d_model = 768, heads = 12. Mỗi head có d_k = ? Tổng Q, K, V projection cho tất cả heads = ?
Giải thích
Multi-Head Attention chạy h phép self-attention song song trong các subspace khác nhau, mỗi phép có bộ trọng số riêng. Đây là khối cốt lõi trong Transformer:
, , .
Single-head: = d² tham số × 3 = 3d². Multi-head h: mỗi head = d²/h. Tổng h heads: h × d²/h = d² × 3 + W° = 4d². Gần bằng!
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model=512, n_heads=8):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads # 512/8 = 64
# Q, K, V projections (gộp tất cả heads)
self.W_q = nn.Linear(d_model, d_model) # 512→512
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model) # Output projection
def forward(self, x):
B, N, D = x.shape # batch, seq_len, d_model
# Project và split thành heads
Q = self.W_q(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
# Shape: (B, n_heads, N, d_k)
# Attention cho tất cả heads song song
scores = Q @ K.transpose(-2, -1) / (self.d_k ** 0.5)
weights = scores.softmax(dim=-1)
attn = weights @ V # (B, n_heads, N, d_k)
# Concat heads và project
out = attn.transpose(1, 2).reshape(B, N, D)
return self.W_o(out) # (B, N, d_model)- Multi-Head = h bộ attention song song, mỗi bộ chuyên biệt 1 kiểu quan hệ (cú pháp, ngữ nghĩa, vị trí...).
- d_k = d_model/h: mỗi head nhỏ hơn nhưng tổng tham số gần bằng single-head attention.
- Concat h heads rồi nhân W^O → output d_model chiều. W^O trộn thông tin từ tất cả heads.
- GPT-3: 96 heads, BERT-base: 12 heads, GPT-4: 120+ heads. Nhiều heads = đa dạng kiểu quan hệ.
- Mỗi head hoạt động trong subspace riêng → ensemble of attention → robust và đa dạng.
Kiểm tra hiểu biết
d_model = 512, có 8 heads. Mỗi head có d_k bao nhiêu? Tổng tham số attention thay đổi thế nào?