6月4日 11:01

RNN、LSTM 和 GRU 有什么区别?怎么选?

RNN 是处理序列数据的基础架构:每一步把当前输入和上一步的隐藏状态拼在一起做变换,输出新的隐藏状态。问题是反向传播时梯度要乘很多次权重矩阵,序列一长梯度就指数级衰减(梯度消失)或爆炸——这就是 RNN 记不住远距离依赖的根本原因。

LSTM 通过引入细胞状态和三个门来解决这个问题:遗忘门决定忘掉什么,输入门决定存什么,输出门决定输出什么。关键在于细胞状态的更新是加法而非乘法:C_t = f_t ⊙ C_{t-1} + i_t ⊙ C̃_t,加法让梯度可以无损地回传,不会逐层衰减。

GRU 是 LSTM 的简化版,把遗忘门和输入门合成一个更新门 z_t,还省掉了细胞状态,直接在隐藏状态上做插值:h_t = (1-z_t) ⊙ h_{t-1} + z_t ⊙ h̃_t。参数少约 30%,训练更快,多数任务上效果和 LSTM 持平。

一个直觉:LSTM 的 f_t ≈ 1, i_t ≈ 0 时细胞状态原样传递——这就是"记忆"。GRU 的 z_t ≈ 0 时隐藏状态原样保留——异曲同工。

python
# PyTorch 中三者用法几乎一致 nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) nn.GRU(input_size, hidden_size, num_layers, batch_first=True) nn.RNN(input_size, hidden_size, num_layers, batch_first=True)

追问

为什么 LSTM 能解决梯度消失而 RNN 不能?

RNN 的隐藏状态更新是全乘法:h_t = tanh(W·h_{t-1} + ...),反向传播时 ∂h_t/∂h_{t-1} 包含 W 和 tanh',连乘 N 次后梯度要么趋于 0 要么爆炸。LSTM 的细胞状态更新是加法:C_t = f_t⊙C_{t-1} + i_t⊙C̃_t,∂C_t/∂C_{t-1} = f_t,只要 f_t 接近 1 梯度就能一路畅通地回传。本质上是加法 vs 乘法的区别。

GRU 和 LSTM 怎么选?

数据量小或需要快速迭代选 GRU(参数少,训练快)。序列特别长或任务特别复杂选 LSTM(表达能力更强)。实际项目中,如果你不确定,先用 GRU 跑 baseline,效果不够再换 LSTM——而不是反过来。2017 年以后大多数新论文默认用 GRU,但工业界 LSTM 仍然广泛部署。

RNN 系列和 Transformer 的核心区别是什么?

RNN 必须按时间步顺序计算,无法并行;Transformer 用自注意力直接看全局,所有位置同时算。代价是 Transformer 的注意力是 O(n²) 复杂度,而 RNN 是 O(n)。所以短序列 RNN 更省内存,长序列 Transformer 更快(因为并行)。另外,RNN 天然适合流式处理(来一个 token 处理一个),Transformer 每次都要重新算整段注意力。

双向 LSTM 和单向有什么区别?

单向 LSTM 只看过去的上下文,双向 LSTM 同时跑一个从左到右和一个从右到左的 LSTM,把两个方向的隐藏状态拼接。双向版本效果更好,但不能用于生成任务——因为你不能偷看未来的 token。文本分类、命名实体识别用双向,语言模型、机器翻译解码器用单向。

标签:NLP