ML 学习站
跳到正文

Transformer 与注意力机制

Self-Attention、Multi-Head、Positional Encoding。

50 分钟3 / 62,332
加载中...

Transformer 与注意力机制

2017 年,Google 发表了论文 "Attention is All You Need",提出了 Transformer 架构。它彻底改变了深度学习,催生了 GPT、BERT、ChatGPT 等革命性模型。

注意力机制:核心直觉

人类看一张图,不会"平均"地看每个像素——而是有重点地关注某些区域。

翻译 "The cat sat on the mat" 时,翻译 "猫" 的时候要重点关注 "cat" 这个词。

注意力机制就是把这种"重点关注"做成数学:

给定查询 Q、键 K、值 V,注意力输出 = V 的加权平均,权重由 Q 和 K 的相似度决定。

Self-Attention 的数学

对于输入序列的每个位置,Self-Attention 计算它与所有位置(包括自己)的相关性:

Attention(Q, K, V) = softmax(Q * K^T / sqrt(d_k)) * V

直觉拆解:

  1. Q * K^T: 算每对位置的相关性分数 (n×n 矩阵)
  2. / sqrt(d_k): 缩放,防止分数过大导致 softmax 梯度消失
  3. softmax: 变成概率分布(权重和为 1)
  4. *** V**: 用这些权重对 V 加权求和

一个具体例子

翻译 "I love you" → "我 爱 你":

  • 处理 "爱" 时,英文的 self-attention 矩阵可能是:
           I    love   you
      I  [0.7   0.2   0.1]
    love [0.3   0.5   0.2]
    you  [0.2   0.3   0.5]
    
  • 翻译 "爱" 时,中文 decoder 的 cross-attention 关注 "love" 权重最大

Multi-Head Attention

一个注意力头只能学一种"关注模式"。Transformer 用 多个头并行,每个头学不同的模式:

MultiHead(Q, K, V) = Concat(head_1, ..., head_h) * W_O
   where head_i = Attention(Q * W_Q_i, K * W_K_i, V * W_V_i)

直觉:

  • 头 1 可能学"语法关系"(主谓)
  • 头 2 可能学"指代关系"(代词指代谁)
  • 头 3 可能学"长距离依赖"
  • 头 4 可能学"否定"

多个头让模型同时捕捉多种关系。

位置编码(Positional Encoding)

Self-Attention 是置换不变的——把输入打乱顺序,输出也对应打乱。这对序列来说是 bug。

位置编码给每个位置加一个"位置向量",让模型知道"谁在前谁在后":

# Transformer 论文用正弦位置编码
PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

直觉:不同维度用不同频率的正弦/余弦,让模型能学到"位置 i 离位置 j 有多远"。

现代 LLM 更常用 RoPE(旋转位置编码)和 ALiBi(线性偏置),效果更好。

Transformer 完整架构

输入 → Embedding → + Positional Encoding
            ↓
    ┌───────────────────┐
    │   Encoder × N     │  (N=6 in original, N=96+ in GPT-3)
    │ ┌───────────────┐ │
    │ │ Multi-Head    │ │  ← 自注意力: 看自己
    │ │ Self-Attention│ │
    │ ├───────────────┤ │
    │ │ Add & Norm    │ │
    │ ├───────────────┤ │
    │ │ Feed Forward  │ │
    │ ├───────────────┤ │
    │ │ Add & Norm    │ │
    │ └───────────────┘ │
    └───────────────────┘
            ↓
       Encoder 输出
            ↓
    ┌───────────────────┐
    │   Decoder × N     │
    │ ┌───────────────┐ │
    │ │ Masked        │ │  ← 自注意力但只看过去
    │ │ Self-Attention│ │
    │ ├───────────────┤ │
    │ │ Cross         │ │  ← 跨注意力: encoder 输出当 K/V
    │ │ Attention     │ │
    │ ├───────────────┤ │
    │ │ Feed Forward  │ │
    │ └───────────────┘ │
    └───────────────────┘
            ↓
       输出预测

三大组件:

  • Self-Attention: 编码器内"看自己"
  • Masked Self-Attention: 解码器内"只能看过去"——保证训练/推理一致
  • Cross-Attention: 解码器看编码器输出(用于翻译等)

为什么 Transformer 这么强大?

维度RNN/LSTMTransformer
并行训练顺序,不能并行完全并行(每个位置独立算注意力)
长距离依赖要经过中间步,衰减直接算任意两点的关系
可扩展性难训大模型容易 scale 到上千亿参数
计算效率O(seq_len) 串行O(seq_len^2) 但并行

代价:Self-Attention 的复杂度是 O(seq_len²)——序列长时显存爆炸。这也是 LLM 长上下文研究的热点(FlashAttention、稀疏注意力等)。

三种典型架构

1. Encoder-only(BERT 路线)

  • 只用 encoder
  • 适合理解类任务:文本分类、命名实体识别、问答
  • 代表:BERT、RoBERTa

2. Decoder-only(GPT 路线)

  • 只用 decoder(去掉 cross-attention)
  • 适合生成类任务:文本生成、对话、代码
  • 代表:GPT 系列、Llama、Claude

3. Encoder-Decoder(T5 / BART 路线)

  • 两个都用
  • 适合序列到序列任务:翻译、摘要
  • 代表:T5、BART、原始 Transformer

PyTorch 极简实现

import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        B, L, D = x.shape
        Q = self.W_q(x).view(B, L, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(B, L, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(B, L, self.n_heads, self.d_k).transpose(1, 2)

        # 注意力分数
        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = torch.softmax(scores, dim=-1)
        out = attn @ V  # (B, n_heads, L, d_k)

        out = out.transpose(1, 2).contiguous().view(B, L, D)
        return self.W_o(out)

实际用 nn.MultiheadAttentionnn.Transformer 即可,这是教学版。

小结

  • Transformer = Self-Attention + 位置编码 + 前馈网络
  • 三大优势:并行训练、长距离依赖、可扩展
  • 三种架构:Encoder-only / Decoder-only / Encoder-Decoder
  • 统治了 NLP,正在扩展到图像(AlexNet 之后的 ViT)、音频、多模态

练习思考

  1. Self-Attention 复杂度是 O(n²),输入 100k token 时显存够吗?有什么优化方向?
  2. 为什么 Decoder 用 Masked Self-Attention?训练时不 mask 会怎样?
  3. 用 Hugging Face 的 transformers 库加载一个预训练 BERT,做一次文本分类。

讨论区(0)

加载评论中...