Beam Search
Beam Search — tìm kiếm chùm tia
GPT sinh câu từ prompt "Tôi yêu ___". Bước 1 có 5 ứng viên: "em" (42%), "gia đình" (24%), "cà phê" (20%), "Việt Nam" (8%), "mèo" (6%). Nếu chỉ giữ 1 từ tốt nhất thì sao?
Kéo beam width và nhấn qua từng bước để xem thuật toán mở rộng rồi prune cây sinh token. Dòng sáng = được giữ, dòng mờ = bị loại. Thử cùng một câu với length penalty α khác nhau để thấy câu dài/ngắn được ưu tiên ra sao.
Hình minh họa
Prompt
“Tôi yêu ___”
Cùng một prompt, ba chiến lược giải mã cho ra câu khác nhau. Xem bảng dưới — lưu ý rằng beam search đôi khi tìm được xác suất tổng cao hơn greedy, và sampling cho ra câu khác hoàn toàn.
Greedy (k = 1)
“Tôi yêu em nhiều lắm.”
p ≈ 8.78%
Ưu: Cực nhanh, deterministic, chi phí thấp.
Nhược: Hay bỏ lỡ câu tốt hơn do tham ăn từng bước.
Dùng cho: Draft nhanh, khi chất lượng không quá quan trọng.
Beam (k = 3)
“Tôi yêu gia đình vô cùng.”
p ≈ 9.50%
Ưu: Chất lượng cao, ổn định. Tìm được sequence tốt hơn greedy.
Nhược: Chậm hơn k lần. Luôn cùng output (thiếu đa dạng).
Dùng cho: Dịch máy, tóm tắt, captioning, code completion.
Sampling (top-p = 0.9, T = 0.8)
“Tôi yêu cà phê buổi sáng.”
p ≈ 5.76%
Ưu: Đa dạng, sáng tạo, mỗi lần một khác.
Nhược: Kém ổn định, đôi khi sai ngữ nghĩa.
Dùng cho: ChatGPT, kể chuyện, brainstorm.
Kéo α từ 0 đến 2 để xem tác động. Đường đỏ là raw log-prob (giảm tuyến tính theo độ dài → luôn thiên vị câu ngắn). Đường accent là score sau chia cho |Y|^α.
Hình minh họa
Beam Search giữ k con đường tốt nhất song song, không “đặt cược” vào 1 lựa chọn duy nhất. Giống đội thám hiểm chia nhóm — nhóm nào tìm được đường tốt nhất thì thắng. Điều kỳ diệu: một ứng viên token có xác suất thấp ở bước 1 (ví dụ “gia đình” chỉ 24%) vẫn có thể dẫn đến câu tốt nhất, nếu các token tiếp theo của nó có xác suất điều kiện cao.
k = 1 → Greedy (nhanh, kém). k → ∞ → exhaustive (chậm, tối ưu). k = 4–10 thường là sweet spot cho dịch máy, k = 1 + sampling là chuẩn cho chatbot.
Beam search luôn cho kết quả GIỐNG NHAU cùng input. ChatGPT thì mỗi lần trả lời khác. ChatGPT dùng gì?
Bạn chạy beam search với k = 5 cho một câu dài 20 token. Vocab cỡ 50K. Mỗi bước cần xét bao nhiêu ứng viên trước khi prune?
Giải thích
Beam Search giữ k chuỗi ứng viên (beams) tốt nhất ở mỗi bước, mở rộng song song cho đến khi tất cả kết thúc (gặp <eos>) hoặc đạt độ dài tối đa. Đây là một chiến lược giải mã deterministic, trái ngược với các phương pháp stochastic như top-k / top-p sampling hay chỉnh temperature. Cùng input, beam search luôn cho cùng output.
Với = length penalty (0.6–1.0 phổ biến). → không penalty, thiên vị câu ngắn. → trung bình log-prob mỗi token.
Thay vì chọn top-k, top-p chọn tập nhỏ nhất đạt xác suất tích lũy ≥ p:
Top-p = 0.9 → chỉ chọn từ top 90% xác suất tích lũy. Kết hợp với temperature T để điều chỉnh “độ sáng tạo”.
- Mode collapse: k beam thường rất giống nhau (khác 1-2 từ). Giải pháp: diverse beam search.
- Hallucination cho câu dài: khi k lớn, beam có thể chọn câu “chắc chắn” nhưng sai thực tế.
- Không phù hợp cho hội thoại: luôn cùng output → nhàm chán. Dùng sampling.
- Chi phí O(k × V) mỗi bước — với vocab lớn và k lớn là tốn bộ nhớ.
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("gpt2-medium")
tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")
prompt = "Tôi yêu "
inputs = tokenizer(prompt, return_tensors="pt")
# 1. Greedy (num_beams = 1, không sample)
greedy = model.generate(
**inputs,
max_new_tokens=20,
num_beams=1,
do_sample=False,
)
print("Greedy:", tokenizer.decode(greedy[0], skip_special_tokens=True))
# 2. Beam search (num_beams = 5)
beam = model.generate(
**inputs,
max_new_tokens=20,
num_beams=5,
do_sample=False,
length_penalty=0.8, # 0.6–1.0 là phổ biến
early_stopping=True, # dừng khi đủ k beam đạt <eos>
no_repeat_ngram_size=3, # chặn lặp 3-gram (tránh mode collapse)
num_return_sequences=3, # trả về 3 beam tốt nhất
)
for i, seq in enumerate(beam):
print(f"Beam {i}:", tokenizer.decode(seq, skip_special_tokens=True))
# 3. Nucleus sampling (top_p = 0.9)
sample = model.generate(
**inputs,
max_new_tokens=20,
do_sample=True,
top_p=0.9,
top_k=50,
temperature=0.8,
num_return_sequences=3,
)
for i, seq in enumerate(sample):
print(f"Sample {i}:", tokenizer.decode(seq, skip_special_tokens=True))
# 4. Diverse beam search
diverse = model.generate(
**inputs,
max_new_tokens=20,
num_beams=6,
num_beam_groups=3, # chia k thành 3 nhóm
diversity_penalty=1.0, # hình phạt giữa các nhóm
num_return_sequences=3,
do_sample=False,
)
for i, seq in enumerate(diverse):
print(f"Diverse {i}:", tokenizer.decode(seq, skip_special_tokens=True))import torch
import torch.nn.functional as F
from typing import List, Tuple
def beam_search_decode(
model,
input_ids: torch.Tensor,
beam_width: int = 5,
max_length: int = 30,
length_penalty: float = 0.7,
eos_token_id: int = 2,
) -> List[Tuple[torch.Tensor, float]]:
"""
Beam search giải mã tự viết tay.
Trả về danh sách (sequence, score) đã sắp xếp.
"""
device = input_ids.device
# Mỗi beam: (sequence, cumulative_log_prob, is_done)
beams: List[Tuple[torch.Tensor, float, bool]] = [
(input_ids, 0.0, False)
]
for step in range(max_length):
candidates: List[Tuple[torch.Tensor, float, bool]] = []
for seq, score, done in beams:
# Beam đã gặp <eos> thì không mở rộng nữa.
if done:
candidates.append((seq, score, True))
continue
# Forward pass
with torch.no_grad():
logits = model(seq).logits[:, -1, :] # (1, V)
log_probs = F.log_softmax(logits, dim=-1).squeeze(0)
# Lấy top-k token
top_log_probs, top_indices = log_probs.topk(beam_width)
for lp, idx in zip(top_log_probs.tolist(), top_indices.tolist()):
new_seq = torch.cat(
[seq, torch.tensor([[idx]], device=device)],
dim=1,
)
new_score = score + lp
new_done = (idx == eos_token_id)
candidates.append((new_seq, new_score, new_done))
# Prune: lấy top beam_width theo score sau length penalty
def _rank(item):
seq, sc, _ = item
L = seq.size(1)
return sc / (L ** length_penalty)
candidates.sort(key=_rank, reverse=True)
beams = candidates[:beam_width]
# Nếu tất cả beam đã done thì dừng
if all(done for _, _, done in beams):
break
# Trả về (seq, score đã chia length penalty)
return [
(seq, sc / (seq.size(1) ** length_penalty))
for seq, sc, _ in beams
]- Beam Search giữ k ứng viên tốt nhất song song (k = beam width / num_beams).
- k = 1 → Greedy (nhanh, kém). k = 4–10 → sweet spot cho dịch máy, tóm tắt. k → ∞ → exhaustive (chậm).
- Length penalty chia score cho |Y|^α để tránh thiên vị câu ngắn; α ≈ 0.6–1.0 phổ biến.
- Dịch máy / tóm tắt / captioning → beam search. ChatGPT / kể chuyện → sampling (top-p + temperature).
- Beam search deterministic (cùng input → cùng output); sampling tạo đa dạng, mỗi lần một khác.
- Diverse BS giảm mode collapse; constrained BS cho phép ép/cấm cụm từ cụ thể.
Kiểm tra hiểu biết
Beam width = 1 tương đương thuật toán nào?