Attention Mechanism
Attention - Cơ chế chú ý
Trước khi nhìn vào công thức, hãy thử tưởng tượng bạn là phiên dịch viên. Bạn đang đọc câu "Tôi yêu Việt Nam quá" và chuẩn bị dịch từng từ ra tiếng Anh. Khi bạn chạm tới từ "Nam", mắt bạn lập tức liếc về đâu?
Khi dịch token "Nam", não bạn ưu tiên NHÌN VÀO đâu nhất trong câu nguồn?
Chạm vào một token để xem nó chú ý vào những token nào. Kéo slider để điều chỉnh temperature của softmax.
Hình minh họa
Chọn Query token (hàng trong ma trận)
Ma trận attention 5×5 (hàng = query, cột = key)
Số trong ô = xác suất attention (%). Hàng đang chọn được viền accent. Tổng mỗi hàng = 100%.
Attention của "yêu" trên 5 token
T nhỏ (→ 0): softmax hội tụ về argmax (one-hot, quyết đoán). T lớn (→ ∞): phân phối phẳng (do dự, chú ý đều). Transformer dùng mặc định T = 1 sau khi đã chia √d_k.
Ba bước của scaled dot-product attention
① Điểm thô eij = Qi · Kj^T / √d_k
② αij = softmax(eij) — tổng = 1
③ contexti = Σj αij · Vj
Vector ngữ cảnh của "yêu" là hỗn hợp có trọng số của 5 vector value. Đây chính là output attention sẽ truyền sang lớp kế.
Khi xử lí "yêu", mô hình chú ý nhất vào "yêu" (42.8%). Tổng weights = 1.00.
Attention biến mỗi token thành một người hỏi. Token đặt câu hỏi (query), so sánh với mọi token khác (key) để quyết định mình nên nghe ai. Câu trả lời không phải là một từ, mà là một hỗn hợp có trọng số của giá trị các token khác.
Đây là lí do cơ chế này giải bài toán phụ thuộc xa (long-range dependency) mà RNN vật vã suốt thập kỉ — bất kể hai token cách nhau bao xa, attention chỉ mất một phép nhân ma trận để kết nối chúng.
Một hàng attention weights cộng lại bằng 1.07. Điều gì đã sai?
1. Tính điểm (Score)
So sánh query với mỗi key để ra điểm 'tương thích'.
2. Softmax (Normalize)
Biến điểm thành phân phối xác suất — tổng = 1.
3. Tổng có trọng số giá trị
Kết hợp các value vector theo trọng số α.
Khi attention được dùng trong decoder của Transformer (sinh văn bản từ trái sang phải), ta phải mask các key ở tương lai: điểm eijvới j > i được đặt bằng −∞ trước softmax → α = 0. Nếu không làm vậy, mô hình "gian lận" bằng cách nhìn vào đáp án.
Giải thích
Cơ chế Attention được giới thiệu bởi Bahdanau và đồng nghiệp (2014) cho bài toán dịch máy. Ý tưởng ban đầu: cho decoder của mô hình Seq2Seq"nhìn lại" mọi trạng thái ẩn của encoder thay vì chỉ dùng một context vector. Ba năm sau, Vaswani et al. (2017) trong bài "Attention is All You Need" tổng quát hoá thành self-attention: token nhìn vào chính chuỗi của mình. Đây là nền tảng của Transformer, GPT, BERT, và mọi LLM hiện đại.
Additive (Bahdanau 2014): — học qua mạng MLP nhỏ, linh hoạt nhưng chậm.
Dot-product (Luong 2015): — đơn giản, nhanh, nhưng không ổn định khi d lớn.
Scaled dot-product (Transformer 2017): — chia √d_k giải quyết vấn đề variance bùng nổ.
Attention cổ điển: Q từ decoder, K/V từ encoder. Gọi là cross-attention.
Self-Attention: Q, K, V đều từ cùng một chuỗi — mỗi token nhìn vào tất cả token khác trong chính nó.
Multi-Head: chạy h đầu attention song song, mỗi đầu có W_Q, W_K, W_V riêng → học nhiều loại quan hệ cùng lúc (ngữ pháp, ngữ nghĩa, đồng quy chiếu...).
Ma trận attention có kích thước n × n (với n = độ dài chuỗi) → bộ nhớ và thời gian tính đều O(n²). Đây là lí do LLM bị giới hạn context. Các biến thể như FlashAttention, Linear Attention, Sliding Window Attention (Longformer) giảm xuống O(n log n) hoặc O(n).
Khi mô hình dịch sai, visualize ma trận attention là bước đầu tiên. Nếu thấy mô hình chú ý sai chỗ — ví dụ khi dịch "bank" (ngân hàng) mà attention dồn vào "river" thay vì "money" — bạn biết vấn đề nằm ở chỗ học, không phải ở decoder. BertViz và các công cụ như exbert hay attention-viewercho phép "mổ" từng đầu attention của các mô hình pre-trained.
Một cách nhìn hữu ích: xem attention như một hệ thống truy vấn cơ sở dữ liệu mềm. Query giống như câu truy vấn SQL. Key giống như chỉ mục (index) của mỗi hàng. Value là nội dung. Khác biệt duy nhất: thay vì trả về một hàng duy nhất khớp chính xác, attention trả về một hỗn hợp mờ — trọng số theo độ khớp giữa query và key.
Cụ thể hơn, trong self-attention của Transformer:
- Q = X · W_Q: mỗi token biến thành một câu hỏi.
- K = X · W_K: mỗi token biến thành một "nhãn" mô tả chính nó.
- V = X · W_V: mỗi token biến thành một "nội dung" sẵn sàng được tổng hợp.
- Ba ma trận W_Q, W_K, W_V được học qua backprop — mạng tự quyết định thế nào là "khớp".
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class ScaledDotProductAttention(nn.Module):
"""Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V."""
def forward(self, q, k, v, mask=None):
# q: [B, H, L_q, d_k]
# k: [B, H, L_k, d_k]
# v: [B, H, L_k, d_v]
d_k = q.size(-1)
# 1) Raw scores = Q · K^T / sqrt(d_k)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
# scores: [B, H, L_q, L_k]
# 2) Optional causal mask (decoder)
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
# 3) Softmax → attention weights
attn = F.softmax(scores, dim=-1) # tổng theo L_k = 1
# 4) Weighted sum of V
out = torch.matmul(attn, v)
# out: [B, H, L_q, d_v]
return out, attn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.attn = ScaledDotProductAttention()
def forward(self, q, k, v, mask=None):
B = q.size(0)
# 1) Linear projection → split into heads
# [B, L, d_model] → [B, H, L, d_k]
q = self.W_q(q).view(B, -1, self.n_heads, self.d_k).transpose(1, 2)
k = self.W_k(k).view(B, -1, self.n_heads, self.d_k).transpose(1, 2)
v = self.W_v(v).view(B, -1, self.n_heads, self.d_k).transpose(1, 2)
# 2) Scaled dot-product attention trên tất cả các đầu song song
out, attn = self.attn(q, k, v, mask=mask)
# 3) Concat các đầu
out = out.transpose(1, 2).contiguous().view(B, -1, self.d_model)
# 4) Projection cuối
return self.W_o(out), attn
# ─── Ví dụ nhanh trên câu "Tôi yêu Việt Nam quá"
if __name__ == "__main__":
seq_len, d_model, n_heads = 5, 64, 4
mha = MultiHeadAttention(d_model, n_heads)
# Giả lập embedding của 5 token
x = torch.randn(1, seq_len, d_model) # [B=1, L=5, d=64]
# Self-attention: Q, K, V đều = x
context, attn = mha(x, x, x)
print("Context shape:", context.shape) # [1, 5, 64]
print("Attention shape:", attn.shape) # [1, 4, 5, 5]
print("Sum per row (đầu 0):",
attn[0, 0].sum(dim=-1)) # tensor([1., 1., 1., 1., 1.])Một điểm thú vị nữa: trong các mô hình ngôn ngữ lớn hiện đại (GPT-4, Claude, Gemini), attention không chỉ được dùng để kết nối các token trong cùng một câu — nó còn kết nối các đoạn văn bản cách xa hàng nghìn token. Khi model đọc một tài liệu dài và trả lời câu hỏi về đoạn đầu, về mặt cơ học thì attention của token trả lời phải nhìn được về token của đoạn đầu. Đây là lí do kích thước context window (32K, 200K, 1M token) trở thành chỉ số quan trọng của LLM.
Tuy nhiên, chất lượng attention không đều trên mọi khoảng cách. Các thí nghiệm "needle in a haystack" (tìm kim trong đống rơm) cho thấy nhiều mô hình có độ chính xác giảm đáng kể ở giữa context dài — hiện tượng gọi là "lost in the middle" (Liu et al. 2023). Nghiên cứu hiện nay tập trung vào cải thiện positional encoding (RoPE, ALiBi), attention sparsity, và training strategy để attention phân bố đều hơn.
import matplotlib.pyplot as plt
import seaborn as sns
import torch
# Cho ma trận attention đã có từ model (shape [L, L])
def plot_attention(attn_matrix, tokens):
"""Vẽ heatmap attention như visualization ở phía trên."""
fig, ax = plt.subplots(figsize=(6, 5))
sns.heatmap(
attn_matrix.cpu().detach(),
xticklabels=tokens,
yticklabels=tokens,
cmap="YlOrBr", # amber scale — giống UI của bài này
annot=True,
fmt=".2f",
cbar_kws={"label": "attention weight"},
ax=ax,
)
ax.set_xlabel("Key (từ được nhìn)")
ax.set_ylabel("Query (từ đang hỏi)")
ax.set_title("Attention matrix")
plt.tight_layout()
return fig
# Pipeline đầy đủ cho 1 câu
def attention_for_sentence(model, tokenizer, sentence):
tokens = tokenizer.tokenize(sentence)
ids = torch.tensor([tokenizer.convert_tokens_to_ids(tokens)])
with torch.no_grad():
outputs = model(ids, output_attentions=True)
# outputs.attentions: tuple[L_layers] of [B, H, L, L]
# Lấy layer cuối, trung bình qua các đầu
attn = outputs.attentions[-1][0].mean(dim=0) # [L, L]
return plot_attention(attn, tokens)
# sentence = "Tôi yêu Việt Nam quá"
# attention_for_sentence(model, tokenizer, sentence)
# ─── Bonus: so sánh attention ở các layer khác nhau ───
def attention_across_layers(model, tokenizer, sentence):
"""Tạo subplot cho mỗi layer — xem attention evolving theo chiều sâu."""
tokens = tokenizer.tokenize(sentence)
ids = torch.tensor([tokenizer.convert_tokens_to_ids(tokens)])
with torch.no_grad():
outputs = model(ids, output_attentions=True)
n_layers = len(outputs.attentions)
fig, axes = plt.subplots(
1, n_layers, figsize=(3 * n_layers, 3), sharey=True
)
for layer_idx, attn_layer in enumerate(outputs.attentions):
# Trung bình qua các đầu của layer này
attn = attn_layer[0].mean(dim=0).cpu()
ax = axes[layer_idx] if n_layers > 1 else axes
sns.heatmap(
attn,
xticklabels=tokens,
yticklabels=tokens if layer_idx == 0 else False,
cmap="YlOrBr",
cbar=False,
ax=ax,
square=True,
)
ax.set_title(f"Layer {layer_idx}")
fig.suptitle("Attention từ nông đến sâu")
plt.tight_layout()
return fig
# ─── Bonus 2: temperature sweep (giống slider ở phần viz) ───
def temperature_sweep(raw_scores, tokens, temperatures=(0.2, 1.0, 5.0)):
"""Vẽ 3 heatmap với T khác nhau để so sánh độ 'quyết đoán'."""
fig, axes = plt.subplots(1, len(temperatures), figsize=(4 * len(temperatures), 4))
for ax, T in zip(axes, temperatures):
scaled = raw_scores / T
attn = torch.softmax(scaled, dim=-1)
sns.heatmap(
attn,
xticklabels=tokens,
yticklabels=tokens,
cmap="YlOrBr",
annot=True,
fmt=".2f",
cbar=False,
ax=ax,
)
ax.set_title(f"T = {T}")
return figDòng chảy tiến hoá của attention
2014 — Bahdanau Attention
Bổ sung attention vào Seq2Seq RNN. Điểm số dùng MLP nhỏ (additive). Chứng minh attention cải thiện BLEU đáng kể cho bản dịch câu dài.
2015 — Luong Attention
Đơn giản hoá thành dot-product (nhanh hơn). Đề xuất global vs local attention — local chỉ nhìn cửa sổ nhỏ quanh từ hiện tại.
2017 — Transformer (Attention is All You Need)
Xoá RNN hoàn toàn. Chỉ dùng self-attention + multi-head + positional encoding. Mở đường cho kỉ nguyên LLM.
2018-2020 — BERT, GPT-2, T5
Ba kiến trúc chính: encoder-only (BERT), decoder-only (GPT), encoder-decoder (T5). Tất cả chỉ là attention + FFN.
2022 — FlashAttention
Tổ chức lại memory access để giảm I/O giữa HBM và SRAM của GPU. Attention O(n²) vẫn đúng về mặt lý thuyết, nhưng thực tế nhanh gấp 2-4 lần.
2023+ — Linear & Sparse Attention
Mamba (SSM), RWKV, Hyena, Longformer — các kiến trúc thay thế attention bằng cơ chế O(n) hoặc O(n log n) để mở rộng context hàng triệu token.
Mặc dù vậy, tại thời điểm 2026, attention vẫn là trái tim của gần như mọi LLM thương mại. Các cải tiến tập trung vào tối ưu (FlashAttention, paged attention trong vLLM) hoặc lai ghép (attention cho context ngắn, SSM cho context dài — như Jamba, Griffin).
Bạn có 1 head attention và 1 câu 1024 token. Trên GPU với 24GB VRAM, chuyện gì hay xảy ra?
- Mỗi token là một query; nó so khớp với mọi key để chọn 'nên nghe ai'.
- Ba bước: Q·K^T / √d_k → softmax → Σ α·V. Tổng mỗi hàng attention = 1.
- Chia √d_k giữ variance của điểm thô ~ 1 để softmax không bão hoà.
- Softmax temperature nhỏ → quyết đoán (one-hot); lớn → dàn đều.
- Multi-head = h đầu song song; mỗi đầu học một loại quan hệ khác.
- Độ phức tạp O(n²) là rào cản context dài — FlashAttention / sparse attention giải quyết.
Kiểm tra hiểu biết
Attention weights cho 'Nam' trên câu 'Tôi yêu Việt Nam quá' là [0.02, 0.04, 0.55, 0.36, 0.03]. Điều này nói lên điều gì?