Recurrent Neural Network
Mạng nơ-ron hồi quy
Bạn đọc câu "Tôi yêu ___". Từ tiếp theo dễ đoán là "mèo", "phở", "em"... Não bạn đã dùng gì để đoán?
Hãy tưởng tượng bạn đọc một truyện ngắn, tay cầm cuốn sổ nhỏ. Sau mỗi từ, bạn không viết lại cả câu vào sổ — chỉ cập nhật vài dòng ngắn ghi lại "ý chính đến lúc này". Cuốn sổ ấy chính là trạng thái ẩn hₜ của RNN. Cùng một bàn tay (cùng bộ trọng số) cập nhật cuốn sổ ở mọi bước; chỉ có nội dung cuốn sổ là thay đổi.
Khi mở cuộn (unroll) mạng theo thời gian, một ô RNN "lặp lại" trở thành nhiều bản sao nối tiếp nhau — tất cả cùng dùng chung W_xh, W_hh và b. Hình bên dưới cho bạn thấy điều đó diễn ra với câu "Tôi yêu mèo". Hãy bấm Bước tiếp hoặc Tự động chạy để quan sát hₜ truyền từ trái sang phải.
Hình minh họa
h₀: chưa có ký ức — khởi tạo bằng vector 0.
Trạng thái ẩn hₜ là "tóm tắt" của mọi thứ RNN đã thấy từ đầu đến bước t. Nó là một vector cố định chiều, không phải danh sách các token.
Gradient chảy ngược (màu nhạt dần = tín hiệu học yếu dần)
Một RNN đã huấn luyện xong không nhớ câu "Tôi yêu mèo" như một cái máy ghi âm. Nó có một cuốn sổ nhỏ cố định chiều (hₜ) và một cây bút cố định (W_xh, W_hh, b). Cây bút đó viết lại cuốn sổ sau mỗi từ; chính cuốn sổ mới là "ký ức" — không phải danh sách từng từ.
Đó là lý do RNN có thể xử lý chuỗi bất kỳ độ dài: 10 từ hay 10.000 từ đều dùng chung một bộ trọng số, chỉ có cuốn sổ chạy theo thời gian.
RNN xử lý câu 6 từ cần 6 bước tuần tự. Transformer xử lý cùng câu đó cần tối đa mấy bước?
Bạn huấn luyện RNN thuần trên câu dài 100 từ, loss giảm rất chậm và gradient ở các tầng đầu gần bằng 0. Nguyên nhân phù hợp nhất là gì?
Giải thích
Mạng nơ-ron hồi quy (Recurrent Neural Network — RNN) là một họ kiến trúc xử lý chuỗi bằng cách duy trì một vector trạng thái ẩn hₜ và cập nhật nó tuần tự qua các bước thời gian. Tại mỗi bước, đầu vào xₜ và trạng thái ẩn cũ hₜ₋₁ được đưa vào cùng một ô (cell) — ô này có chung bộ trọng số ở mọi bước (weight sharing theo thời gian).
Công thức hồi quy cốt lõi của RNN thuần (Elman RNN):
Trong đó là đầu vào (thường là embedding của token t), là trạng thái ẩn, và là đầu ra (có thể bỏ qua nếu chỉ cần hₜ cuối). Ba ma trận cùng bias được chia sẻ ở mọi bước.
Ba thành phần cần nắm vững:
- Đầu vào xₜ: vector biểu diễn cho token/điểm dữ liệu tại thời điểm t. Với văn bản thường là embedding (dim 100–1024); với chuỗi số (giá cổ phiếu, sensor) có thể chỉ vài chiều.
- Trạng thái ẩn hₜ: "cuốn sổ" cố định chiều, chứa tóm tắt tất cả những gì đã thấy. Khi bắt đầu, h₀ thường được khởi tạo bằng vector 0 (hoặc học được như một tham số).
- Cổng phi tuyến tanh: nén tổ hợp tuyến tính về khoảng (−1, 1). Nếu bỏ tanh, RNN sụp đổ thành một phép biến đổi tuyến tính thuần và mất khả năng học quan hệ phi tuyến theo thời gian.
Hiện thực RNN bằng PyTorch có thể rất ngắn gọn. Dưới đây là phiên bản "từ gốc" để bạn thấy đúng công thức trên:
import torch
import torch.nn as nn
class SimpleRNNCell(nn.Module):
"""
Một ô RNN thuần tuân đúng công thức Elman:
h_t = tanh(W_xh * x_t + W_hh * h_{t-1} + b_h)
Dùng nn.Parameter để PyTorch tự tính gradient qua BPTT.
"""
def __init__(self, input_dim: int, hidden_dim: int):
super().__init__()
# Khởi tạo theo Xavier để tránh vanishing/exploding gradient
# ngay từ bước đầu. nn.Linear chỉ là tiện nghi cho W·x + b.
self.W_xh = nn.Linear(input_dim, hidden_dim, bias=True)
self.W_hh = nn.Linear(hidden_dim, hidden_dim, bias=False)
def forward(self, x_t: torch.Tensor, h_prev: torch.Tensor) -> torch.Tensor:
# x_t: (batch, input_dim); h_prev: (batch, hidden_dim)
return torch.tanh(self.W_xh(x_t) + self.W_hh(h_prev))
class SimpleRNN(nn.Module):
"""
Quấn nhiều bước thời gian lại: nhận chuỗi (batch, T, input_dim),
trả về toàn bộ hidden states (batch, T, hidden_dim) và h_T.
"""
def __init__(self, input_dim: int, hidden_dim: int):
super().__init__()
self.hidden_dim = hidden_dim
self.cell = SimpleRNNCell(input_dim, hidden_dim)
def forward(self, x: torch.Tensor, h0: torch.Tensor | None = None):
batch, T, _ = x.shape
if h0 is None:
h0 = x.new_zeros(batch, self.hidden_dim)
h = h0
hs = []
# Vòng lặp tuần tự theo thời gian — đây chính là lý do RNN
# KHÔNG song song hóa được giữa các bước (t và t+1).
for t in range(T):
h = self.cell(x[:, t, :], h)
hs.append(h)
# hs: danh sách T tensor kích thước (batch, hidden_dim)
return torch.stack(hs, dim=1), h
# Thực tế: dùng nn.RNN của PyTorch sẽ nhanh hơn (CuDNN tối ưu).
# Ví dụ sử dụng:
if __name__ == "__main__":
rnn = SimpleRNN(input_dim=50, hidden_dim=128)
# Batch 4, chuỗi 20 bước, mỗi bước embedding 50 chiều
dummy = torch.randn(4, 20, 50)
out, h_final = rnn(dummy)
print(out.shape, h_final.shape) # (4, 20, 128) (4, 128)
Huấn luyện RNN cần Backpropagation Through Time (BPTT) — PyTorch làm tự động nhưng bạn cần thêm gradient clipping để tránh exploding gradient:
import torch
import torch.nn as nn
# Giả sử bạn có sẵn mô hình và data loader
model = nn.RNN(input_size=50, hidden_size=128, batch_first=True)
head = nn.Linear(128, vocab_size) # vocab_size tùy bài toán
opt = torch.optim.Adam(list(model.parameters()) + list(head.parameters()), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(num_epochs):
for batch_x, batch_y in train_loader:
# batch_x: (B, T, 50); batch_y: (B, T) — chỉ số token đích
out, _ = model(batch_x) # (B, T, 128)
logits = head(out) # (B, T, V)
loss = loss_fn(logits.reshape(-1, vocab_size), batch_y.reshape(-1))
opt.zero_grad()
loss.backward()
# QUAN TRỌNG: cắt gradient để tránh exploding gradient
# Với RNN thuần, norm gradient rất dễ nổ > 10^3 khi chuỗi dài.
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
opt.step()
Gradient của loss theo W_hh chứa tích của T ma trận Jacobi liên tiếp. Nếu spectral radius của ma trận đó < 1, gradient biến mất; nếu > 1, gradient bùng nổ. Hai lối thoát:
- Gradient clipping — cắt norm gradient về ngưỡng (thường 1–5). Chặn được exploding nhưng không hồi phục được vanishing.
- Kiến trúc mới — LSTM, GRU, Transformer: thiết kế để gradient không bị nhân dồn qua tanh.
- Streaming: audio real-time, sensor IoT — RNN chỉ cần hₜ₋₁ để tính hₜ, không cần toàn bộ chuỗi trong RAM.
- Thiết bị biên (edge): RAM và compute hạn chế; RNN nhỏ 1–10MB dễ chạy hơn Transformer 100MB+.
- Chuỗi rất dài: attention O(T²) quá đắt; RNN hoặc biến thể SSM (Mamba) có O(T).
Khi chuỗi quá dài (vd 10.000 bước), backprop qua toàn bộ sẽ tốn bộ nhớ khổng lồ. Giải pháp: truncated BPTT — cứ mỗi k bước (vd k=35) thì detach hₜ khỏi đồ thị tính toán trước đó. Gradient chỉ lan ngược trong cửa sổ k bước gần nhất. Mô hình vẫn nhớ dài hạn thông qua hₜ (forward), chỉ gradient mới bị cắt.
So sánh RNN thuần, LSTM, và GRU:
| Kiến trúc | Cổng | Bộ nhớ | Tham số | Mạnh | Yếu |
|---|---|---|---|---|---|
| RNN (vanilla) | Không có cổng | Một trạng thái ẩn h | W_xh, W_hh, b (ít nhất) | Đơn giản, dễ huấn luyện trên chuỗi ngắn | Gradient biến mất/bùng nổ, khó nhớ xa |
| LSTM | 3 cổng: forget, input, output | Cell state c + hidden state h | 4 bộ (W_f, W_i, W_o, W_c) — nhiều hơn ~4× | Nhớ được ngữ cảnh dài 100–300 bước | Chậm hơn RNN, nhiều tham số hơn |
| GRU | 2 cổng: reset, update | Chỉ có hidden state h (gộp cell + hidden) | 3 bộ — gọn hơn LSTM ~25% | Cân bằng giữa LSTM và RNN, ít tham số | Hơi kém LSTM với chuỗi siêu dài |
Những sai lầm thường gặp khi huấn luyện RNN:
- Quên gradient clipping → loss đột ngột nổ thành NaN sau vài epoch. Luôn đặt
clip_grad_norm_(..., max_norm=5.0). - Khởi tạo W_hh kém → orthogonal hoặc identity-ish là lựa chọn tốt; random Gaussian thường có norm quá lớn gây bùng nổ ngay từ đầu.
- Dùng RNN thuần cho chuỗi > 100 bước → gần như chắc chắn vanishing. Chuyển sang LSTM/GRU hoặc truncated BPTT với attention bổ sung.
- Batch không đồng đều độ dài → phải pad và dùng
pack_padded_sequenceđể RNN bỏ qua pad, tránh học sai. - Quên reset hₜ giữa các câu độc lập → mô hình "rò rỉ" ngữ cảnh từ mẫu này sang mẫu khác, gây bias.
Ứng dụng thực tế của RNN (và biến thể):
- Nhận dạng tiếng nói — Deep Speech (Baidu, 2014) dùng LSTM hai chiều; ngày nay Conformer (CNN+attention) đã thay thế nhưng LSTM vẫn mạnh trên thiết bị biên.
- Dự báo chuỗi thời gian — giá cổ phiếu, nhu cầu điện, tải mạng. RNN/LSTM phổ biến vì cần streaming.
- Machine translation — kiến trúc encoder-decoder LSTM từng là state-of-the-art (2014–2017) trước khi Transformer thay thế.
- Gõ phím gợi ý trên điện thoại — RNN/GRU nhỏ chạy on-device vì latency và quyền riêng tư.
- Phát hiện bất thường trong log hệ thống — RNN học pattern "bình thường" rồi đánh dấu bước có loss cao là bất thường.
Hiểu RNN là nền tảng để bạn đọc được LSTM (thêm cell state và 3 cổng), GRU (rút gọn LSTM còn 2 cổng), và tại sao Transformer lại vượt qua cả ba. Bạn cũng nên xem lại backpropagation để nắm BPTT.
- RNN truyền trạng thái ẩn hₜ qua mỗi bước thời gian — hₜ là 'cuốn sổ' tóm tắt mọi thứ đã thấy.
- Công thức cốt lõi: hₜ = tanh(W_xh·xₜ + W_hh·hₜ₋₁ + b). Weight sharing theo thời gian, tương đương CNN chia sẻ kernel theo không gian.
- Vanishing gradient là vấn đề kinh điển: gradient nhân qua T Jacobi co về 0. Triệu chứng: lớp đầu không học.
- Gradient clipping (max_norm 1–5) là bắt buộc để tránh exploding gradient khi chuỗi dài.
- LSTM và GRU khắc phục vanishing bằng cell state + cổng; Transformer giải quyết triệt để bằng attention nhưng mất lợi thế streaming.
- RNN vẫn hữu ích cho streaming, thiết bị biên, và là nền tảng lý thuyết cho State Space Models (Mamba) hiện đại.
Kiểm tra hiểu biết
RNN xử lý câu 100 từ. Tại bước thứ 100, trạng thái ẩn h₁₀₀ chứa thông tin gì?
Sandbox bổ sung — thử nghiệm chiều hidden state (nâng cao)
Kéo thanh trượt dưới đây để mô phỏng việc tăng chiều hₜ — quan sát số phép nhân ma trận tăng theo O(d²). Đây là lý do RNN chiều lớn rất đắt.
Tham số: d_h · (d_x + d_h + 1)
Bộ nhớ khi forward: O(T · d_h)
Rủi ro: vanishing/exploding cao