Gated Recurrent Unit
Đơn vị hồi quy có cổng
LSTM có 3 cổng và cell state riêng → mạnh nhưng nặng. Nếu bạn muốn đơn giản hóa mà vẫn giữ khả năng nhớ xa, bạn sẽ gộp hoặc bỏ gì?
Nếu LSTM là chiếc xe máy Honda Wave đầy đủ tính năng, thì GRU là chiếc Wave Alpha — bỏ bớt vài tính năng ít dùng, nhẹ hơn, tiết kiệm xăng hơn, mà đi vẫn tốt. Hãy mổ xẻ ô GRU và xem từng cổng làm gì.
Hình minh họa
Nhấn vào cổng để xem chi tiết. Kéo slider để thay đổi input xₜ và hₜ₋₁ — xem ô GRU tính toán thế nào.
Phép nhân ma trận trực tiếp:
Chỉ 2 cổng thay vì 3 — nhưng cổng update kiêm luôn vai trò forget và input. Thử kéo xₜ về -1 và hₜ₋₁ lên 1 xem update gate phản ứng thế nào.
Hãy cùng đi qua một bước GRU đầy đủ, từ input xₜ và hₜ₋₁ cho tới hₜ mới. Mọi phép toán đều là phép nhân/cộng ma trận và hàm kích hoạt quen thuộc.
Tính reset gate rₜ
Giá trị trong [0, 1]. Gần 0 = "quên quá khứ", gần 1 = "giữ đầy đủ quá khứ".
Tính update gate zₜ
Cũng trong [0, 1]. Đây là cổng "thông minh nhất" của GRU — kiểm soát tỷ lệ giữ/thay.
Tính candidate h̃ₜ
Đây là "đề xuất hidden state mới" — chưa được áp dụng ngay. Lưu ý rₜ nhân theo phần tử với hₜ₋₁ trước khi vào tanh.
Trộn cũ và mới → hₜ
Nếu zₜ = 0.9: giữ 90% cũ, chỉ thêm 10% mới. Nếu zₜ = 0.1: gần như thay hoàn toàn bằng candidate.
Đếm "3 ma trận trọng số": Wᵣ cho reset, W_z cho update, W cho candidate. LSTM có 4. Đây là lý do GRU nhẹ hơn ~25%.
Cùng bài toán, cùng dữ liệu — LSTM và GRU thường cho kết quả gần giống nhau. Khác biệt nằm ở tham số, tốc độ train, và một vài tình huống ngoại lệ.
Hình minh họa
GRU (2 cổng)
• 3 ma trận trọng số (Wᵣ, W_z, W)
• 1 trạng thái: hidden (gộp)
• ~3× tham số RNN (ít hơn LSTM 25%)
Dataset nhỏ/trung bình + cần tốc độ? → GRU. Dataset lớn + chuỗi rất dài? → LSTM. Bài toán NLP hiện đại? → Transformer (cả hai đều thua). Trong thực tế, thử cả hai rồi chọn cái tốt hơn trên validation set!
Với cấu hình phổ biến (input_size=128, hidden_size=256), đây là số tham số trainable của GRU so với LSTM:
Hình minh họa
Ít tham số → train nhanh hơn (ít phép nhân ma trận), tốn ít VRAM hơn (đặc biệt khi batch lớn), và ít overfitting (khi data khan hiếm). Đổi lại GRU hơi kém LSTM khi chuỗi cực dài — một đánh đổi hợp lý.
Cổng update là phát minh thông minh nhất của GRU: . Khi : giữ nguyên cũ (forget). Khi : lấy hoàn toàn mới (input). Một cổng, hai vai trò — tổng luôn bằng 1, ràng buộc tự nhiên!
GRU update gate: hₜ = zₜ × hₜ₋₁ + (1-zₜ) × h̃ₜ. Nếu zₜ = 0.9 cho mọi bước, GRU sẽ hoạt động thế nào?
Nếu reset gate rₜ = 0 liên tục, candidate h̃ₜ lúc đó phụ thuộc vào gì?
Giải thích
GRU (Gated Recurrent Unit)được Cho et al. đề xuất năm 2014 trong paper "Learning Phrase Representations using RNN Encoder-Decoder". Ý tưởng chính: đơn giản hóa LSTM mà vẫn giữ khả năng nhớ xa. Hai cổng chính:
Đầy đủ 4 công thức:
LSTMdùng forget gate (fₜ) và input gate (iₜ) độc lập — tổng không cần bằng 1. GRU ép tổng luôn bằng 1: zₜ + (1-zₜ) = 1. Nghĩa là "giữ nhiều cũ" tự động = "thêm ít mới". Ít tham số hơn mà ràng buộc chặt hơn!
Đạo hàm riêng có một thành phần là zₜ — có thể gần 1. Khi backprop qua nhiều bước, gradient không nhân liên tục các số nhỏ (như RNN thường) mà "đi thẳng" qua cổng update. Đây là lý do GRU (và LSTM) tránh được vanishing gradient tốt hơn RNN thường.
Code ví dụ 1: Khai báo GRU trong PyTorch
import torch
import torch.nn as nn
# GRU — đơn giản hơn
gru = nn.GRU(
input_size=128,
hidden_size=256,
num_layers=2,
batch_first=True,
dropout=0.2,
bidirectional=False,
)
# LSTM — để so sánh
lstm = nn.LSTM(
input_size=128,
hidden_size=256,
num_layers=2,
batch_first=True,
dropout=0.2,
bidirectional=False,
)
# So sánh tham số
gru_params = sum(p.numel() for p in gru.parameters())
lstm_params = sum(p.numel() for p in lstm.parameters())
print(f"GRU: {gru_params:,} params") # ~788,480
print(f"LSTM: {lstm_params:,} params") # ~1,050,624
print(f"GRU nhẹ hơn: {(1 - gru_params/lstm_params)*100:.1f}%") # ~25%
# Sử dụng
batch_size, seq_len, feat = 32, 50, 128
x = torch.randn(batch_size, seq_len, feat)
h0 = torch.zeros(2, batch_size, 256) # (num_layers, batch, hidden)
# GRU chỉ trả 1 state (h), LSTM trả 2 (h, c)
output_gru, h_gru = gru(x, h0)
output_lstm, (h_lstm, c_lstm) = lstm(x)
print(output_gru.shape) # torch.Size([32, 50, 256])
print(h_gru.shape) # torch.Size([2, 32, 256])
# Đây là khác biệt rõ nhất khi code: LSTM trả tuple (h, c)
# vì có cell state riêng, GRU thì không.Code ví dụ 2: Viết tay GRU cell từ đầu (để hiểu sâu)
import torch
import torch.nn as nn
import torch.nn.functional as F
class GRUCellManual(nn.Module):
"""GRU cell viết tay — khớp với công thức trong bài."""
def __init__(self, input_size: int, hidden_size: int):
super().__init__()
self.hidden_size = hidden_size
# 3 ma trận trọng số (thay vì 4 của LSTM)
# W_r, W_z, W — mỗi cái xử lý (input + hidden) nối lại
self.W_r = nn.Linear(input_size + hidden_size, hidden_size)
self.W_z = nn.Linear(input_size + hidden_size, hidden_size)
self.W = nn.Linear(input_size + hidden_size, hidden_size)
def forward(self, x_t: torch.Tensor, h_prev: torch.Tensor):
# Nối input và hidden trước đó
combined = torch.cat([x_t, h_prev], dim=-1)
# Reset gate và Update gate — đều dùng cùng combined
r_t = torch.sigmoid(self.W_r(combined))
z_t = torch.sigmoid(self.W_z(combined))
# Candidate — reset gate nhân theo phần tử với h_prev
combined_reset = torch.cat([x_t, r_t * h_prev], dim=-1)
h_tilde = torch.tanh(self.W(combined_reset))
# Trộn cũ và mới — đây là phép "magic" của GRU
h_t = z_t * h_prev + (1 - z_t) * h_tilde
return h_t
# Sử dụng cho chuỗi đầy đủ
cell = GRUCellManual(128, 256)
h = torch.zeros(1, 256) # initial hidden
for t in range(seq_len):
x_t = x[:, t, :] # (batch, input_size)
h = cell(x_t, h)
# h cuối cùng chứa summary toàn chuỗi
# So sánh với nn.GRU có sẵn:
# - nn.GRU tối ưu với cuDNN, nhanh hơn ~5-10x
# - Viết tay hữu ích để debug hoặc custom (vd: attention over hidden)- GRU đơn giản hóa LSTM: 2 cổng (reset, update) thay vì 3, gộp cell state vào hidden state.
- Update gate kiêm cả forget + input: hₜ = zₜ·hₜ₋₁ + (1-zₜ)·h̃ₜ. Tổng luôn bằng 1 — ràng buộc chặt hơn LSTM.
- Reset gate cho phép 'quên sạch' quá khứ khi cần bắt đầu context mới: h̃ₜ = tanh(W·xₜ + U·(rₜ·hₜ₋₁)).
- Ít tham số hơn ~25% → train nhanh hơn ~15-20%, ít overfitting hơn trên dataset nhỏ.
- Hiệu suất thường ngang LSTM — chọn tùy bài toán (nhỏ → GRU, lớn → LSTM, NLP → Transformer).
- Giải vanishing gradient nhờ đường đi 'thẳng' qua zₜ, nhưng vẫn hạn chế với chuỗi >1000 bước.
Kiểm tra hiểu biết
GRU dùng cổng update (zₜ) để làm gì?