Batch Normalization
Chuẩn hoá batch
Lớp 1 chấm thang 100, lớp 2 thang 10, lớp 3 dùng chữ A–F. Muốn so sánh điểm công bằng, bạn làm gì?
BN gồm 3 bước: (1) tính μ, σ² của batch, (2) chuẩn hóa về (0, 1), (3) scale bằng γ rồi shift bằng β. Dùng nút dưới để xem từng bước, kéo các slider để thay đổi phân phối gốc và các tham số học được.
Hình minh họa
Activation gốc có μ = 5.48 và σ = 3.66. Lớp sau khó học nếu phân phối này thay đổi mỗi batch.
Kéo batch size ở trên xuống 2 rồi lên lại 128 để thấy phân phối sau BN ổn định ra sao. Batch quá nhỏ → ước lượng μ, σ² quá ồn → BN phản tác dụng.
So sánh loss qua các epoch trên cùng một kiến trúc (mô phỏng toy). BN giúp hội tụ nhanh hơn và loss mượt hơn. Kéo số epoch để xem xu hướng dài hạn.
Hình minh họa
Có BN
Loss cuối ≈ 0.285. Hội tụ nhanh. Có thể dùng learning rate cao hơn mà không diverge.
Không BN
Loss cuối ≈ 1.031. Chậm hơn, loss dao động mạnh hơn, nhạy cảm với khởi tạo.
Batch Normalization không chỉ là phép “trừ trung bình, chia độ lệch” — nó còn có γ và β học được. Nếu chuẩn hóa là tối ưu, mạng giữ γ=1, β=0. Nếu không, mạng có thể học γ=σ, β=μ để triệt tiêu BN. Nghĩa là: BN không bao giờ hạn chế khả năng biểu diễn — nó chỉ mở thêm lựa chọn dễ học hơn.
Đó là lý do BN (và các biến thể LN, GN, IN) gần như luôn có mặt trong kiến trúc hiện đại: lợi ích rõ ràng, chi phí cực nhỏ, và không hy sinh expressivity.
Bạn huấn luyện CNN với batch size 4 trên GPU nhỏ. Training bị diverge khi thêm BatchNorm2d. Lý do có khả năng cao nhất?
Tại inference bạn gửi từng ảnh một (batch=1). Nếu BN dùng batch stats lúc này, chuyện gì xảy ra?
Giải thích
Ý tưởng cốt lõi: trong một mini-batch , tính thống kê của chính batch và dùng chúng để chuẩn hóa mỗi activation. Sau đó scale bằng γ và shift bằng β — hai tham số học được cho phép mạng tự quyết định muốn chuẩn hóa đến đâu.
1. Trung bình batch:
2. Phương sai batch:
3. Chuẩn hóa:
4. Scale & shift:
Khi training, ngoài dùng μ, σ² của batch hiện tại, BN còn cập nhật EMA (exponential moving average) để dành cho inference:
Với = momentum (PyTorch mặc định 0.1). Khi model.eval(), BN chuyển sang dùng .
- Quên gọi
model.eval()trước inference → BN dùng batch stats của 1 ảnh → kết quả lộn xộn. - Dùng BN với batch rất nhỏ → μ, σ² ồn → diverge. Đổi sang GroupNorm / LayerNorm.
- BN trong RNN không tự nhiên vì sequence length thay đổi; đó là lý do Transformer mặc định dùng LayerNorm.
- Nếu fine-tune trên domain mới với batch khác biệt hoàn toàn, running stats cũ có thể không còn phù hợp. Cân nhắc
track_running_stats=Truevà thời gian warm-up.
import torch
import torch.nn as nn
class ConvBlock(nn.Module):
"""Conv2d -> BatchNorm2d -> ReLU: khối chuẩn của CNN hiện đại."""
def __init__(self, in_ch: int, out_ch: int, kernel_size: int = 3):
super().__init__()
self.conv = nn.Conv2d(
in_ch, out_ch, kernel_size,
padding=kernel_size // 2,
bias=False, # BN có beta => bias của conv thừa
)
# BatchNorm2d: tính thống kê trên (N, H, W) cho từng channel C.
# num_features = out_ch => có out_ch cặp (gamma, beta),
# (running_mean, running_var).
self.bn = nn.BatchNorm2d(
num_features=out_ch,
eps=1e-5,
momentum=0.1, # EMA cho running stats
affine=True, # bật gamma, beta học được
track_running_stats=True,
)
self.act = nn.ReLU(inplace=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.act(self.bn(self.conv(x)))
# Ví dụ sử dụng
model = nn.Sequential(
ConvBlock(3, 64),
ConvBlock(64, 128),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(128, 10),
)
# Training mode: BN dùng batch stats và update running stats.
model.train()
x = torch.randn(32, 3, 32, 32) # batch 32, giống CIFAR
out = model(x)
print("train output:", out.shape)
# Inference: BN dùng running_mean, running_var.
model.eval()
with torch.no_grad():
single = torch.randn(1, 3, 32, 32)
pred = model(single)
print("eval output:", pred.shape)
# Freeze BN (khi fine-tune)
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval() # đóng băng running stats
for p in m.parameters():
p.requires_grad = Falseclass MyBatchNorm2d(nn.Module):
"""Viết tay BatchNorm2d để hiểu kĩ bên trong."""
def __init__(self, num_features: int, momentum: float = 0.1,
eps: float = 1e-5):
super().__init__()
self.num_features = num_features
self.momentum = momentum
self.eps = eps
# Học được
self.gamma = nn.Parameter(torch.ones(num_features))
self.beta = nn.Parameter(torch.zeros(num_features))
# Buffer: không học, nhưng lưu cùng state_dict
self.register_buffer("running_mean", torch.zeros(num_features))
self.register_buffer("running_var", torch.ones(num_features))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x shape: (N, C, H, W)
if self.training:
dims = (0, 2, 3)
mu = x.mean(dim=dims)
var = x.var(dim=dims, unbiased=False)
# update running stats
with torch.no_grad():
self.running_mean.mul_(1 - self.momentum).add_(
mu, alpha=self.momentum
)
self.running_var.mul_(1 - self.momentum).add_(
var, alpha=self.momentum
)
else:
mu = self.running_mean
var = self.running_var
x_hat = (x - mu[None, :, None, None]) / torch.sqrt(
var[None, :, None, None] + self.eps
)
return (
self.gamma[None, :, None, None] * x_hat
+ self.beta[None, :, None, None]
)Mô phỏng EMA cho μ và σ² trong 20 epoch đầu:
| Epoch | running μ | running σ² |
|---|---|---|
| 0 | 0.580 | 2.367 |
| 2 | 1.528 | 4.585 |
| 4 | 2.322 | 6.562 |
| 6 | 2.983 | 8.062 |
| 8 | 3.521 | 9.275 |
| 10 | 3.944 | 10.336 |
| 12 | 4.270 | 11.124 |
| 14 | 4.532 | 11.756 |
| 16 | 4.770 | 12.164 |
| 18 | 4.994 | 12.454 |
Với momentum 0.1, sau vài epoch running stats đã hội tụ về gần batch stats thực tế. Đây chính là con số BN dùng khi model.eval().
- BN chuẩn hóa mỗi mini-batch về μ ≈ 0, σ ≈ 1 rồi scale bằng γ và shift bằng β học được.
- Hiệu ứng: ổn định gradient, cho phép learning rate lớn hơn, hội tụ nhanh hơn, regularization nhẹ.
- Khi inference BN dùng running mean/variance (EMA) — nhớ gọi model.eval() trước khi predict.
- Batch quá nhỏ → μ, σ² ồn → model dễ diverge. Chuyển sang GroupNorm hoặc tăng effective batch.
- BN dùng cho CNN; LayerNorm cho Transformer/RNN; GroupNorm cho batch nhỏ; InstanceNorm cho style transfer.
- γ=σ và β=μ sẽ ‘undo’ BN — chứng tỏ BN không hy sinh khả năng biểu diễn, chỉ nới thêm lựa chọn.
Kiểm tra hiểu biết
Batch Normalization khi inference dùng gì thay cho batch statistics (μ, σ của mini-batch)?