Calculus for Backpropagation in Model Training
Giải tích trong huấn luyện mô hình
Công ty nào đang ứng dụng Giải tích cho backprop?
Tháng 1 năm 2024, Meta bắt đầu huấn luyện LLaMA 3.1 405B — mô hình ngôn ngữ có 405 tỉ tham số. 16.384 GPU H100 chạy liên tục 54 ngày. Mỗi giây trôi qua, mỗi tham số đều được hỏi một câu duy nhất: “Nếu tôi nhích một tí, loss thay đổi bao nhiêu?”
Đó chính là đạo hàm. Và để trả lời câu hỏi đó cho cả 405 tỉ tham số trong vài trăm mili-giây, Meta dùng một vòng lặp 5 bước không thay đổi từ thời Rosenblatt:
Di chuột qua một bước để xem vai trò của nó trong vòng lặp:
Bài này cho bạn nhìn từng bước trong vòng lặp đó, với một mô hình nhỏ xíu — chỉ một weight — để bạn thấy đạo hàm thực sự làm gì khi learning rate quá to, quá nhỏ, và vừa đủ.
Vấn đề công ty cần giải quyết
LLaMA 3.1 405B phải học từ 15,6 nghìn tỉ token (đơn vị văn bản). Mỗi lần đi qua một batch, mô hình cần điều chỉnh 405 tỉ tham số — cộng lại là con số khổng lồ.
Vấn đề gốc rễ: không thể thử từng tham số một. Nếu mỗi lần muốn biết gradient của một tham số, ta phải chạy lại forward pass, thì sẽ cần 405 tỉ forward pass cho một bước cập nhật duy nhất — không thể nào xong trong một đời người. Chain rule biến bài toán đó từ không khả thi thành khả thi.
405 tỉ tham số
Mỗi tham số là một weight cần được điều chỉnh theo gradient của nó.
126 lớp nối tiếp
Mạng sâu → chain rule nhân hàng trăm đạo hàm cục bộ xuyên qua các lớp.
3,8 × 10²⁵ FLOP
Tổng phép tính. Không có chain rule + song song hoá, bài toán là bất khả thi.
Cách Giải tích cho backprop giải quyết vấn đề
Forward pass — dự đoán từ weights hiện tại. Đưa một batch dữ liệu (hàng triệu token) qua 126 lớp Transformer. Mỗi lớp thực hiện phép nhân ma trận rồi hàm kích hoạt. Cuối cùng, cross-entropy ra một con số duy nhất L — sai lệch hiện tại.
Forward pass qua các lớp Transformer. Batch (tokens) L1 L2 L3 L4 L5 L6 → ŷ 126 lớp Transformer nối tiếp, mỗi lớp nhân ma trận + activation Backward pass — chain rule trôi ngược. Đây là nơi đạo hàm vào cuộc. Thuật toán backprop đi ngược từ L qua từng lớp: ∂L/∂w = ∂L/∂aₙ · ∂aₙ/∂aₙ₋₁ · … · ∂a₁/∂w. Mỗi lớp chỉ cần tính đạo hàm cục bộ của mình, rồi nhân với gradient truyền từ lớp sau.
AdamW — bộ tối ưu thông minh hơn gradient descent cơ bản. Thay vì trừ thẳng gradient, AdamW theo dõi trung bình động của gradient và gradient bình phương, rồi điều chỉnh bước cập nhật cho từng tham số riêng. Kết quả: hội tụ ổn định hơn trên không gian 405 tỉ chiều, ít nhạy với lựa chọn learning rate.
Gradient checkpointing — đánh đổi tính toán lấy bộ nhớ. Mỗi GPU H100 chỉ có 80 GB HBM3 — không đủ lưu toàn bộ giá trị trung gian cho 126 lớp. Giải pháp: chỉ lưu tại một vài lớp “mốc”, phần còn lại tính lại khi backprop cần. Giảm bộ nhớ từ O(n) xuống O(√n) nhưng tăng khoảng 30% tính toán.
Song song hoá 16.384 GPU. Meta kết hợp ba chiến lược: tensor parallelism (chia từng lớp ra nhiều GPU), pipeline parallelism (chia các lớp thành nhóm nối tiếp), data parallelism (mỗi nhóm GPU xử lý batch riêng rồi đồng bộ gradient). Tất cả gradient vẫn được tính bằng cùng công thức chain rule — chỉ khác là được phân tán ra nghìn máy.
Thí nghiệm: kéo learning rate, xem loss hạ — hoặc nổ
Mô hình thu nhỏ: 1 weight, loss L(w) = (w − 2)² + 0.2. Đáp án: w = 2.
epoch loss L = 0.20 Hội tụ êm
Learning rate vừa phải: mỗi bước rút ngắn khoảng cách tới đáy thung lũng, loss giảm đều và dừng yên ở minimum.
0.10.0051.26-48Thực tế ở MetaCho LLaMA 3.1 405B, Meta dùng learning rate đỉnh khoảng 8 × 10⁻⁵ với lịch cosine warmup/decay. Không ai chọn số đó bằng tay — nó đến từ hàng trăm run thử nghiệm trên mô hình nhỏ, rồi scale theo công thức đã kiểm chứng. Một learning rate sai lệch 10× có thể làm hư toàn bộ run 54 ngày.Bên trong MỘT vòng lặp — dữ liệu di chuyển thế nào
Bấm “Tiếp tục” để thấy từng giai đoạn với dữ liệu thật chảy qua.
1 · ForwardForward pass — input chảy xuôi
Một batch (ví dụ 4 triệu token của LLaMA) chảy qua các lớp. Mỗi lớp nhận vector từ lớp trước, nhân ma trận, kích hoạt, đẩy sang lớp sau. Cuối chuỗi: đầu ra ŷ.
Input ŷ chảy xuôi qua 126 lớp Trong thí nghiệm ở trên, bạn kéo η lên 0.9. Loss đầu tiên giảm, nhưng sau đó loss NHẢY TUNG lên rồi phân kỳ. Tại sao?
Ba dấu hiệu loss đang “kêu cứu” trong quá trình huấn luyện
Loss = NaN
Gradient phát nổ. Giảm η, bật gradient clipping, kiểm tra khởi tạo weight.
Loss đứng yên
η quá nhỏ hoặc vanishing gradient. Tăng η, đổi activation sang ReLU, thêm residual.
Loss giảm → tăng lại
η đột ngột quá lớn (thường do dữ liệu batch dị thường). Thêm gradient clipping, kiểm tra pipeline.
- Điều bạn thấy rõ hôm nay
- Vòng lặp huấn luyện là 5 bước lặp đi lặp lại: init → forward → loss → backward → update.
- Backward là nơi ĐẠO HÀM và CHAIN RULE thực sự làm việc — tính gradient cho 405 tỉ tham số trong vài trăm mili-giây.
- Learning rate là tham số hiệu chỉnh quan trọng nhất: quá to → nổ, quá nhỏ → không bao giờ tới đáy, vừa phải → hội tụ.
- Mọi trick tiên tiến (AdamW, gradient checkpointing, mixed precision, gradient clipping, warmup) đều phục vụ một mục tiêu duy nhất: giữ vòng lặp trên ổn định khi scale lên tỉ tham số.
Trở lại lý thuyếtMuốn hiểu vì sao chain rule lại biến bài toán không khả thi thành khả thi, quay lại bài Giải tích cho backprop. Muốn cảm nhận gradient trên mặt 2D tương tác, xem Gradient — mũi tên chỉ đường xuống dốc.
Con số thật
- 405 tỉ tham số, 3,8 × 10²⁵ FLOP tổng cộng để huấn luyện LLaMA 3.1 [1]
- 16.384 GPU H100, 54 ngày huấn luyện liên tục, 30,84 triệu giờ GPU [1]
- 15,6 nghìn tỉ token dữ liệu huấn luyện — mỗi token tham gia ít nhất một lần vào backprop [1]
- Megatron-LM đạt 52% hiệu suất đỉnh GPU nhờ song song hoá gradient qua tensor + pipeline parallelism [3]
- LLaMA-13B (13 tỉ tham số) vượt GPT-3 175B trên nhiều benchmark — cùng thuật toán backprop + chọn lọc kỹ dữ liệu và tối ưu [2]
Nếu không có Giải tích cho backprop, app sẽ ra sao?
Không có chain rule, không có cách nào tính gradient cho 405 tỉ tham số qua 126 lớp. Mỗi tham số sẽ phải được thử sai riêng lẻ — cần hàng trăm tỉ lần forward pass cho một bước cập nhật duy nhất, biến bài toán 54 ngày thành hàng triệu năm.
Backpropagation biến chi phí tính gradient từ O(n) forward pass xuống O(1) — chỉ cần một lần lan truyền ngược. Kết hợp với AdamW, gradient checkpointing, song song hoá 16.384 GPU và mixed precision BF16, giải tích là nền tảng toán học duy nhất khiến việc huấn luyện mô hình hàng trăm tỉ tham số trở nên khả thi.
Bài học rút ra:learning rate không phải hyperparameter “tùy hứng”. Nó là nút chỉnh cỡ bước của gradient descent, và sai 10× có thể hỏng 54 ngày GPU. Đó là lý do các đội huấn luyện mô hình lớn chi hàng triệu đô-la để tinh chỉnh đúng giá trị η trước khi bấm nút “start”.