U-Net
Kiến trúc U-Net
Bạn cần gán nhãn MỖI pixel trong ảnh (ví dụ: pixel này là đường, pixel kia là xe). Autoencoder nén ảnh rồi giải nén, nhưng chi tiết pixel bị mất. Làm sao khắc phục?
Hãy tưởng tượng bạn vẽ bản đồ từ ảnh vệ tinh. Bước 1: bạn zoom out dần để thấy bức tranh tổng thể (encoder). Bước 2: bạn zoom in lại, vẽ chi tiết ranh giới (decoder). Ghi chú từ mỗi mức zoom (skip connections) giúp bạn không quên chi tiết.
Hình minh họa
Nhấn vào từng cấp để xem skip connection truyền chi tiết từ encoder sang decoder.
U-Net = Encoder (nén) + Decoder (phóng to) + Skip connections (giữ chi tiết). Hình chữ U! Encoder nắm ngữ cảnh ("đây là con đường"), decoder vẽ chính xác ranh giới pixel, skip connections truyền chi tiết không bị mất khi nén.
Stable Diffusion dùng U-Net với 3 bổ sung: (1) Timestep embedding — cho U-Net biết đang ở bước khử nhiễu nào. (2) Cross-attention — nhận text embedding từ CLIP, cho phép sinh ảnh theo mô tả. (3) Latent space — hoạt động trên latent 64×64 thay vì pixel 512×512.
Encoder U-Net: 256×256→128→64→32→16 (bottleneck). Mỗi lần max pool 2×2. Skip connection ở cấp 64×64 truyền feature map 64×64. Decoder nhận concat: kích thước?
Giải thích
U-Net (Ronneberger et al., 2015) là kiến trúc encoder-decoder hình chữ U dựa trên CNN, ban đầu cho phân đoạn ảnh y tế, nay là backbone chính của diffusion models như Stable Diffusion.
import torch.nn as nn
class UNet(nn.Module):
def __init__(self, in_ch=3, out_ch=1):
super().__init__()
# Encoder
self.enc1 = self.conv_block(in_ch, 64)
self.enc2 = self.conv_block(64, 128)
self.pool = nn.MaxPool2d(2)
# Bottleneck
self.bottleneck = self.conv_block(128, 256)
# Decoder
self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.dec2 = self.conv_block(256, 128) # 128+128=256 input (concat!)
self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.dec1 = self.conv_block(128, 64) # 64+64=128 input
self.out = nn.Conv2d(64, out_ch, 1)
def conv_block(self, in_ch, out_ch):
return nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.ReLU(),
nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.ReLU(),
)
def forward(self, x):
e1 = self.enc1(x) # Skip 1
e2 = self.enc2(self.pool(e1)) # Skip 2
b = self.bottleneck(self.pool(e2)) # Bottleneck
d2 = self.dec2(torch.cat([self.up2(b), e2], 1)) # Concat skip!
d1 = self.dec1(torch.cat([self.up1(d2), e1], 1))
return self.out(d1)- U-Net = Encoder (nén) + Bottleneck + Decoder (phóng to) + Skip connections (concat chi tiết).
- Skip connections concat (nối) feature map encoder với decoder cùng cấp → giữ chi tiết pixel-level.
- Khác ResNet skip (cộng): U-Net skip nối → decoder nhận gấp đôi channel → phong phú thông tin hơn.
- Ứng dụng: phân đoạn ảnh y tế, vệ tinh, tự lái. Là backbone của Stable Diffusion (diffusion models).
- Trong Diffusion: U-Net + timestep embedding + cross-attention (text) → sinh ảnh theo mô tả.
Kiểm tra hiểu biết
U-Net skip connections khác ResNet skip connections thế nào?