Vanishing & Exploding Gradients
Gradient biến mất & bùng nổ
Trò chơi 'truyền tin' qua 10 người: mỗi người thì thầm lại cho người kế tiếp. Nếu mỗi người chỉ nghe được 50% rồi truyền lại, người thứ 10 nghe được bao nhiêu % tin gốc?
Hình minh họa
Gradient đi qua N lớp = nhân N hệ số liên tiếp. Nếu hệ số = 0.5 → (triệt tiêu). Nếu = 2.0 → (bùng nổ). Đây là lý do mạng sâu 100+ lớp cần ReLU, BatchNorm, và skip connections!
Đạo hàm sigmoid tối đa = 0.25. Qua 20 lớp sigmoid, gradient còn bao nhiêu % so với ban đầu?
Giải thích
Khi gradient đi qua N lớp trong backpropagation, theo chain rule:
Nếu mỗi → tích tiến về 0 (vanishing). Nếu > 1 → tích tiến về ∞ (exploding).
5 giải pháp chính:
| Giải pháp | Chống | Cơ chế |
|---|---|---|
| ReLU | Vanishing | Đạo hàm = 1 ở vùng dương, không thu nhỏ gradient |
| BatchNorm | Cả hai | Chuẩn hóa phân phối, giữ gradient trong khoảng ổn định |
| Skip Connections | Vanishing | y = F(x) + x → gradient luôn có đường đi qua "+x" |
| Gradient Clipping | Exploding | Cắt gradient không cho vượt ngưỡng (max_norm) |
| He/Xavier Init | Cả hai | Khởi tạo trọng số giữ phương sai ổn định qua các lớp |
import torch
import torch.nn as nn
# 1. ReLU thay sigmoid ở lớp ẩn
model = nn.Sequential(
nn.Linear(256, 256), nn.ReLU(), # đạo hàm = 1
nn.Linear(256, 256), nn.ReLU(),
# ...100 lớp
)
# 2. He initialization cho ReLU
for m in model.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
# 3. Gradient clipping (chống exploding)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
# 4. Skip connection (ResNet block)
class ResBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(dim, dim), nn.ReLU(),
nn.Linear(dim, dim),
)
def forward(self, x):
return x + self.layers(x) # gradient đi qua "+x" không triệt tiêuLoss đột ngột trở thành NaN ở epoch 50. Nguyên nhân có thể nhất là gì?
- Gradient qua N lớp = tích N hệ số. Hệ số < 1 → triệt tiêu (0.5^10 ≈ 0.001). Hệ số > 1 → bùng nổ (2^10 = 1024).
- Sigmoid/Tanh có đạo hàm ≤ 0.25 → nguyên nhân chính gây vanishing. ReLU (đạo hàm = 1) giải quyết.
- Skip connections (ResNet) cho gradient đi tắt qua '+x' — đột phá cho mạng sâu 100+ lớp.
- Gradient clipping cắt norm gradient tối đa — first aid cho exploding gradient.
- Combo: ReLU + BatchNorm + He Init + Skip Connections = mạng sâu ổn định.
Kiểm tra hiểu biết
Tại sao ReLU giảm vấn đề vanishing gradient so với sigmoid?