Flash Attention
Flash Attention
Attention matrix cho 128K tokens cần 128K × 128K × 4 bytes ≈ 64GB bộ nhớ. GPU A100 chỉ có 80GB. Gần hết! Làm sao tính attention mà không cần lưu toàn bộ ma trận N×N?
Hãy tưởng tượng bạn cần so sánh 1000 hồ sơ Shopee với nhau. Cách cũ (self-attention gốc): trải hết 1000 hồ sơ ra sàn nhà khổng lồ, so từng cặp — mất đất và mất thời gian vì phải chạy ra chạy vô. Cách Flash: lấy ra 50 hồ sơ, so sánh trên bàn nhỏ (SRAM), ghi kết quả, cất lại, lấy 50 hồ sơ tiếp. Bàn nhỏ hơn nhiều mà kết quả chính xác — vì bạn tích lũy tổng chứ không bỏ hồ sơ nào!
Hình minh họa
Với N=128K, fp16: Standard 32 GB · Flash ~0.5 MB.
FA3 trên H100 đạt ~75% peak FP16, hoặc 1.2 PFLOPS với FP8.
Sequence 8K — Flash nhanh ~3×; bộ nhớ giảm 60×.
Flash Attention không thay đổi toán học — kết quả CHÍNH XÁC giống standard attention. Nó chỉ thay đổi cách tính: chia thành blocks nhỏ, tính trong SRAM (nhanh), dùng online softmax để tích lũy. Giảm IO = giảm thời gian thật sự!
Insight sâu hơn: GPU hiện đại tắc ở bộ nhớ (memory-bound) chứ không phải tính toán (compute-bound). Tỷ lệ FLOPS/bandwidth của A100 là ~150; tức GPU có thể làm 150 phép nhân cho mỗi byte đọc. Attention chuẩn có arithmetic intensity thấp hơn giá trị đó → bỏ phí compute. Flash Attention nâng arithmetic intensity bằng cách “tái dùng” data trong SRAM nhiều lần — gần đến peak compute của GPU.
Registers
1 chu kỳNhanh nhất, nằm ngay trong lõi CUDA. Giữ các biến tạm trong kernel.
Lưu ý: Không thể cấp phát lớn — compiler quyết định phân bổ.
SRAM (Shared Memory)
~20 chu kỳBộ nhớ chia sẻ trong mỗi SM. Nhanh hơn HBM ~10 lần. Flash Attention giữ block Q, K, V ở đây.
Lưu ý: Chỉ 192 KB / SM — vì vậy block size phải nhỏ.
L2 Cache
~200 chu kỳCache chung cho toàn GPU. Tự động quản lý, không lập trình trực tiếp.
Lưu ý: Hữu ích cho các tensor dùng lặp.
HBM (Device Memory)
~400 chu kỳBộ nhớ chính GPU. Lớn, nhưng 'xa' — phần lớn thời gian GPU chờ HBM.
Lưu ý: Mọi tensor nằm đây giữa các kernel call.
DRAM (Host)
hàng µsRAM hệ thống. Đi qua PCIe/NVLink — rất chậm so với GPU memory.
Lưu ý: Cần tránh data transfer host↔device trong hot path.
Flash Attention không bỏ phần tử nào, không xấp xỉ. Nó chỉ thay đổi thứ tự tính toán: thay vì tính toàn bộ hàng softmax → tính từng block + cập nhật online softmax. Toán học tương đương, IO ít hơn → nhanh hơn. Điều này giống việc bạn cộng số trong Excel theo từng cột rồi tổng lại — kết quả cuối giống hệt cộng tất cả cùng lúc.
Arithmetic intensity = FLOPs / bytes loaded. Một operation có AI cao nghĩa là mỗi byte load được tái dùng nhiều lần → gần peak compute. Attention chuẩn có AI thấp vì nó chỉ load data một lần, compute, write back. Flash giữ data trong SRAM → tái dùng qua nhiều inner loop → AI tăng lên, memory-bound biến thành compute-bound.
Chuỗi 128K tokens. Standard attention cần 128K × 128K × 4 bytes ≈ 64GB cho attention matrix. Flash Attention cần bao nhiêu?
Bạn train Llama với torch.compile + F.scaled_dot_product_attention nhưng GPU utilization chỉ 35%, profiler thấy nhiều _aten::attention native. Lỗi đâu?
Giải thích
Flash Attention (Dao et al., 2022) là thuật toán IO-aware exact attention. Không thay đổi toán học, chỉ tối ưu cách tính trên hardware. Ý tưởng cốt lõi: phần lớn thời gian GPU không làm toán — nó đợi data đi lại giữa HBM và SRAM. Tránh round-trip đó là thắng.
Standard Attention — IO bottleneck
3 lần ghi/đọc HBM cho ma trận N×N. Memory = O(N²). Đây là HBM trip thứ 1 (ghi S), thứ 2 (đọc S + ghi P), thứ 3 (đọc P + ghi O) — cộng dồn rất đáng kể khi N lớn.
Flash Attention — tiling + online softmax
Mỗi block tính hoàn toàn trong SRAM. Online softmax cập nhật running max/sum. Memory = O(N). Không có ma trận P cỡ N×N bao giờ xuất hiện trong HBM!
Online softmax — kỹ thuật cốt lõi
Để tính softmax ổn định, ta cần rồi . Online softmax cho phép tính dần khi chỉ thấy một phần của x mỗi lần. Khi gặp giá trị mới :
Output cũng được rescale tương ứng khi m thay đổi. Kết quả cuối cùng hoàn toàn chính xác — không sai số so với tính một lần với toàn bộ hàng.
FA2 (Dao 2023): đảo loop order (outer loop theo Q block), song song hóa theo sequence length chứ không chỉ theo head/batch, giảm non-matmul FLOPs → khoảng 2× nhanh hơn FA1. Đạt 50–70% peak A100 FP16 TFLOPs.
FA3(Shah & Dao 2024): thiết kế cho Hopper/H100. Tận dụng Tensor Memory Accelerator (TMA) cho async load, WGMMA warp-group matmul, FP8 cho bandwidth 2× nữa. Kết quả: ~75% peak FP16 hoặc 1.2 PFLOPS FP8 — gần giới hạn vật lý của GPU. Mọi LLM hiện đại (GPT-4, Claude, Llama, Gemini) đều dùng Flash Attention, thường kết hợp với KV cache và Transformer architecture.
Flash Attention KHÁC với Linear Attention, Performer, Linformer. Những cái đó là xấp xỉ — dùng kernel trick hoặc low-rank projection để giảm O(N²) → O(N). Flash Attention là exact — kết quả đồng nhất với standard. Bạn có thể dùng Flash làm drop-in replacement mà không mất độ chính xác của mô hình.
Mặc định luôn dùng. PyTorch 2.0+ tự động chọn Flash khi có thể. Bạn chỉ cần đảm bảo: (1) dtype là fp16 hoặc bf16; (2) head_dim ≤ 128 (FA1/FA2) hoặc ≤ 256 (FA3); (3) mask nằm trong các pattern hỗ trợ (causal, padding, sliding window) — nếu custom, dùng FlexAttention.
Flash Attention là một trong những optimization có tác động lớn nhất trong hệ sinh thái LLM. Nó cho phép training và serving context dài (32K–1M tokens) trở nên kinh tế. Không có Flash, Retrieval-Augmented Generation trên tài liệu dài, long-form writing, hoặc analyzing codebase đều sẽ đắt hơn hàng chục lần.
# Cách dùng Flash Attention trong PyTorch
import torch
import torch.nn.functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel
# PyTorch 2.0+ có built-in Flash Attention!
# Tự động chọn Flash Attention khi có thể.
q = torch.randn(1, 8, 4096, 64, device="cuda", dtype=torch.float16) # (B, heads, N, d_k)
k = torch.randn(1, 8, 4096, 64, device="cuda", dtype=torch.float16)
v = torch.randn(1, 8, 4096, 64, device="cuda", dtype=torch.float16)
# Tự động dùng Flash Attention nếu available
output = F.scaled_dot_product_attention(q, k, v, is_causal=True)
# Memory: O(N) thay vì O(N²)
# Speed: 2–4× nhanh hơn naive implementation
# Ép dùng Flash backend (báo lỗi nếu không được) để debug
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
output = F.scaled_dot_product_attention(q, k, v, is_causal=True)
# Hoặc dùng thư viện flash-attn trực tiếp (chi tiết hơn)
# pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func
output = flash_attn_func(q, k, v, causal=True, dropout_p=0.0, softmax_scale=1.0 / (64 ** 0.5))
# So sánh memory:
# N=4096: Standard ~64MB, Flash ~0.5MB (giảm 128×)
# N=128K: Standard ~64GB, Flash ~0.5MB (giảm 130,000×)
# Pseudocode Flash Attention forward pass
# Tham khảo: Dao et al. 2022, Algorithm 1
# N: seq length, d: head dim, Br/Bc: block row/col sizes
import math
import torch
def flash_attention_forward(Q, K, V, Br=64, Bc=64, causal=False):
"""
Q, K, V: (N, d) tensors trên HBM
Trả về O: (N, d) — tương đương softmax(QK^T / sqrt(d)) V
Bộ nhớ thêm: O(N) (running stats), KHÔNG phải O(N^2)
"""
N, d = Q.shape
scale = 1.0 / math.sqrt(d)
# Output + running stats, chỉ O(N) bộ nhớ
O = torch.zeros_like(Q) # (N, d)
l = torch.zeros(N, device=Q.device) # running sum of exp
m = torch.full((N,), -float("inf"), device=Q.device) # running max
Tr = (N + Br - 1) // Br # số block của Q
Tc = (N + Bc - 1) // Bc # số block của K, V
# OUTER loop trên Q blocks (FA2 đổi thứ tự: outer là Q)
for i in range(Tr):
qi_s, qi_e = i * Br, min((i + 1) * Br, N)
Qi = Q[qi_s:qi_e] # (Br, d) — load SRAM
Oi = torch.zeros(qi_e - qi_s, d, device=Q.device)
li = torch.zeros(qi_e - qi_s, device=Q.device)
mi = torch.full((qi_e - qi_s,), -float("inf"), device=Q.device)
# INNER loop trên K,V blocks
for j in range(Tc):
if causal and j * Bc > qi_e - 1:
break
kj_s, kj_e = j * Bc, min((j + 1) * Bc, N)
Kj = K[kj_s:kj_e] # (Bc, d)
Vj = V[kj_s:kj_e] # (Bc, d)
# Attention scores cho tile này — hoàn toàn trong SRAM
Sij = Qi @ Kj.transpose(0, 1) * scale # (Br, Bc)
if causal:
mask = torch.arange(kj_s, kj_e, device=Q.device) > torch.arange(qi_s, qi_e, device=Q.device)[:, None]
Sij.masked_fill_(mask, -float("inf"))
# ── Online softmax ───────────────────────────────
mij_local = Sij.max(dim=-1).values # (Br,)
mi_new = torch.maximum(mi, mij_local)
Pij = torch.exp(Sij - mi_new[:, None]) # (Br, Bc)
lij_local = Pij.sum(dim=-1) # (Br,)
alpha = torch.exp(mi - mi_new) # rescale factor cho O cũ
li = alpha * li + lij_local
Oi = alpha[:, None] * Oi + Pij @ Vj
mi = mi_new
# ─────────────────────────────────────────────────
# Ghi kết quả block Q này ra HBM
O[qi_s:qi_e] = Oi / li[:, None]
l[qi_s:qi_e] = li
m[qi_s:qi_e] = mi
return O, l, m # l, m dùng lại ở backward
- LLM context 128K+ tokens: GPT-4 Turbo, Claude 200K, Gemini 1.5 (1M tokens) đều dùng Flash Attention hoặc biến thể. Không có Flash, KV cache + attention matrix sẽ vượt VRAM.
- Huấn luyện Llama / Mistral / Qwen: Flash Attention 2 tăng throughput training ~2×, tiết kiệm ~40% memory — nghĩa là có thể fit batch lớn hơn hoặc sequence dài hơn trên cùng một cluster.
- Serving vLLM, TGI, SGLang: Các engine serving đều dùng Flash Attention cho prefill (encode prompt dài) kết hợp paged attention cho decode. Giảm tail latency đáng kể.
- Vision Transformer trên ảnh lớn: Ảnh 1024×1024 patch 16×16 → 4K tokens. Flash giúp train ViT-Huge trên ảnh HD không OOM.
- Protein folding / AlphaFold-like: Sequence amino acid có thể dài hàng ngàn. Memory-efficient attention là bắt buộc.
- Whisper / ASR long-form: Xử lý audio dài 30 phút = chuỗi token dài. Flash Attention giữ memory dưới ngưỡng GPU tiêu dùng.
- Kỳ vọng Flash thần kỳ với chuỗi ngắn: Dưới 1K tokens, Flash Attention có thể không nhanh hơn — compute đủ rẻ, overhead tile management chiếm tỷ trọng lớn. Speedup thực sự lộ ra khi N ≥ 4K.
- Dùng sai dtype: Flash Attention chính thức yêu cầu fp16 hoặc bf16. Nếu vô tình chạy fp32, PyTorch sẽ fallback về kernel thường — tưởng đã dùng Flash nhưng không. Luôn kiểm tra với profiler.
- Head dim > 128 không hỗ trợ: FA1/FA2 yêu cầu head_dim ≤ 128 (phải vừa SRAM). FA3 mở rộng lên 256 trên H100. Nếu mô hình có head lớn, phải fallback hoặc chia.
- Mask phức tạp: Custom attention mask (không phải causal/padding) khó tận dụng Flash. Dùng FlexAttention (PyTorch 2.5+) — compile mask thành kernel Flash tối ưu.
- Flash Attention = exact attention (không xấp xỉ) nhưng IO-efficient: chia thành blocks, tính trong SRAM nhanh gấp 10× HBM.
- Giảm bộ nhớ O(N²) → O(N): không lưu ma trận attention N×N đầy đủ, chỉ giữ running stats cho online softmax.
- Nhanh hơn 2–4× nhờ giảm HBM IO trips. Insight: GPU tắc ở memory bandwidth, không phải compute — giảm IO quan trọng hơn giảm FLOPs.
- Online softmax: cập nhật running max + running sum qua mỗi block, rescale output tương ứng → kết quả đồng nhất với softmax một lần.
- FA1 (2022) → FA2 (2023, 2× nhanh hơn, song song tốt hơn) → FA3 (2024, tận dụng H100 TMA/WGMMA/FP8, ~75% peak).
- Mọi LLM hiện đại dùng Flash Attention — GPT-4, Claude, Llama, Mistral, Gemini. Cho phép context 128K–1M tokens kinh tế.
Kiểm tra hiểu biết
Standard attention tốn O(N²) bộ nhớ. Flash Attention giảm xuống bao nhiêu?