ML 学习站
跳到正文

循环神经网络 RNN 与 LSTM

序列建模、RNN 梯度消失、LSTM 与 GRU。

40 分钟2 / 61,659
加载中...

循环神经网络(RNN)和长短期记忆网络(LSTM)是处理序列数据的强大工具。RNN的核心思想是让网络具有“记忆”,即每一步的输出都依赖于之前的输入,使其适用于文本、语音等时序数据。然而,RNN存在长期依赖问题,难以学习超过5步的历史信息。LSTM通过引入门控机制(输入门、遗忘门、输出门)和cell state解决了这一问题,使得梯度可以无衰减地流动,从而有效捕捉长距离依赖。GRU是LSTM的简化版本,使用更少的参数并保持相似的性能。双向RNN通过同时处理序列的正向和反向信息,进一步提升了模型的表现。尽管LSTM和GRU缓解了长期依赖问题,但它们在处理超长序列时仍存在局限。读者学完后,能够理解RNN和LSTM的基本原理,掌握它们在处理序列数据中的应用,并了解LSTM和GRU的优缺点及适用场景。

循环神经网络 RNN 与 LSTM

CNN 擅长处理"空间上"的数据(图片),但很多任务是"时间上"有顺序的——文本、语音、股票、视频。RNN(Recurrent Neural Network)就是为序列而生的。

为什么需要 RNN?

普通神经网络假设所有输入是独立的。但实际场景中:

  • 一句话:"我喜欢吃苹果"——"喜欢"暗示"吃",时序很重要
  • 一段语音:音素顺序决定语义
  • 股票价格:今天的价格依赖昨天、上周

RNN 的核心思想:让网络有"记忆"——每一步的输出依赖之前的输入

RNN 的结构

把网络"展开"看,RNN 在每个时间步都做同样的事,但信息会在时间步之间传递:

时间步:    t-1        t         t+1
         ┌───┐      ┌───┐      ┌───┐
    x → │RNN│ → h │RNN│ → h │RNN│ → h → 输出
         └─┬─┘      └─┬─┘      └─┬─┘
           │          │          │
           h ←────────h ←────────h
         (隐藏状态传递给下一步)

关键公式:

h_t = tanh(W_xh * x_t + W_hh * h_{t-1} + b)
y_t = W_hy * h_t

h_t 既是当前步的"记忆",也是当前步的输出。理论上,这个记忆可以一直传下去。

RNN 实战:字符级语言模型

import torch
import torch.nn as nn

class CharRNN(nn.Module):
    def __init__(self, vocab_size, hidden_size=128):
        super().__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.rnn = nn.RNN(hidden_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, h0=None):
        # x: (batch, seq_len)
        emb = self.embedding(x)
        out, h_n = self.rnn(emb, h0)
        # out: (batch, seq_len, hidden)
        logits = self.fc(out)
        return logits, h_n

# 训练: 给 "hello" 预测 "ello " (下一个字符)
model = CharRNN(vocab_size=128)
x = torch.tensor([[ord(c) for c in "hello"]])
logits, _ = model(x)
print(logits.shape)  # (1, 5, 128)

RNN 的致命问题:长期依赖

理论上 RNN 能记住任意长的历史,实际上很难学到 5 步以上的依赖

原因:反向传播时,梯度要沿时间步连乘。如果每步的局部导数都 < 1,梯度会指数级衰减到 0——前面步的参数根本不更新。

梯度  = (∂h_t/∂h_{t-1}) * (∂h_{t-1}/∂h_{t-2}) * ... * (∂h_2/∂h_1)
       = 多次 tanh'(h) 相乘 (大部分 &lt; 1)
       ≈ 0

这就是长期依赖问题:RNN 学不会"我 30 步前看到的那个信息对当前决策很重要"。

LSTM:让记忆可控

LSTM(Long Short-Term Memory,1997)用巧妙的"门"机制解决了这个问题:

            ┌───── forget gate (决定忘记什么)
            │
h_{t-1} ───┤───── input gate  (决定记住什么新东西)
            │
x_t ────────┤───── output gate (决定输出什么)
            │
            └──────→ cell state c_t (长期记忆通道)

三个门都是 sigmoid 神经元,输出 0~1,像阀门一样控制信息流

LSTM 的核心:Cell State

c_t = forget_gate * c_{t-1} + input_gate * new_info
h_t = output_gate * tanh(c_t)

关键设计:c_t 的更新是逐元素相乘 + 相加,而不是矩阵乘法——梯度可以无衰减地沿 c_t 流动

LSTM 实战

class CharLSTM(nn.Module):
    def __init__(self, vocab_size, hidden_size=128):
        super().__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        # PyTorch 内置 LSTM: 一个调用搞定
        self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x):
        emb = self.embedding(x)
        out, (h_n, c_n) = self.lstm(emb)
        return self.fc(out), (h_n, c_n)

只比 RNN 多一行——PyTorch 帮你把门机制都封装好了。

GRU:LSTM 的简化版

GRU(Gated Recurrent Unit,2014)把 LSTM 的三个门合并成两个:更新门 + 重置门

# PyTorch 一行切换
self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)

GRU vs LSTM:

  • GRU 参数更少(2 个门 vs 3 个门),训得更快
  • LSTM 表达力稍强,数据多时上限更高
  • 实践中差别不大,GRU 默认就行

双向 RNN

有时候我们不仅想看"过去",还想看"未来"。

比如命名实体识别:"苹果公司发布了新产品"——看到"公司"才知道"苹果"是公司名。

双向 RNN 同时跑两个 RNN:一个从左到右,一个从右到左,然后拼接输出。

self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True, bidirectional=True)
# 输出维度变成 2 * hidden_size

RNN/LSTM 的局限

虽然 LSTM/GRU 缓解了长期依赖,但对超长序列仍然力不从心:

  • 训练是顺序的,无法并行(必须等 t-1 算完才能算 t)
  • 长距离依赖仍然要经过很多步,信息会衰减
  • 实践中一般只能处理 100-500 步的依赖

这就是为什么 Transformer 出现了——它用注意力机制直接建模任意两个位置的关系,完全摆脱了"一步步传"。

小结

  • RNN/LSTM 处理序列数据(文本、语音、时序)
  • LSTM 用门控 + cell state 解决长期依赖
  • GRU 是 LSTM 的简化版,参数更少
  • 双向 RNN 可以同时看上下文
  • 对超长序列,Transformer 是更好的选择(下一章)

练习思考

  1. 为什么 RNN 不能像 CNN 那样并行训练?这对训练速度有什么影响?
  2. LSTM 里的 forget gate bias 应该初始化为多少?为什么?(提示:想让网络"默认记住一切")
  3. 用 LSTM 训一个字符级 Shakespeare 生成器,看看能学出什么风格。

章末小测验

检验你对《循环神经网络 RNN 与 LSTM》的掌握程度。

1

RNN的核心思想是什么?

2

以下哪一项是RNN面临的长期依赖问题的原因?

3

LSTM通过什么机制解决了长期依赖问题?

4

GRU与LSTM的主要区别是什么?

5

双向RNN的主要优势是什么?

讨论区(0)

加载评论中...