Knowledge Distillation
Chưng cất kiến thức
Mô hình GPT-4 rất giỏi nhưng quá lớn để chạy trên điện thoại. Cách nào tốt nhất để tạo phiên bản nhỏ mà vẫn giữ được phần lớn năng lực?
Hãy tưởng tượng một giáo sư ngôn ngữ học đang chỉ dẫn một sinh viên mới vào nghề. Có hai cách dạy hoàn toàn khác nhau.
Cách cổ điển (hard label) — giáo sư chỉ cho sinh viên xem đáp án đúng của từng bài tập: "câu 1 đáp án là A, câu 2 đáp án là B". Sinh viên thuộc bảng đáp án, nhưng không hiểu vì sao A đúng, cũng không biết B và C khác nhau ở đâu. Khi gặp bài tập mới một chút, sinh viên bị động.
Cách chưng cất (soft label) — giáo sư không chỉ nói "đáp án là A", mà còn thì thầm thêm: "mà B cũng khá hợp lý đấy, chỉ thiếu một chi tiết; còn C thì rõ ràng sai vì vi phạm quy tắc này". Sinh viên học được cấu trúc của bài toán, không chỉ đáp án. Lần sau gặp biến thể mới, sinh viên biết dùng nguyên tắc tương tự để suy ra.
Knowledge Distillation chính là cách thứ hai, áp dụng vào mạng nơ-ron. Teacher (mô hình lớn đã được huấn luyện tốt) không chỉ cung cấp label đúng, mà còn cung cấp toàn bộ phân bố xác suất trên mọi lớp — mỗi con số nhỏ trong phân bố là một gợi ý về cách nhìn thế giới của teacher. Student hấp thụ các gợi ý đó và đạt tới năng lực vượt xa việc học hard label thông thường — dù kích thước chỉ bằng một nửa, một phần ba hay thậm chí một phần mười.
Cái đẹp của ẩn dụ này là ta có thể đẩy nó đi xa hơn: teacher không chỉ "thì thầm đáp án", mà có thể chỉ cả cách làm — ánh mắt họ nhìn vào đâu, họ chú ý đến chi tiết nào (tương tự như attention map), họ đã tinh chỉnh qua những tầng suy nghĩ nào (tương tự hidden state trung gian). Khi student khớp cả quá trình đó — không chỉ output cuối — ta gọi là pattern distillation, một phát triển mạnh mẽ hơn distillation gốc.
Hình minh họa
Kiến trúc Teacher vs Student
Teacher sâu & rộng; Student nông & hẹp. Mũi tên xám chỉ ánh xạ pattern distillation giữa các tầng tương ứng.
Soft Labels thay đổi theo nhiệt độ
Kéo thanh trượt nhiệt độ T và quan sát phân bố xác suất thay đổi cho cả teacher lẫn student.
Entropy teacher
1.355
Càng cao = phân bố càng mềm
KL(student || teacher)
0.0003
Loss distillation chính là đại lượng này
Nhân bù T²
16×
Hinton nhân T² để gradient không bị co lại
Nhiệt độ vừa: Phân bố mềm hơn — student học được mối quan hệ giữa các lớp.
Accuracy: có & không có distillation
Kết quả minh hoạ trên một test set 10.000 mẫu. Click vào từng cột để xem ghi chú chi tiết về cấu hình.
Student + KD feat+logit
Student + KD logit + feature — thêm pattern distillation ở các tầng trung gian. Chỉ thua teacher 2.4 điểm nhưng nhỏ hơn ~4.4 lần.
Teacher cho soft label: [0.8, 0.15, 0.04, 0.01]. Hard label là [1, 0, 0, 0]. Thông tin gì bị mất khi dùng hard label?
Bạn huấn luyện student với α = 0 (chỉ soft loss, không dùng hard label). Sau 10 epoch, student bắt đầu phạm các lỗi 'bịa' mà teacher không hề có. Vấn đề gì?
Giải thích
Knowledge Distillation(Hinton, Vinyals & Dean, 2015) huấn luyện student trên hỗn hợp hard label và soft label từ teacher. Đây là một lựa chọn nén mô hình thay thế cho quantization và pruning, với ưu điểm đặc biệt: student có thể có kiến trúc hoàn toàn khác với teacher.
Hàm loss tổng quát của distillation:
Trong đó là temperature, là logits của teacher và student, cân bằng giữa hard và soft loss. Nhân để gradient không bị thu nhỏ khi T tăng — đây là chi tiết kỹ thuật Hinton đưa ra ngay trong paper gốc.
Quy trình distillation cơ bản gồm ba bước:
- Teacher dự đoán: Mô hình lớn đã huấn luyện tốt chạy inference trên tập dữ liệu, tạo soft labels (logit hoặc probability sau softmax).
- Làm mềm phân bố: Chia logits cho T trước softmax — T cao = phân bố phẳng hơn, tiết lộ thứ hạng tương đối của mọi lớp.
- Student học: Huấn luyện trên cả hard labels (ground truth) và soft labels (teacher) với trọng số .
import torch.nn.functional as F
def distillation_loss(
student_logits,
teacher_logits,
labels,
T: float = 4.0,
alpha: float = 0.5,
):
"""Hàm loss distillation chuẩn Hinton 2015.
Args:
student_logits: output thô của student (chưa softmax), shape (B, C)
teacher_logits: output thô của teacher, shape (B, C)
labels: ground truth hard labels, shape (B,)
T: temperature — càng cao, phân bố càng mềm
alpha: trọng số hard loss vs soft loss
Trả về scalar loss để backprop.
"""
# ---- Soft loss: KL divergence trên phân bố đã làm mềm ----
soft_t = F.softmax(teacher_logits / T, dim=-1)
log_soft_s = F.log_softmax(student_logits / T, dim=-1)
soft_loss = F.kl_div(
log_soft_s,
soft_t,
reduction="batchmean",
) * (T ** 2) # nhân T^2 để bù gradient
# ---- Hard loss: cross-entropy với ground truth ----
hard_loss = F.cross_entropy(student_logits, labels)
return alpha * hard_loss + (1 - alpha) * soft_loss
# ---- Vòng lặp huấn luyện ngắn gọn ----
for batch in loader:
x, y = batch
with torch.no_grad():
t_logits = teacher(x) # teacher đóng băng
s_logits = student(x)
loss = distillation_loss(s_logits, t_logits, y, T=4.0, alpha=0.5)
optimizer.zero_grad()
loss.backward()
optimizer.step()Với các biến thể pattern distillation (TinyBERT, MobileBERT, DistilBERT), ta bổ sung thêm các loss trung gian — ví dụ buộc student khớp hidden state và attention map của teacher tại một số tầng tương ứng:
import torch
import torch.nn.functional as F
class PatternDistiller(torch.nn.Module):
"""Distill cả output cuối lẫn hidden state trung gian.
Các tầng của student và teacher được ánh xạ qua student_to_teacher.
"""
def __init__(self, student, teacher, student_to_teacher: dict[int, int]):
super().__init__()
self.student = student
self.teacher = teacher
self.teacher.eval()
for p in self.teacher.parameters():
p.requires_grad_(False)
self.pattern_map = student_to_teacher
# Projection nếu kích thước hidden khác nhau
self.proj = torch.nn.Linear(
student.hidden_size,
teacher.hidden_size,
bias=False,
)
def forward(self, x, labels, T=4.0, alpha=0.5, beta=0.1):
s_out = self.student(x, output_hidden_states=True)
with torch.no_grad():
t_out = self.teacher(x, output_hidden_states=True)
# 1) Logit loss (Hinton KD)
logit_loss = (
F.kl_div(
F.log_softmax(s_out.logits / T, dim=-1),
F.softmax(t_out.logits / T, dim=-1),
reduction="batchmean",
)
* (T ** 2)
)
ce = F.cross_entropy(s_out.logits, labels)
# 2) Hidden-state loss — MSE giữa các tầng tương ứng
hidden_loss = 0.0
for s_idx, t_idx in self.pattern_map.items():
s_h = self.proj(s_out.hidden_states[s_idx])
t_h = t_out.hidden_states[t_idx]
hidden_loss = hidden_loss + F.mse_loss(s_h, t_h)
return alpha * ce + (1 - alpha) * logit_loss + beta * hidden_lossỨng dụng thực tế — distillation được dùng ở hầu hết quy trình triển khai mô hình lớn ra sản phẩm:
- Serving giá rẻ: Teacher GPT-4 → distill thành mô hình 7B cho các tác vụ cụ thể (dịch, tóm tắt, classify). Giảm chi phí 100-1000 lần.
- On-device AI: DistilBERT và MobileBERT chạy trên điện thoại, TinyStories nhỏ đến mức chạy CPU. Latency vài ms.
- Edge computing: Student 1-3M tham số triển khai trên microcontroller cho IoT — teacher có thể là mô hình cloud.
- Ensemble compression: Distill một ensemble 10 mô hình thành một mô hình đơn — giữ 95% năng lực mà chi phí inference chỉ bằng 1/10.
- Privacy-preserving training: Teacher train trên dữ liệu nhạy cảm; student chỉ học qua soft labels — tránh rò rỉ dữ liệu gốc.
Các pitfall thường gặp khi triển khai:
- Teacher yếu: Nếu teacher chỉ đạt 80% accuracy, student hiếm khi vượt quá teacher. Chọn teacher đủ mạnh trước khi distill.
- Capacity gap quá lớn: Teacher 100B → student 100M quá xa; soft labels trở nên "quá phức tạp" so với dung lượng student. Cách khắc phục: dùng teacher assistant — một mô hình trung gian (ví dụ 10B) làm cầu nối.
- Quên đóng băng teacher: Một lỗi phổ biến — teacher không được set
requires_grad = False, dẫn tới cả hai cùng cập nhật và loss trở nên vô nghĩa. - α sai: α = 0 khiến student bắt chước mọi sai lầm của teacher; α = 1 thì distillation mất tác dụng. Bắt đầu với α = 0.5 rồi tinh chỉnh.
- T không đổi: Một số paper cho thấy dùng T schedule (cao lúc đầu, thấp lúc cuối) cải thiện kết quả, giống curriculum learning.
- Dữ liệu distillation kém chất lượng: Nếu chỉ dùng dữ liệu gốc, dark knowledge bị giới hạn. Dùng thêm dữ liệu không nhãn (unlabeled) — teacher tạo label cho nó — thường cải thiện mạnh.
- Distillation dạy student bắt chước teacher qua soft labels — không chỉ đáp án mà cả quá trình suy luận qua phân bố xác suất.
- Dark knowledge: thông tin ẩn trong xác suất nhỏ (15% chó, 10% thỏ) giúp student tổng quát hoá tốt hơn hard labels.
- Temperature T kiểm soát độ mềm: T thấp (1-3) = sắc nét, T vừa (4-10) = cân bằng tốt nhất, T cao (>10) = tín hiệu yếu.
- Loss = α·CE(hard) + (1-α)·T²·KL(soft_t || soft_s). Nhân T² để bù gradient — chi tiết kỹ thuật quan trọng.
- Pattern distillation (TinyBERT, MobileBERT) khớp cả hidden state và attention — cải thiện 2-4% so với logit-only.
- DistilBERT: nhỏ hơn 40%, nhanh hơn 60%, giữ 97% hiệu suất — minh chứng sức mạnh của distillation ở quy mô công nghiệp.
Kiểm tra hiểu biết
Tại sao student học từ soft labels tốt hơn hard labels?