Self-Attention
Tự chú ý
Câu "Con mèo ngồi trên bàn". Khi não bạn đọc tới từ "bàn", bạn tự động nghĩ nhiều nhất đến từ nào trong câu?
Bạn vừa chọn bằng trực giác. Ở phần tiếp theo bạn sẽ thấy chính con số đó hiện ra như một hàng trong ma trận — đó chính là cách máy "chú ý".
Hãy tưởng tượng bạn đang đứng trong chợ Bến Thành. Bạn là một "từ", và bạn đang cần hoàn thiện nghĩa của mình bằng cách hỏi các gian hàng xung quanh.
- Query (Q)là câu hỏi bạn thì thào: "Ai có thứ liên quan đến tôi?"
- Key (K) là tấm bảng hiệu của mỗi gian hàng, tóm tắt họ đang bán gì.
- Value (V) là món hàng thực sự họ trao cho bạn nếu bạn để ý đến họ.
Bạn nhìn lướt tất cả bảng hiệu (tính ), cho điểm độ liên quan, rồi phân chia "sự chú ý" của mình (softmax). Cuối cùng bạn gom các món hàng theo tỷ lệ đó — đó là output của attention cho từ "bạn".
Dưới đây là toàn bộ self-attention cho câu 5 từ. Bạn có thể: (1) bật/tắt chia để thấy softmax bão hoà, (2) đổi sang cross-attention để thấy Q đến từ chuỗi khác, và (3) click vào bất kỳ token nào để xem hàng attention của nó hiện ra dưới dạng heatmap.
Hình minh họa
Heatmap ma trận attention (self-attention, 5×5)
Mỗi hàng tổng bằng 1 (softmax). Màu càng đậm → chú ý càng nhiều. Click bất kỳ hàng nào để đổi token đang xem.
Tính toán cho hàng "bàn"
| token j | Q·K | /√d_k | softmax |
|---|---|---|---|
| Con | 0.31 | 0.18 | 0.178 |
| mèo | 0.62 | 0.36 | 0.213 |
| ngồi | 0.59 | 0.34 | 0.209 |
| trên | 0.43 | 0.25 | 0.191 |
| bàn | 0.60 | 0.35 | 0.210 |
Cột 2 là tích vô hướng "thô", cột 3 sau khi chia , cột 4 là attention cuối cùng.
Output cho "bàn" = Σ w · V
Vector output (d_k=3): [0.386, 0.580, 0.484]
Thử tắt Chia √d_k: khi d_k nhỏ (=3) hiệu ứng chưa rõ, nhưng bạn vẫn thấy softmax dịch chuyển về một ô duy nhất. Trong thực tế với d_k=64, chênh lệch ấy đủ để gradient tắt.
Self-attention= mỗi từ "hỏi" tất cả từ khác trong cùng câu: "Bạn quan trọng với tôi bao nhiêu?" Tập hợp các câu trả lời (sau softmax) chính là hàng attention của nó. Output của từ đó = tổng có trọng số của mọi Value.
Khác với RNNphải truyền thông tin tuần tự qua thời gian, self-attention cho mọi từ đường đi O(1) đến mọi từ khác — bất kể khoảng cách. "Tôi" ở đầu câu vẫn có đường kết nối trực tiếp với "tôi" ở cuối câu. Đây là trái tim của Transformer, và khi được nhân bản song song h lần ta có multi-head attention.
Một chuỗi dài 4096 token. Ma trận attention có bao nhiêu ô, và điều đó kéo theo vấn đề gì?
Trong decoder của GPT, khi sinh token thứ 5 của chuỗi, attention của token đó có thể nhìn vào token thứ 6 (tương lai) không?
Giải thích
Định nghĩa
Self-Attention(Scaled Dot-Product Attention) là phép toán cho phép mỗi vị trí trong chuỗi tính ra một vector đầu ra là tổng có trọng số của tất cả các vector Value trong cùng chuỗi, với trọng số được xác định bởi độ "khớp" giữa Query của vị trí đó và Key của các vị trí khác.
Chính thức: cho ma trận Q, K, V cùng có n hàng (mỗi hàng là một token) và d_k, d_v cột, output là:
Q, K ∈ , V ∈ . Ma trận attention , mỗi ô (i, j) là "mức chú ý" token i dành cho token j.
Công thức gốc của Vaswani et al. (Attention Is All You Need, 2017)
Khi lớn, tích vô hướng có phương sai xấp xỉ d_k (giả định q, k được chuẩn hoá với phương sai 1). Giá trị lớn đi qua softmax sẽ tạo phân phối gần one-hot → gradient gần 0 cho các vị trí còn lại. Chia đưa phương sai về 1, giúp softmax mềm hơn và gradient chảy tốt hơn khi train.
Self-attention: Q, K, V cùng xuất phát từ một chuỗi (encoder nhìn câu nguồn, hoặc decoder nhìn câu đích đã sinh). Cross-attention:Q từ chuỗi A, còn K, V từ chuỗi B. Trong Transformer encoder-decoder, decoder dùng cross-attention để "đọc" thông tin từ encoder — đây là nơi thông tin nguồn được đưa vào câu đích.
Bộ nhớ và thời gian tính một lớp self-attention là . Nhân đôi chiều dài context → gấp 4 chi phí. Khi n = 32k thì chỉ riêng ma trận attention đã hàng GB. Các kỹ thuật như Flash Attention, Long Context và sliding window attention ra đời để đánh đổi khéo giữa chính xác và chi phí.
Khác với RNN phải chạy tuần tự theo thời gian, self-attention là một phép nhân ma trận lớn — ánh xạ gần như hoàn hảo lên GPU. Đây là một trong những lý do chính khiến Transformer "ăn" dữ liệu nhanh hơn LSTM cùng kích thước.
Code mẫu: tính self-attention bằng NumPy
import numpy as np
def softmax(x, axis=-1):
x = x - x.max(axis=axis, keepdims=True) # ổn định số học
e = np.exp(x)
return e / e.sum(axis=axis, keepdims=True)
def self_attention(X, Wq, Wk, Wv):
"""
X: (n_tokens, d_model) — embedding đầu vào
Wq: (d_model, d_k) — chiếu sang không gian Query
Wk: (d_model, d_k) — chiếu sang không gian Key
Wv: (d_model, d_v) — chiếu sang không gian Value
"""
Q = X @ Wq # (n, d_k)
K = X @ Wk # (n, d_k)
V = X @ Wv # (n, d_v)
d_k = K.shape[-1]
scores = Q @ K.T / np.sqrt(d_k) # (n, n) — QK^T/√d_k
weights = softmax(scores, axis=-1) # (n, n) — softmax theo hàng
output = weights @ V # (n, d_v) — tổng có trọng số
return output, weights
# Ví dụ: câu 5 từ, d_model = 8, d_k = d_v = 4
rng = np.random.default_rng(42)
X = rng.standard_normal((5, 8))
Wq = rng.standard_normal((8, 4)) * 0.1
Wk = rng.standard_normal((8, 4)) * 0.1
Wv = rng.standard_normal((8, 4)) * 0.1
out, attn = self_attention(X, Wq, Wk, Wv)
print(out.shape) # (5, 4)
print(attn.shape) # (5, 5) — tổng mỗi hàng ≈ 1
print(attn.sum(axis=-1)) # [1. 1. 1. 1. 1.]Code mẫu: masked self-attention & cross-attention
import numpy as np
def scaled_dot_product(Q, K, V, mask=None):
d_k = K.shape[-1]
scores = Q @ K.T / np.sqrt(d_k)
if mask is not None:
scores = np.where(mask, scores, -1e9) # đặt các ô bị che = -inf
weights = softmax(scores, axis=-1)
return weights @ V, weights
# --- causal mask cho decoder-only (GPT) ---
n = 5
causal = np.tril(np.ones((n, n), dtype=bool)) # True ở tam giác dưới
# scores[i, j] chỉ được phép > -inf khi j <= i
# --- cross-attention ---
# Q từ decoder (m token), K, V từ encoder (n token)
# Lưu ý: hình dạng attention là m × n, không bắt buộc vuông.
def cross_attention(dec_h, enc_out, Wq, Wk, Wv):
Q = dec_h @ Wq # (m, d_k)
K = enc_out @ Wk # (n, d_k)
V = enc_out @ Wv # (n, d_v)
out, _ = scaled_dot_product(Q, K, V)
return out # (m, d_v)Ứng dụng tiêu biểu
- Mô hình ngôn ngữ (LLM). GPT, Claude, Llama, Gemini đều dùng self-attention làm nền — mỗi token dự đoán được ngữ cảnh nhờ attention tới mọi token trước đó.
- Dịch máy. Encoder-decoder Transformer dùng self-attention để mã hoá câu nguồn và cross-attention để căn chỉnh từng từ đích với từ nguồn liên quan.
- Thị giác máy tính. Vision Transformer (ViT) chia ảnh thành các patch và dùng self-attention như xử lý token.
- Mô hình đa phương thức. CLIP, VLM dùng cross-attention để "ghép" text và ảnh vào cùng một biểu diễn.
- Sinh học & khoa học. AlphaFold 2 dùng attention-over-residues để dự đoán cấu trúc protein, SE(3)-Transformer cho hoá học, Graphormer cho đồ thị.
Bẫy thường gặp
- Quên scaling. Bỏ khi d_k lớn → softmax bão hoà, loss kẹt ở plateau.
- Thiếu causal mask khi huấn luyện decoder.Dẫn đến "rò rỉ tương lai": mô hình học được cheat sheet, train loss rất thấp nhưng inference sai bét.
- Nhầm d_v với d_k. Nhiều cài đặt đặt d_k = d_v để đơn giản, nhưng bản chất hai con số khác nhau; khi đổi d_v, kích thước W^O phải cập nhật.
- Lẫn mask boolean với -inf. Trong framework, phải thêm lượng lớn âm TRƯỚC softmax, không phải nhân với 0 SAU softmax — nếu không các trọng số không còn tổng bằng 1.
- Quên positional encoding. Self-attention không biết thứ tự; nếu thiếu vị trí, hoán vị chuỗi input cho cùng output.
- O(n²) âm thầm. Context dài gấp 4 → bộ nhớ gấp 16. Luôn kiểm tra peak GPU memory, đừng chỉ nhìn param count.
Code mẫu: gradient tự nhiên qua attention
"""
Ta không tự viết backward cho attention — autograd lo việc đó. Nhưng
hiểu đường đi của gradient giúp bạn debug:
∂L/∂V = weights.T @ ∂L/∂output
∂L/∂W_i,: = softmax_grad(scores_i,:) * (V @ ∂L/∂output_i)
∂L/∂Q,∂K = đi ngược qua QK^T / sqrt(d_k)
Đoạn PyTorch minh hoạ — nếu một trong các tensor không requires_grad,
gradient sẽ chặn tại đó (điều hay xảy ra khi ta freeze embedding).
"""
import torch, torch.nn.functional as F
def attention(Q, K, V, mask=None):
d_k = Q.size(-1)
s = Q @ K.transpose(-2, -1) / d_k ** 0.5
if mask is not None:
s = s.masked_fill(~mask, float("-inf"))
w = F.softmax(s, dim=-1)
return w @ V, w
torch.manual_seed(0)
X = torch.randn(2, 5, 8, requires_grad=True) # (batch, n, d_model)
Wq = torch.randn(8, 4, requires_grad=True) * 0.1
Wk = torch.randn(8, 4, requires_grad=True) * 0.1
Wv = torch.randn(8, 4, requires_grad=True) * 0.1
out, _ = attention(X @ Wq, X @ Wk, X @ Wv)
loss = out.pow(2).mean()
loss.backward()
print("‖∂L/∂X‖ =", X.grad.norm().item())
print("‖∂L/∂Wq‖ =", Wq.grad.norm().item())
# Thay requires_grad của Wq bằng False → gradient không chảy tới Wq.Biến thể attention bạn nên biết
- Additive attention (Bahdanau). Ra đời trước scaled dot-product; tính score qua một MLP nhỏ. Biểu cảm hơn nhưng đắt, hiếm dùng trong Transformer hiện đại.
- Multi-Query Attention (MQA). Nhiều head chia sẻ cùng một K, V — giảm mạnh chi phí KV cache trong inference LLM.
- Grouped-Query Attention (GQA). Dung hoà giữa MHA và MQA: chia head thành G nhóm, mỗi nhóm chia sẻ K, V. Được Llama 2/3 dùng làm mặc định.
- Linear Attention. Khai triển softmax qua kernel feature-map, đưa độ phức tạp về O(n·d²). Đánh đổi precision, phù hợp streaming / real-time.
- Sparse / Local Attention. Chỉ attend vào lân cận hoặc một mẫu rời rạc. Longformer, BigBird, Mistral sliding window đều dùng ý này.
- Mamba / State Space Models. Không phải attention, nhưng cạnh tranh trực tiếp cho ngữ cảnh dài — xem state space models.
Thuật ngữ liên quan
- Multi-head attention — chạy h bản self-attention song song với các tham số khác nhau, rồi concat.
- Positional encoding — thêm thông tin vị trí vào embedding vì self-attention không biết thứ tự. Gồm sinusoidal (gốc) và RoPE / ALiBi (hiện đại).
- Flash Attention — kỹ thuật tile-based giúp tính attention không hiện vật ma trận n×n đầy đủ, tiết kiệm bộ nhớ và nhanh hơn nhiều trên GPU.
- KV cache — lưu K, V của các token đã sinh, tránh tính lại khi generate từng token tiếp theo trong decoder.
- Attention mechanism — bài giới thiệu tổng quát cho attention (Bahdanau, Luong) trước khi vào self-attention.
- Mỗi token sinh ra 3 vector: Query (hỏi gì?), Key (chứa gì?), Value (nội dung). Attention = softmax(QKᵀ/√d_k) · V.
- Mỗi token nối trực tiếp với mọi token khác — path length = 1, nắm bắt phụ thuộc xa tốt hơn RNN.
- Tính được song song trên GPU thông qua một phép nhân ma trận — đây là lý do Transformer train nhanh hơn LSTM.
- Nhược điểm O(n²) về bộ nhớ và tính toán — giới hạn context window; Flash Attention, Sparse Attention giải quyết phần nào.
- Chia √d_k giữ softmax không bão hoà; masked attention chặn decoder nhìn tương lai; cross-attention đổi nguồn cho Q.
- Là trái tim của Transformer (GPT, BERT, Llama, Claude, Gemini, ViT) và bước mở rộng tự nhiên thành multi-head attention.
Kiểm tra hiểu biết
Mỗi token tạo ra 3 vector Q, K, V. Chúng đóng vai trò gì?