Long Short-Term Memory
Bộ nhớ dài-ngắn hạn
Câu: "Tôi sinh ở Huế, học ở Sài Gòn, làm việc ở Hà Nội, nên tôi thích ăn bún bò ___". Từ cuối liên quan đến từ nào cách xa 15 từ? RNN thường nhớ nổi chuỗi bao dài?
Nếu bạn chưa chắc về cách gradient suy giảm qua nhiều bước thời gian, hãy xem lại RNN (Recurrent Neural Network) và Backpropagation Through Time trước. Toàn bộ bài học này xoay quanh câu hỏi: làm sao giữ gradient SỐNG sót qua 100+ bước?
Hãy tưởng tượng bạn đang ghi chép trong lớp học với cuốn vở (cell state — bộ nhớ dài hạn). Tại mỗi thời điểm, bạn cần 4 quyết định: tẩy (cổng quên), nháp (candidate), duyệt viết (cổng nhập), đọc lại (cổng xuất).
Hình minh họa
Nhấn vào từng cổng để xem chức năng chi tiết.
4 khối màu tương ứng 4 phép tính song song trong 1 ô LSTM. 3 cổng (đỏ/xanh lá/xanh dương) dùng sigmoid để "mở–đóng" từ 0 đến 1. Candidate (cam) dùng tanh để đề xuất giá trị MỚI.
Hình minh họa
Trượt các slider để thay đổi xₜ, hₜ₋₁, và cₜ₋₁. Xem các cổng và cell state cập nhật TỨC THÌ. (Trọng số được cố định để bạn thấy hành vi rõ ràng.)
fₜ (forget)
0.761
→ GIỮ hầu hết cell state cũ
iₜ (input)
0.634
→ ghi ít hoặc bỏ qua
g̃ₜ (candidate)
0.744
→ đề xuất TĂNG cell state
oₜ (output)
0.705
→ XUẤT mạnh ra hidden
Kết quả bước này
Cₜ = fₜ · Cₜ₋₁ + iₜ · g̃ₜ = 0.76 × 0.30 + 0.63 × 0.74 = 0.700
hₜ = oₜ · tanh(Cₜ) = 0.70 × tanh(0.70) = 0.426
Thử đặt fₜ về gần 1 (bằng cách tăng xₜ lên ~1.0): bạn sẽ thấy cell state chủ yếu giữ lại Cₜ₋₁ — đây là cơ chế "nhớ lâu" cốt lõi của LSTM.
Giá trị sigmoid nằm trong (0, 1). Hãy nghĩ nó như "độ mở van": 0 = đóng hoàn toàn, 1 = mở hoàn toàn. Với forget gate: 0 = "xóa sạch bộ nhớ", 1 = "giữ nguyên bộ nhớ". Với input gate: 0 = "từ chối ghi", 1 = "ghi toàn bộ candidate".
Cell statelà "đường cao tốc" cho thông tin — chạy thẳng qua mà chỉ bị thay đổi nhẹ nhàng bởi phép nhân/cộng. Gradient cũng chảy thẳng qua đường này → không bị vanishing!
RNN thường: gradient phải đi qua hàm tanh ở mỗi bước → bị nhân nhỏ dần → biến mất. LSTM: gradient đi thẳng qua cell state (nhân với forget gate ≈ 1) → sống sót qua hàng trăm bước! Sau này, Transformer giải quyết vấn đề này triệt để hơn bằng self-attention song song.
Hình minh họa
Cell state qua câu: "Tôi sinh ở Huế nên tôi thích ăn bún bò"
Bước 1: đọc "Tôi"
Cell: [ ]
RNN thường đã quên "Huế" sau 5–6 bước. LSTM giữ được nhờ cell state truyền thẳng — phép nhân với forget gate ≈ 1 = coi như không đổi.
Hình minh họa
Gradient sống sót qua N bước: LSTM vs RNN
Kéo thanh trượt để thay đổi số bước N. Quan sát gradient RNN (đỏ) vs LSTM (xanh).
RNN — gradient ≈
8.67e-19
(tanh′ ≈ 0.25)ᴺ → biến mất rất nhanh
LSTM — gradient ≈
2.15e-1
(f ≈ 0.95)ᴺ → suy giảm chậm hơn ~gấp triệu lần
0.25¹⁰⁰ ≈ 10⁻⁶⁰ — coi như 0. 0.95¹⁰⁰ ≈ 0.006 — nhỏ nhưng không biến mất. Đó là lý do LSTM vẫn học được phụ thuộc dài hạn trong khi RNN thì không. Nhưng LSTM cũng không hoàn hảo — ở 500+ bước, LSTM vẫn gặp khó, và đó là cửa ngõ để Transformer bước vào.
LSTM hidden size = 256, input size = 128. RNN cùng kích thước có ~(256+128)×256 ≈ 98K tham số. LSTM có bao nhiêu (bỏ qua bias)?
Bạn huấn luyện LSTM sinh văn bản. Loss giảm nhưng output chỉ là từ ngữ ngẫu nhiên không có logic dài hạn. Điều nào KHẢ NĂNG CAO nhất?
Giải thích
LSTM (Long Short-Term Memory)được Hochreiter & Schmidhuber đề xuất năm 1997, giải quyết vấn đề vanishing gradient bằng cơ chế 3 cổng và cell state riêng biệt. Đây là bước tiến lớn đầu tiên giúp mạng neural xử lý được phụ thuộc dài hạn thực sự.
Các công thức cốt lõi:
Bí mật nằm ở công thức cell state: . Khi , cell state truyền gần như nguyên vẹn: . Gradient backprop qua phép nhân này cũng ≈ 1 → không biến mất!
Peephole LSTM: các cổng nhìn trực tiếp vào cell state (thêm C vào input của cổng). Bidirectional LSTM: đọc chuỗi cả 2 chiều → hiểu ngữ cảnh đầy đủ hơn. Stacked LSTM: xếp nhiều lớp LSTM → biểu diễn phức tạp hơn. ConvLSTM: thay phép nhân ma trận bằng tích chập — dùng cho video.
(1) Khởi tạo bias forget = 1 để "mặc định nhớ". (2) Dùng gradient clipping ở ngưỡng 5.0 — LSTM vẫn có thể exploding gradient. (3) Chú ý batch-first vs seq-first tuỳ framework. (4) Với chuỗi dài > 500 bước, cân nhắc Transformer thay vì LSTM.
LSTM vẫn vượt trội cho: (a) dữ liệu streaming thời gian thực — xử lý token-theo-token, không cần nhìn cả chuỗi; (b) thiết bị edge với RAM hạn chế — ít tham số hơn Transformer; (c) chuỗi CỰC DÀI có tính chất Markov (time-series tài chính, cảm biến IoT).
import torch
import torch.nn as nn
class LSTMClassifier(nn.Module):
"""Bộ phân loại văn bản dùng LSTM 2 tầng, 2 chiều."""
def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(
input_size=embed_dim,
hidden_size=hidden_dim,
num_layers=2, # 2 lớp LSTM xếp chồng
batch_first=True,
bidirectional=True, # Đọc cả 2 chiều
dropout=0.3,
)
# Mẹo: khởi tạo bias forget = 1 để "mặc định nhớ"
for name, param in self.lstm.named_parameters():
if "bias" in name:
n = param.size(0)
start, end = n // 4, n // 2
param.data[start:end].fill_(1.0)
# Bidirectional → hidden_dim * 2
self.fc = nn.Linear(hidden_dim * 2, num_classes)
def forward(self, x):
emb = self.embedding(x) # (batch, seq_len, embed_dim)
output, (h_n, c_n) = self.lstm(emb) # output: full sequence
# h_n[-2], h_n[-1]: hidden cuối 2 chiều của LỚP CUỐI
h_cat = torch.cat([h_n[-2], h_n[-1]], dim=1)
return self.fc(h_cat)
# Tổng tham số ≈ 4× RNN cùng hidden size, nhưng nhớ xa hơn rất nhiều.import torch
import torch.nn as nn
import torch.nn.functional as F
class LSTMCellManual(nn.Module):
"""LSTM cell viết tay — dạy học & debug. KHÔNG nhanh hơn nn.LSTM."""
def __init__(self, input_size, hidden_size):
super().__init__()
self.hidden_size = hidden_size
# Gộp 4 bộ trọng số thành 1 ma trận lớn (tối ưu bộ nhớ)
self.W = nn.Linear(input_size + hidden_size, 4 * hidden_size)
# Khởi tạo bias forget = 1 (Jozefowicz 2015)
with torch.no_grad():
self.W.bias[hidden_size:2 * hidden_size].fill_(1.0)
def forward(self, x, state):
"""
x: (batch, input_size)
state: (h_prev, c_prev), mỗi cái (batch, hidden_size)
"""
h_prev, c_prev = state
combined = torch.cat([x, h_prev], dim=1)
gates = self.W(combined) # (batch, 4*hidden)
i, f, g, o = gates.chunk(4, dim=1)
i = torch.sigmoid(i)
f = torch.sigmoid(f)
g = torch.tanh(g)
o = torch.sigmoid(o)
c = f * c_prev + i * g # cell state update
h = o * torch.tanh(c) # hidden state
return h, (h, c)
# Dùng: khởi tạo zeros, lặp qua sequence
cell = LSTMCellManual(input_size=64, hidden_size=128)
batch = 16
x_seq = torch.randn(10, batch, 64) # (seq_len, batch, input)
h = torch.zeros(batch, 128)
c = torch.zeros(batch, 128)
outputs = []
for t in range(x_seq.size(0)):
h, (h, c) = cell(x_seq[t], (h, c))
outputs.append(h)
outputs = torch.stack(outputs) # (seq_len, batch, hidden)- LSTM giải quyết vanishing gradient bằng cell state — 'đường cao tốc' cho gradient truyền thẳng qua phép nhân với forget gate.
- 4 khối tính toán song song: forget gate (xóa gì), input gate (ghi bao nhiêu), candidate (giá trị mới), output gate (xuất gì).
- Cell state C = bộ nhớ dài hạn (truyền trực tiếp); hidden state h = tanh(C) × output gate — phiên bản 'lọc' dùng cho output.
- Tham số gấp ~4× RNN (4 bộ trọng số riêng), nhưng đổi lại nhớ xa hàng trăm bước. Khởi tạo bias forget = 1 để mặc định nhớ.
- Biến thể: Bidirectional (2 chiều), Stacked (nhiều lớp), GRU (đơn giản hơn, 3× params), ConvLSTM (video).
- Ngày nay Transformer thay thế LSTM trong NLP, nhưng LSTM vẫn dùng cho time-series, streaming và thiết bị edge.
Kiểm tra hiểu biết
LSTM có 4 'nhóm' tính toán (3 cổng sigmoid + 1 candidate tanh). Nhóm nào cho phép 'nhớ' thông tin qua hàng trăm bước?