Epochs & Batches in GPT Training
Epoch và batch khi huấn luyện GPT
Công ty nào đang ứng dụng Epoch, batch và iteration?
Khi bạn chat với GPT-4, bạn đang nói chuyện với một mô hình đã “đọc” khoảng 13 nghìn tỉ tokenvăn bản trên internet — tương đương khoảng 50 triệu cuốn sách, gấp hàng nghìn lần số chữ một người có thể đọc trong cả đời. Câu hỏi tự nhiên: làm thế nào để nhét ngần ấy dữ liệu vào một mô hình trong vài tháng huấn luyện?
Câu trả lời gọi là epoch (lượt duyệt qua toàn bộ dữ liệu) và batch(lô — nhóm mẫu xử lý cùng lúc). GPT-4 thường chỉ chạy một hoặc hai epoch — vì một lượt đã đủ nhiều. Nhưng mỗi lượt đó được chặt thành hàng triệu batch, mỗi batch là một bước cập nhật trọng số trên hàng ngàn GPU chạy song song. Bài này sẽ cho bạn thấy cách các lab AI thực sự vặn những con số đó, và vì sao một sai lầm nhỏ trong batch size có thể đốt hàng chục triệu đô-la vô ích.
Vấn đề công ty cần giải quyết
Hãy hình dung bạn có 13 nghìn tỉ tokenvăn bản. Không GPU nào trên trái đất có đủ RAM để ôm toàn bộ dữ liệu cùng lúc. Ngay cả một chiếc H100 (card đồ hoạ chuyên dụng cho AI, RAM 80 GB) cũng chỉ chứa nổi vài triệu token. Vậy mô hình “đọc” kiểu gì?
Và đây mới là câu hỏi đắt giá: nên lặp lại dữ liệu nhiều lần (nhiều epoch, dữ liệu ít) hay đọc một lượt thật kỹ (một epoch, dữ liệu khổng lồ)? Batch nên to bao nhiêu để không hết RAM nhưng vẫn đủ ổn định? Mỗi quyết định ảnh hưởng trực tiếp tới chi phí hàng chục đến hàng trăm triệu đô-la cho một đợt huấn luyện — và quyết định chất lượng mô hình cuối.
13 nghìn tỉ token ≈ 50 triệu cuốn sách. Không một GPU nào chứa hết cùng lúc.
H100 có 80 GB. Một batch 4 triệu token đã chiếm ~100 GB → phải chia tiếp.
Mỗi ngày GPU-cluster ≈ hàng trăm nghìn đô. Chọn sai batch = cháy tiền.
Cách Epoch, batch và iteration giải quyết vấn đề
Chia dữ liệu thành batch, mỗi batch là một bước cập nhật. Thay vì nạp cả 13 nghìn tỉ token cùng lúc, lab AI chia dữ liệu thành những khối nhỏ gọi là batch. Với LLaMA 2, mỗi batch chứa khoảng 4 triệu token— tương đương 1.000 chuỗi, mỗi chuỗi 4.096 token. Sau khi xem xong một batch, mô hình cập nhật trọng số một lần, rồi đọc batch kế tiếp.
Quy luật Chinchilla: khoảng 20 token cho mỗi tham số.Năm 2022 DeepMind công bố kết quả chấn động: GPT-3 (175 tỉ tham số) dùng chỉ 300 tỉ token — tức 1,7 token/tham số — “đói” dữ liệu trầm trọng. Chinchilla 70 tỉ tham số được huấn luyện đúng tỉ lệ 20 token cho mỗi tham số và đánh bại GPT-3 trên gần hết bài kiểm tra. Bài học: tăng dữ liệu đúng tỉ lệ thường đáng giá hơn tăng kích thước mô hình.
Hàng trăm nghìn đến vài triệu iteration trong một epoch.Lấy 2 nghìn tỉ token của LLaMA 2 chia cho batch 4 triệu token → khoảng 500.000 iteration(lần lặp — mỗi lần là một forward + backward + cập nhật). Toàn bộ 500.000 bước này gộp thành một epoch. Một đợt huấn luyện thường chạy trong 2 đến 4 tháng trên cluster hàng nghìn GPU.
Gradient accumulation: nhiều batch nhỏ đóng vai một batch lớn. Nếu một GPU chỉ chứa nổi 500 nghìn token, nhưng bạn muốn batch hiệu dụng 4 triệu token, bạn gom gradient của 8 mini-batch rồi mới cập nhật một lần. Đây là mẹo vàng giúp các đội nhỏ huấn luyện được mô hình to — họ “mô phỏng” batch lớn bằng cách lặp lại batch nhỏ trước khi cập nhật.
Lặp dữ liệu quá 4 epoch bắt đầu gây hại.Muennighoff và cộng sự (2023) chỉ ra: lặp dữ liệu 1–2 lần gần như miễn phí, 3–4 lần có ích giảm dần, sau 4 lần thì giá trị biên gần như bằng 0. Lý do: mô hình bắt đầu thuộc lòng thay vì học mẫu tổng quát. Vì vậy các lab lớn đầu tư mạnh vào thu thập dữ liệu mới thay vì chạy nhiều epoch trên bộ dữ liệu cũ.
Scaling laws: loss giảm dần theo compute
Trục hoành là tổng phép tính (compute, thang log). Trục tung là loss (sai số). Mỗi chấm là một mô hình thật đã được công bố. Đường cong cho thấy quy luật: “nhân đôi compute → loss giảm theo một tỉ lệ dự đoán được”.
Mỗi khi tăng gấp 10 lần compute, loss giảm khoảng 15–20%. Các lab lên kế hoạch huấn luyện dựa trên những đường cong như thế này — trước khi chi tiền.
Thử tự tay
So sánh ngân sách huấn luyện của bốn mô hình thật
Bấm chọn từng mô hình để xem dữ liệu, tỉ lệ token/tham số, và biểu đồ dải batch tương ứng. Bạn sẽ thấy GPT-3 “đói” dữ liệu rõ rệt so với Chinchilla.
Nhận định: Vượt mốc Chinchilla một chút — đầu tư dữ liệu cao hơn để đổi lấy khả năng nói tiếng nhiều vùng hơn.
Mô phỏng chia 2 nghìn tỉtoken thành các batch 4 triệu token. Trên thực tế, số ô nhiều hơn rất nhiều — đây chỉ là minh hoạ.
Vặn batch size — nhìn RAM GPU và nhịp gradient thay đổi
Bạn đang ngồi trên một GPU H100 (RAM 80 GB). Kéo thanh để đổi batch size (tính bằng token). Bạn sẽ thấy: batch quá nhỏ → gradient nhiễu; vừa phải → an toàn; quá lớn → GPU hết RAM.
Ngân sách bộ nhớ và tốc độ gradient
Gradient mượt, tận dụng tốt GPU song song, learning rate có thể nâng lên tương ứng với căn bậc hai của batch.
Lên lịch huấn luyện thực tế cho một mô hình 1 tỉ tham số
Giả sử bạn là một đội startup muốn huấn luyện một mô hình 1 tỉ tham số. Dưới đây là từng bước một đội thật sự đi qua — từ tính ngân sách token đến chạy epoch cuối. Bấm “Tiếp tục” để đi qua từng bước.
Mô hình 1 tỉ tham số theo tỉ lệ Chinchilla cần ~20 tỉ token. Bạn chuẩn bị dữ liệu: Common Crawl đã lọc + Wikipedia + sách + code — tổng ~25 tỉ token để dư phòng. Đây là bước quan trọng nhất — thiếu dữ liệu thì mô hình dù to cũng chỉ là con vẹt.
Thử thách: bạn là nhà nghiên cứu với GPU 24 GB
Một tình huống có thật: bạn muốn huấn luyện mô hình 1 tỉ tham số với batch size 1 triệu token, nhưng bạn chỉ có một GPU 24 GB — thiếu ~10 lần RAM so với nhu cầu (một batch 1 triệu token +optimizer states tiêu tốn khoảng 240 GB). Bạn chọn cách nào?
Bạn có GPU 24 GB, muốn batch hiệu dụng 1 triệu token, nhưng thiếu ~10× RAM. Giải pháp nào khả thi và giữ nguyên chất lượng?
Bạn chạy xong 1 epoch trên 20 tỉ token, loss vẫn cao hơn dự đoán. Chọn chiến thuật tiếp theo có lợi nhất:
Cộng sự đề xuất: 'Thay vì 1 epoch trên 20 tỉ token, hãy chạy 4 epoch trên 5 tỉ token để tiết kiệm chi phí thu thập dữ liệu'. Bạn phản biện thế nào?
Con số thật
- Chinchilla 70B (1,4 nghìn tỉ token, ~20 token/tham số) vượt Gopher 280B (300 tỉ token, ~1 token/tham số) — dữ liệu quan trọng hơn kích thước mô hình [1]
- LLaMA 2 huấn luyện trên 2 nghìn tỉ token với global batch size 4 triệu token — khoảng 500.000 iteration trong 1 epoch [3]
- Tỉ lệ tối ưu Chinchilla: khoảng 20 token dữ liệu cho mỗi tham số mô hình [4]
- Lặp dữ liệu quá 4 epoch khiến giá trị biên giảm gần bằng 0 — ưu tiên dữ liệu mới hơn lặp dữ liệu cũ [2]
Nếu không có Epoch, batch và iteration, app sẽ ra sao?
Hãy tưởng tượng OpenAI quên quy luật Chinchilla và huấn luyện GPT-4 với tỉ lệ 1,7 token/tham số như GPT-3. Với 1,8 nghìn tỉ tham số, họ sẽ chỉ dùng ~3 nghìn tỉ token — và kết quả sẽ là một mô hình đói dữ liệu trầm trọng, hiệu năng kém hơn nhiều so với phiên bản thực tế dùng 13 nghìn tỉ token. Toàn bộ khoản đầu tư hạ tầng có thể biến thành một mô hình tầm trung — mất hàng trăm triệu đô và lợi thế cạnh tranh.
Theo chiều ngược lại, nếu không có khái niệm batch và gradient accumulation, các đội nhỏ sẽ không bao giờ huấn luyện được mô hình tỉ tham số — chỉ những công ty có cluster GPU khổng lồ mới tham gia cuộc chơi. Đây chính là lý do batch và epoch không phải chỉ là chi tiết kỹ thuật: chúng là đòn bẩy dân chủ hoá huấn luyện AI. Hiểu đúng epoch/batch là khác biệt giữa đốt hàng chục triệu đô-la và có một mô hình ra hồn.
- Epoch = một lượt duyệt toàn bộ dữ liệu. GPT-4 chỉ chạy 1–2 epoch vì 13 nghìn tỉ token đã quá đủ cho một lượt.
- Batch = một khối dữ liệu xử lý cùng lúc. LLaMA 2 dùng batch 4 triệu token → ~500 nghìn iteration cho 1 epoch.
- Tỉ lệ vàng Chinchilla: ~20 token dữ liệu cho mỗi tham số. Thiếu thì mô hình đói; thừa thì cần lặp và dễ thuộc lòng.
- Hết RAM GPU? Dùng gradient accumulation: chia thành mini-batch, gom gradient rồi mới cập nhật. Toán học tương đương, chỉ chậm hơn.
Muốn hiểu cặn kẽ vì sao một epoch chia thành nhiều batch, và công thức tính số iteration chính xác? Xem bài lý thuyết Epoch, batch và iteration — nơi chúng ta mổ xẻ cơ chế từng bước cho một mạng nơ-ron nhỏ, trước khi áp dụng cho mô hình khổng lồ như GPT.