Residual Connections
Kết nối tắt
Mạng 20 lớp đạt accuracy 90%. Bạn thêm thành mạng 56 lớp — kỳ vọng accuracy sẽ tăng. Kết quả thực tế trong thí nghiệm của He et al. (2015)?
Hãy tưởng tượng bạn gửi một tin nhắn bằng miệng qua một dãy 20 người truyền miệng. Qua mỗi người, tin nhắn bị méo mó thêm một chút — thiếu một từ, sai một chữ, hoặc bị hiểu lệch. Đến người thứ 20, tin nhắn gốc đã biến thành chuyện "tam sao thất bản", không còn nhận ra được nữa.
Skip connection giống như việc bạn gửi thêm một bản sao của tin nhắn gốc — bằng giấy, đi thẳng tới người cuối cùng. Dù đường truyền miệng có méo tới đâu, người cuối vẫn có tin nhắn gốc để đối chiếu. Người cuối cùng học cách sửa tin nhắn miệng bằng tin nhắn gốc, thay vì phải dựng lại toàn bộ nội dung từ đầu — nhiệm vụ dễ hơn rất nhiều.
Trong deep learning, "truyền miệng" là các lớp trung gian học đặc trưng, còn "bản sao giấy" là chính input x đi qua đường skip. Mạng chỉ cần học phần chênh lệch F(x) = H(x) − x — phần residual (phần dư). Thay vì buộc mỗi lớp học một ánh xạ phức tạp H(x), ta buộc mỗi lớp học một thay đổi nhỏ F(x). Nếu chẳng có gì đáng thay đổi, F(x) chỉ cần = 0.
Học "không thay đổi gì" (identity) thông qua một chuỗi Conv + BN + ReLU là rất khó — bạn cần nhiều trọng số phối hợp chính xác. Nhưng học "F(x) ≈ 0" chỉ cần đẩy các trọng số về gần 0 — dễ hơn rất nhiều. ResNet biến bài toán khó thành bài toán dễ.
Hình minh họa
ResNet Basic Block — bật/tắt Skip Connection
Bật skip để thấy đường tắt x chảy song song với nhánh chính F(x) = Conv → BN → ReLU → Conv → BN, sau đó cộng lại rồi qua ReLU cuối. Tắt skip để thấy mạng "plain" — gradient sẽ vanishing khi backprop.
Layer được chọn
⊕ (+input)
Phép cộng element-wise: F(x) + x. Đây là trái tim của residual block.
Training curve — depth = 20 (loss giảm theo epoch)
Cùng kiến trúc, cùng hyperparameter — chỉ khác có/không skip connection. Không skip: loss dao động mạnh và drift lên (diverge). Có skip: loss giảm trơn xuống gần 0.
Gradient magnitude per layer (backprop 20 lớp)
Magnitude của ∂L/∂W tại mỗi lớp khi backprop từ loss về đầu mạng. Không skip: gradient bị nhân với tích dF/dx < 1 qua nhiều lớp → nhỏ xíu ở đầu. Có skip: đường tắt đóng góp thành phần 1 → gradient luôn ≥ ~1.
Khi backprop về layer 1 của mạng plain: gradient ≈ 0.85²⁰ ≈ 0.039 — rất nhỏ, trọng số gần như không được cập nhật. Với skip, mỗi block giữ lại thành phần = 1 qua đường tắt → tổng gradient giữ ở mức ≥ 1.
output = F(x) + x. Mạng không cần học toàn bộ ánh xạ H(x) — chỉ cần học phần dư F(x) = H(x) − x. Nếu lớp tối ưu là identity (không cần thay đổi gì), F(x) chỉ cần = 0 — dễ hơn rất nhiều so với việc học lại identity từ một đống Conv + BN + ReLU.
Và gradient của x + F(x) theo x = 1 + dF/dx. Dù dF/dx nhỏ cỡ nào, gradient vẫn có thành phần ≥ 1. Không vanishing. Mạng có thể sâu 100, 152, thậm chí 1000 lớp.
Residual connections là thành phần cốt lõi của Transformer — được dùng quanh cả attention và FFN. Hai câu hỏi dưới đây kiểm tra mức độ nắm vững cơ chế này.
Transformer có 2 residual connections mỗi lớp. Chúng bao quanh gì?
Bạn có một mạng plain 56 lớp và một ResNet 56 lớp cùng số tham số, cùng optimizer, cùng learning rate. Training loss sau 100 epoch kỳ vọng?
ImageNet accuracy theo độ sâu mạng (minh họa)
8 lớp
20 lớp
56 lớp
110 lớp
152 lớp
Không có skip: 8 lớp (90%) > 56 lớp (72%) > 152 lớp (40%). Sâu hơn = TỆ hơn. Có skip: 8 lớp (92%) < 56 lớp (97%) < 152 lớp (97.8%). Sâu hơn = TỐT hơn. Khác biệt duy nhất: skip connection.
Giải thích
Residual Connection(He et al., 2015, "Deep Residual Learning for Image Recognition") là một cơ chế kiến trúc trong đó input của một khối lớp được cộng trực tiếpvào output của khối đó qua một đường "tắt" (shortcut). Thay vì buộc khối phải học ánh xạ mục tiêu H(x) trực tiếp, ta cho khối học phần dư F(x) = H(x) − x. Công thức tổng quát của một residual block là:
với F là hàm residual do các trọng số W₁, W₂, ... biểu diễn (thường là 2 hoặc 3 lớp Conv+BN+ReLU). Khi kích thước của x và F(x) khác nhau (do đổi số kênh hoặc downsample bằng stride), ta dùng projection shortcut:
với W_s là một lớp 1×1 convolution chỉ để đồng bộ dimension. Gradient của loss theo input của block:
Gradient luôn có thành phần 1 từ skip connection cộng với dF/dx. Dù dF/dx nhỏ → gradient vẫn ≥ 1. Tổng gradient qua N block xếp chồng là ∏(1 + dF_i/dx), không phải ∏(dF_i/dx) — đây là lý do ResNet không bị vanishing gradient dù sâu bao nhiêu lớp.
Post-Norm (ResNet gốc, BERT gốc): y = LayerNorm(x + F(x)). Performance tốt nhưng khó train cho mạng rất sâu vì LayerNorm nằm TRÊN đường skip, chặn một phần gradient. Pre-Norm (GPT, LLaMA, hầu hết LLM hiện đại): y = x + F(LayerNorm(x)). Dễ train hơn vì gradient chảy thẳng qua skip không bị LayerNorm chặn — đổi lại cần một LayerNorm cuối cùng trước output.
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicResidualBlock(nn.Module):
"""
ResNet basic block: Conv → BN → ReLU → Conv → BN → (+input) → ReLU
Nếu số kênh đầu vào và ra khác nhau hoặc stride > 1,
dùng projection shortcut (1x1 conv) để khớp dimension.
"""
expansion = 1
def __init__(self, in_ch: int, out_ch: int, stride: int = 1):
super().__init__()
self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3,
stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_ch)
self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_ch)
# Skip connection: identity nếu dimension khớp,
# còn không thì dùng 1x1 conv (projection shortcut).
if stride != 1 or in_ch != out_ch * self.expansion:
self.shortcut = nn.Sequential(
nn.Conv2d(in_ch, out_ch * self.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_ch * self.expansion),
)
else:
self.shortcut = nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = self.shortcut(x)
# Nhánh chính F(x): Conv → BN → ReLU → Conv → BN
out = self.conv1(x)
out = self.bn1(out)
out = F.relu(out, inplace=True)
out = self.conv2(out)
out = self.bn2(out)
# Cộng với skip TRƯỚC khi qua ReLU cuối
out = out + identity
out = F.relu(out, inplace=True)
return out
# Ví dụ: xếp chồng 3 block tạo thành một stage của ResNet
layers = nn.Sequential(
BasicResidualBlock(in_ch=64, out_ch=64),
BasicResidualBlock(in_ch=64, out_ch=64),
BasicResidualBlock(in_ch=64, out_ch=128, stride=2), # downsample
)
x = torch.randn(8, 64, 56, 56)
y = layers(x)
print(y.shape) # torch.Size([8, 128, 28, 28])ResNet-50/101/152 dùng bottleneck blockvới 3 lớp: 1×1 (giảm kênh) → 3×3 (xử lý không gian) → 1×1 (tăng kênh lại). Tên "bottleneck" vì tầng giữa có số kênh nhỏ hơn. Giúp giảm tham số mà vẫn giữ depth lớn. Output cuối cùng có số kênh = 4× input (expansion = 4).
Khi dùng 1×1 conv trên skip path (option B/C), đường tắt không còn là identity thuần khiết — nó trở thành một phép biến đổi tuyến tính có trọng số. Điều này có thể làm giảm nhẹ lợi ích gradient flow, nhưng cần thiết khi downsample. Trong thực hành, người ta chỉ dùng projection ở block đầu của mỗi stage (khi đổi số kênh), còn lại là identity thuần.
import torch
import torch.nn as nn
class PreNormTransformerBlock(nn.Module):
"""
Một block Transformer kiểu Pre-Norm (GPT-2, LLaMA, Mistral).
2 residual connections mỗi block:
1) x = x + MultiHeadAttention(LayerNorm(x))
2) x = x + FFN(LayerNorm(x))
Gradient chảy thẳng qua 2 đường skip không bị LayerNorm chặn.
"""
def __init__(self, d_model: int, n_heads: int, d_ff: int,
dropout: float = 0.1):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.attn = nn.MultiheadAttention(
d_model, n_heads, dropout=dropout, batch_first=True
)
self.norm2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
)
self.drop = nn.Dropout(dropout)
def forward(self, x: torch.Tensor,
attn_mask: torch.Tensor | None = None) -> torch.Tensor:
# Residual 1: quanh self-attention
a = self.norm1(x)
attn_out, _ = self.attn(a, a, a, attn_mask=attn_mask,
need_weights=False)
x = x + self.drop(attn_out)
# Residual 2: quanh FFN
x = x + self.drop(self.ffn(self.norm2(x)))
return x
# Xếp chồng 48 block như GPT-2 Medium
blocks = nn.Sequential(*[
PreNormTransformerBlock(d_model=1024, n_heads=16, d_ff=4096)
for _ in range(48)
])
x = torch.randn(2, 256, 1024)
y = blocks(x)
print(y.shape) # torch.Size([2, 256, 1024])Ngoài ResNet, skip connection là thành phần bắt buộc của: Transformer (Attention + FFN), U-Net (encoder → decoder concatenate), DenseNet (nối tất cả các lớp trước với mọi lớp sau), Highway Networks (tiền nhiệm của ResNet với gating), Diffusion Models (U-Net backbone), GNN (residual giữa các layer), và thậm chí cả LSTM/GRU (cell-state carry-over là một dạng skip qua thời gian).
Ứng dụng thực tế
- Computer Vision: ResNet-50/101/152 là backbone chuẩn cho object detection (Faster R-CNN), segmentation (Mask R-CNN, DeepLab), pose estimation, v.v. Hầu hết mọi CNN từ 2016 trở đi đều có skip.
- NLP / LLM: Mọi Transformer (BERT, GPT, T5, LLaMA, Claude, Gemini) đều có 2 residual connections mỗi block. GPT-3 96 lớp, GPT-4 ước tính 120+ lớp — không có skip là không train được.
- Generative Models: U-Net backbone của Stable Diffusion có skip từ encoder sang decoder ở mọi resolution. DDPM, Flow Matching, Rectified Flow đều dùng.
- Reinforcement Learning: AlphaGo, AlphaZero, MuZero dùng ResNet làm policy/value network. Dreamer, IMPALA dùng residual trong world model.
- Speech: Whisper, wav2vec 2.0 có skip connection trong cả convolutional feature extractor và Transformer encoder.
Những lỗi thường gặp (pitfalls)
- Quên cộng x trước ReLU cuối: Một số implementation đặt ReLU trước phép cộng — điều này cắt toàn bộ giá trị âm của F(x) và có thể làm mạng kém đi. Đặt ReLU SAU phép cộng F(x) + x.
- Dimension mismatch không xử lý:Khi stride > 1 hoặc đổi số kênh, quên projection shortcut → torch báo lỗi shape. Luôn kiểm tra x.shape == F(x).shape trước khi cộng.
- Khởi tạo F(x) quá lớn: Nếu trọng số của F khởi tạo lớn, F(x) át hẳn x — skip mất tác dụng. Dùng zero-initcho lớp cuối của F hoặc gamma = 0 cho BN cuối ("ZeroInit", "Fixup").
- BN trên skip path: Đặt BN trên đường tắt làm skip không còn là identity. Tránh — chỉ đặt BN trong nhánh F(x) hoặc dùng projection thật cẩn trọng.
- Dropout trên skip path:Tương tự, dropout ngẫu nhiên "ngắt" đường skip → mất lợi thế gradient flow. Dropout chỉ trong nhánh F(x).
Trong thực tế:Gần như mọi kiến trúc sâu hiện đại đều có skip connection ở một dạng nào đó. Khi bạn tự thiết kế mạng > 10 lớp, thêm skip gần như không bao giờ hại — chỉ có lợi hoặc không đổi. Đây là một trong những "default choice" an toàn nhất trong deep learning.
- Công thức: output = F(x) + x. Mạng chỉ học phần dư F(x) = H(x) − x thay vì học toàn bộ H(x).
- Gradient = 1 + dF/dx → luôn có thành phần 1 → không vanishing, train được mạng 100+ lớp (ResNet-152, GPT-3, LLaMA).
- Trường hợp xấu nhất: F(x) = 0 → output = x (identity). Thêm block không bao giờ làm tệ hơn.
- Thứ tự chuẩn trong basic block: Conv → BN → ReLU → Conv → BN → (+x) → ReLU. ReLU cuối SAU phép cộng.
- Dimension mismatch: dùng projection shortcut (1×1 conv với stride phù hợp) để khớp kích thước x với F(x).
- Skip connection có mặt trong Transformer (2 per block), U-Net, DenseNet, Diffusion U-Net, LSTM/GRU — là default của deep learning hiện đại.
Kiểm tra hiểu biết
output = F(x) + x. Nếu F(x) = 0 (lớp không học được gì), output = ?