Weight Initialization
Khởi tạo trọng số
10 vận động viên chuẩn bị chạy đua. Nếu TẤT CẢ đứng chung một vạch xuất phát chính xác cùng vị trí, kết quả sẽ thế nào?
Hình minh họa
Trung bình
-0.0020
Phương sai
0.00638
Max |w|
0.2017
Phương sai activation qua 6 lớp (dùng ReLU):
Lớp 0
1.00
Lớp 1
0.50
Lớp 2
0.25
Lớp 3
0.13
Lớp 4
0.06
Lớp 5
0.03
Weight initialization chọn "vạch xuất phát" cho mạng. Zeros = tất cả đứng cùng chỗ (symmetry). Random lớn = nhảy hỗn loạn (exploding). Xavier/He = phân bổ hợp lý để phương sai ổn định qua mọi lớp — giống cách sắp xếp chỗ ngồi trên xe buýt: đều nhau, không ai quá chật, không ai quá thoải mái!
Bạn dùng tanh activation nhưng He initialization (Var = 2/fan_in). Kết quả có thể xảy ra?
Giải thích
Mục tiêu: giữ phương sai ổn định qua các lớp — không tăng (bùng nổ) cũng không giảm (triệt tiêu) (xem vanishing/exploding gradients).
Xavier (Glorot, 2010):
He (Kaiming, 2015):
Hệ số 2 trong He bù đắp việc ReLU "tắt" 50% output (max(0,x) loại bỏ một nửa phân phối) — chọn init phụ thuộc vào hàm kích hoạt.
import torch.nn as nn
model = nn.Sequential(
nn.Linear(784, 256), nn.ReLU(),
nn.Linear(256, 128), nn.ReLU(),
nn.Linear(128, 10),
)
# He init cho ReLU (PyTorch mặc định Kaiming uniform)
for m in model.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
nn.init.zeros_(m.bias) # bias luôn init = 0
# Xavier init cho Sigmoid/Tanh
# nn.init.xavier_normal_(m.weight)
# Kiểm tra phương sai ban đầu
with torch.no_grad():
x = torch.randn(100, 784)
for layer in model:
x = layer(x)
if isinstance(layer, nn.Linear):
print(f"Var = {x.var().item():.4f}")Fan_in = 10000 (lớp đầu tiên của mạng xử lý ảnh lớn). Xavier init cho Var = 1/10000 = 0.0001. Trọng số sẽ rất nhỏ. Có vấn đề gì không?
- Zeros = symmetry problem (mọi nơ-ron giống hệt nhau). Random lớn = bùng nổ gradient.
- Xavier (Var=1/fan_in): giữ phương sai ổn định cho sigmoid/tanh.
- He (Var=2/fan_in): bù ReLU tắt 50% → tiêu chuẩn cho ReLU. Nhân 2 so với Xavier.
- Quy tắc: ReLU → He, Sigmoid/Tanh → Xavier. Với BatchNorm, init ít quan trọng hơn.
- PyTorch mặc định dùng Kaiming — thường không cần thay đổi trừ khi debug gradient.
Kiểm tra hiểu biết
Tại sao khởi tạo tất cả trọng số = 0 là ý tưởng tệ?