Gradient Descent in GPT-4 Training
Gradient descent chạy suốt 34 ngày để huấn luyện GPT-3
Công ty nào đang ứng dụng Gradient Descent?
Tháng 6 năm 2020, OpenAI công bố GPT-3, mô hình ngôn ngữ 175 tỉ tham số. Theo Lambda Labs, chi phí tính toán để huấn luyện một lần rơi vào khoảng 12 triệu USD. Phần lớn số tiền đó không chảy vào kiến trúc sang trọng. Nó chảy vào một việc duy nhất, lặp lại hàng triệu lần trên 1.024 GPU V100.
Việc duy nhất đó được gọi là gradient descent. Thuật toán hỏi “nếu ta nhích trọng số này xuống một chút, loss giảm bao nhiêu?” rồi trả lời bằng một bước rất nhỏ về phía đáy thung lũng mất mát. Lặp vài trăm triệu lần, mô hình đi từ ngẫu nhiên hoàn toàn tới trả lời được câu hỏi tiếng Việt. Tại sao OpenAI chọn thuật toán mộc mạc này thay vì một thứ kỳ diệu hơn?
Chi phí của một lần huấn luyện gradient descent (ước lượng GPT-3)
175 tỉ
tham số cần cập nhật
1.024
GPU V100 song song
~34 ngày
chạy liên tục
~12M USD
chi phí tính toán
Gradient descent sống sót vì ba lý do: rẻ trên mỗi bước, mở rộng được với bất kỳ số chiều (kể cả 175 tỉ), và có nhiều biến thể (AdamW, momentum, learning-rate schedule) giúp ổn định khi scale. Bài này bạn sẽ tự vặn learning rate, momentum, batch size của một mô hình nhỏ và xem loss, độ chính xác, phân bố trọng số đổi ra sao.
Vấn đề công ty cần giải quyết
Bề mặt mất mát (loss surface) của GPT-3 sống trong không gian 175 tỉ chiều. Bạn không thể “nhìn” nó. Vậy mà OpenAI phải tìm một điểm thấp trong đó, bằng cách chỉ tính gradient (hướng dốc nhất) tại vị trí hiện tại.
Saddle point
Điểm yên ngựa: gradient gần 0 nhưng không phải đáy. SGD thuần dễ kẹt. Momentum giúp trượt qua.
Noisy gradient
Mỗi batch chỉ vài triệu token. Gradient trên batch lệch trung bình thật. Batch lớn thì ít nhiễu nhưng đắt.
Exploding loss
Một batch dị thường có thể đẩy gradient lên 1000 lần. Không clip, 34 ngày công sức tan theo một bước.
Bài toán của OpenAI không phải “tìm thuật toán mới”. Bài toán là giữ gradient descent ổn định qua 300 tỉ token dữ liệu, trên 1.024 GPU, trong 34 ngày, với chi phí 12 triệu USD. Sai learning rate 3 lần có thể tốn hàng triệu USD. Vì thế mỗi hyperparameter (η, β₁, β₂, batch size, clip threshold) đều được chọn từ hàng trăm thí nghiệm trên mô hình nhỏ.
Cách Gradient Descent giải quyết vấn đề
Lấy một mini-batch ngẫu nhiên. Dùng toàn bộ 300 tỉ token cho mỗi bước là không thể. Thay vào đó, OpenAI chia dữ liệu thành các mini-batch cỡ 3,2 triệu token. Mỗi GPU xử lý một phần, gradient cục bộ được tổng hợp qua all-reduce.
Tính gradient qua backprop. Với mini-batch hiện tại, tính loss rồi lan truyền ngược qua 96 lớp Transformer để có gradient của mọi tham số. Đây là bước đắt nhất, chiếm phần lớn chi phí GPU của lần huấn luyện 12 triệu USD.
AdamW cập nhật trọng số.OpenAI không dùng SGD thuần. Họ dùng AdamW, vốn giữ trung bình động của gradient (momentum) và của gradient bình phương (variance). Mỗi trọng số có bước thích ứng riêng: trọng số nào gradient lớn thì bị rút bước, nhỏ thì được đẩy nhanh. Weight decay tách riêng khỏi gradient giúp mô hình không “bơm” trọng số lên quá lớn.
Clip gradient, rồi cập nhật. Trước khi áp bước, cắt gradient sao cho norm không vượt 1,0. Nếu một batch dị thường tạo gradient khổng lồ, clip kéo nó về cỡ bình thường. Đây là phanh cứu cả lần huấn luyện khỏi phân kỳ.
Lặp lại cho đến khi loss không giảm nữa. Vòng lặp trên chạy khoảng 500 tỉ lần cho GPT-3. Theo dõi loss, độ chính xác và phân bố trọng số sau vài nghìn bước. Nếu đúng hướng, giữ nguyên. Nếu lệch, rollback checkpoint và điều chỉnh η.
Simulator: huấn luyện mô hình nhỏ phân loại MNIST
Mô hình nhỏ (một trọng số w, mục tiêu w* ≈ 1,5) học phân loại nhị phân có nhiễu. Vặn learning rate, momentum và batch size để xem loss, độ chính xác, phân bố trọng số đổi qua 50 epoch giả lập.
Hyperparameter: thả tay vặn thử
Loss
0.06Accuracy
47%Weight histogram (phân bố trọng số qua thời gian)
chấm = trung tâm, vùng bo = spreadHội tụ êm. Lý tưởng
η = 0.080, β = 0.85, batch = 64Learning rate, momentum và batch size phối hợp tốt. Loss giảm đều, độ chính xác tăng đều, phân bố trọng số ổn định. Đây là tín hiệu OpenAI theo dõi suốt 34 ngày.
0.080.0011.50.8500.99644256Ba biểu đồ, một câu chuyệnLoss cho biết mô hình có đi xuống đáy. Độ chính xác cho biết nó có học phân loại (đôi khi loss giảm nhưng độ chính xác không tăng, đó là dấu hiệu overfitting). Weight histogram cho biết trọng số có đang bùng nổ hoặc co về 0. Đó là chỉ số sức khoẻ mà kỹ sư của OpenAI theo dõi suốt 34 ngày.Ba mẹo thực tế giữ gradient descent ổn định
Ba quirk bất kỳ team huấn luyện mô hình lớn đều biết. Bấm “Tiếp tục” để lật từng mẹo kèm hình mô tả.
1 · WarmupWarmup. Khởi động chậm vài nghìn bước đầu
η tăng tuyến tính từ gần 0 tới đỉnh trong khoảng 375 triệu token đầu
Đầu huấn luyện, trọng số là số ngẫu nhiên, gradient có thể lệch cỡ lớn. Nếu cho η chạm đỉnh ngay bước 1, bước cập nhật đầu có thể làm hỏng hệ thống trước cả khi nó ổn định. Warmup nâng dần η qua vài nghìn bước đầu, giống như để xe ấm máy trước khi đạp ga.
Chẩn đoán ba loss curve kỳ lạ
Team huấn luyện nhìn ba biểu đồ dưới đây xuất hiện trên TensorBoard. Bạn chọn nguyên nhân nào là hợp lý nhất?
Biểu đồ 1, loss jagged (răng cưa): giảm được vài bước rồi nhảy tung lên cao, rồi lại giảm, lặp đi lặp lại suốt 20 epoch. Độ chính xác cũng nhấp nhô. Nguyên nhân khả dĩ nhất?
Biểu đồ 2, loss plateau (nằm ngang): giảm rất nhanh trong 5 epoch đầu rồi phẳng lì 30 epoch tiếp theo, không nhúc nhích. Độ chính xác kẹt ở 62%. Bước nào nên thử TRƯỚC?
Biểu đồ 3, loss = NaN sau epoch 12: loss giảm đẹp tới epoch 11, sang epoch 12 thì toàn biểu đồ đầy NaN. Team nên làm gì?
Bảng tra nhanh: loss curve → biện pháp
Loss giảm êm
Giữ nguyên. Theo dõi accuracy và weight histogram để phát hiện sớm vấn đề. Đừng vội giảm η cho tới khi plateau rõ ràng.
Loss răng cưa, dao động
Giảm η 1,5×–2×, hoặc tăng batch size 2× để giảm noise. Nếu vẫn dao động, tăng momentum để bước đi “nặng” hơn.
Loss plateau, không giảm
Có thể saddle point: bật / tăng momentum. Có thể η quá nhỏ: nhân η cho 2–3. Có thể dataset học hết: lúc dừng huấn luyện.
Loss tăng dần (chưa NaN)
η hơi lớn, mỗi bước làm loss tăng. Giảm η ngay, hoặc rollback checkpoint. Nếu mới tăng vài bước, bật warmup có thể xử lý.
Loss nổ / NaN đột ngột
Rollback checkpoint trước khi nổ. Giảm η 3–5×. Bật / siết gradient clipping. Kiểm tra batch gần nhất có bất thường không.
Accuracy tăng, loss phẳng
Cross-entropy và độ chính xác đo hai thứ khác nhau. Mô hình vẫn học đúng, chỉ “tự tin” chưa đủ. Thường là tín hiệu ổn.
- Bốn điều rút ra từ bài này
- Huấn luyện GPT-3 là một vòng lặp gradient descent chạy 500 tỉ lần. 12 triệu USD phần lớn tiêu cho bước backprop tính gradient.
- Ba hyperparameter quyết định sống chết: learning rate (quá to → nổ, quá nhỏ → lãng phí), momentum (giúp qua saddle point), batch size (batch to → noise thấp nhưng đắt).
- Ba mẹo cứu cánh: warmup (khởi động chậm), learning-rate decay (rón rén cuối run), gradient clipping (phanh chống nổ).
- Loss curve là bảng điều khiển: jagged → giảm η, plateau → bật momentum, NaN → rollback + siết clip. Đọc curve là kỹ năng cốt lõi của kỹ sư huấn luyện.
Quay lại lý thuyếtNếu bạn muốn thấy gradient descent chạy trên bản đồ contour 2D để cảm nhận “lăn xuống dốc”, quay lại bài Gradient Descent.
Con số thật
- GPT-3 huấn luyện trên 300 tỉ token, dùng AdamW với β₁ = 0,9, β₂ = 0,95, weight decay = 0,1 [1]
- Peak learning rate cho GPT-3 là 6×10⁻⁵. Warmup tuyến tính trong 375 triệu token đầu, rồi decay cosine tới 10% giá trị đỉnh [1]
- Chi phí tính toán huấn luyện GPT-3 ước lượng ~12 triệu USD, chạy trên 1.024 GPU V100 trong khoảng 34 ngày [5]
- AdamW (Loshchilov & Hutter 2019) cải thiện generalization so với Adam chuẩn bằng cách tách weight decay khỏi gradient. Đây là tiêu chuẩn vàng cho Transformer [2]
- Warmup ~2.000 bước đầu giảm đáng kể rủi ro phân kỳ; nghiên cứu 2024 chỉ ra có thể rút ngắn nếu khởi tạo trọng số cẩn thận [3]
- Gradient clipping (Pascanu et al. 2013) là biện pháp ngăn gradient nổ; threshold 1,0 được dùng phổ biến cho LLM lớn [4]
Nếu không có Gradient Descent, app sẽ ra sao?
Không có gradient descent, không có cách nào điều chỉnh 175 tỉ trọng số theo một tín hiệu loss duy nhất. Mọi kỹ thuật thay thế (grid search, đạo hàm số, tiến hoá) đều có độ phức tạp tăng theo số chiều, tức là bất khả thi ở quy mô tỉ tham số.
Không có AdamW, mô hình sẽ cần tinh chỉnh learning rate thủ công cho từng nhóm tham số. Không có warmup, bước đầu dễ làm gradient nổ. Không có learning-rate decay, mô hình dao động mãi quanh đáy. Không có gradient clipping, một batch xấu huỷ cả 34 ngày huấn luyện.
Bài học rút ra: gradient descent không phải công thức kỳ diệu. Nó đơn giản đến nỗi một thuật toán 70 năm tuổi vẫn làm nền tảng cho GPT-3, GPT-4 và mọi mô hình sau đó. Cái khó nằm ở việc giữ nó ổn định suốt 34 ngày trên 1.024 GPU. Đó là chỗ AdamW, warmup, decay, clipping cộng dồn thành 12 triệu USD tiêu đúng chỗ.