ML 学习站
跳到正文

Graph Transformer:把 Transformer 搬到图上

Graphormer / GATv2 / GPS, 大规模 + LLM 结合。

45 分钟4 / 42,995
加载中...

本章深入探讨了将 Transformer 架构应用于图数据的 Graph Transformer,核心在于解决传统 GNN 的局限性。Graph Transformer 通过全局 Attention 机制,使任意两节点直接相连,克服了 GNN 的三个主要局限:过平滑、过压缩和长距离依赖弱。Graph Transformer 与 Vanilla Transformer 的主要区别在于显式编码图结构信息。主流架构包括 Graphormer(引入图结构偏置矩阵)、GATv2(改进 attention 机制)和 GPS(结合 GNN 和 Transformer)。此外,Graph Transformer 需要图位置编码,如拉普拉斯特征向量、随机游走和距离编码。为应对大规模图数据,采用了稀疏 Attention、节点采样、锚点和 NodeFormer 等方法。Graph Transformer 还可与 LLM 结合,用于文本属性图、知识图谱和多模态图。读者将掌握 Graph Transformer 的核心概念、主流模型及其应用场景,并了解其在分子预测等任务中的优势。

Graph Transformer:把 Transformer 搬到图上

GNN 消息传递只能看 K 跳邻居, Graph Transformer 用全局 Attention 让任意两节点直接相连, 解决长距离依赖。

这一章拆解 Graph Transformer 的核心思想和代表模型。

1. 为什么 GNN 不够

GNN 有 3 个核心局限:

  1. 过平滑 (Over-smoothing): 深层 GNN 所有节点表示趋同
  2. 过压缩 (Over-squashing): 信息必须挤过 bottleneck (长路径)
  3. 长距离依赖弱: K 层 GNN 只能看 K-hop, 远了传不到

Graph Transformer 用全局 Attention 一举解决: 任意两节点直接相连, 不受距离限制。

2. Graph Transformer vs Vanilla Transformer

Vanilla Transformer (NLP):

  • 输入: 序列 token
  • Attention: 任意两 token 直接相连

Graph Transformer:

  • 输入: 节点 + 边
  • Attention: 任意两节点直接相连 + 图结构编码

核心差异: 图结构必须显式编码进 attention (否则就是普通 Transformer)。

3. 三种主流架构

3.1 Graphormer (Microsoft, 2021)

用图结构当 Attention 偏置:

Attention(Q, K, V, E) = softmax((Q K^T) / sqrt(d) + B) V

B 是图结构偏置矩阵, 包含三类信息:

  1. 节点度编码 (Centrality Encoding): 在 embedding 里加度的 sin/cos 编码
  2. 空间编码 (Spatial Encoding): b_ij = 节点 i 到 j 的最短路径长度
  3. 边编码 (Edge Encoding): c_ij = 边 (i,j) 的特征 (如有)

效果: OGB-LSC 冠军, 分子预测 SOTA。

class GraphormerLayer(nn.Module):
    def forward(self, x, attn_bias, edge_attr):
        # x: (N, d) 节点特征
        # attn_bias: (N, N) 空间编码 (最短路径)
        # edge_attr: (N, N, de) 边特征
        q, k, v = self.qkv(x).chunk(3, dim=-1)
        attn = (q @ k.T) / sqrt(d) + attn_bias + edge_attr.mean(-1)
        attn = softmax(attn)
        return attn @ v

3.2 GATv2 / Transformer-style Attention

修正 GAT 的 attention, 让 attention 系数真正反映节点相似度:

GAT 公式:

α_ij = softmax(LeakyReLU(a^T [W h_i || W h_j]))

问题: a^T 在 LeakyReLU 之前, 非线性破坏了 attention 分布

GATv2 把 LeakyReLU 移到最后:

α_ij = softmax(a^T LeakyReLU(W [h_i || h_j]))

效果提升 5-10%。

3.3 GPS (Generalizable Permutation Set, 2022)

GNN + Transformer 并行, 最后融合:

h_v = MLP(h_v^GNN + h_v^Transformer)
  • GNN 路径: 局部消息传递 (GCN / GAT)
  • Transformer 路径: 全局 attention
  • 融合: 拼接 / 加权

兼顾局部 + 全局, OGB 多个任务 SOTA。

class GPSLayer(nn.Module):
    def __init__(self, dim):
        self.local = GATConv(dim, dim)
        self.global_attn = MultiheadAttention(dim, num_heads=8)
        self.mlp = MLP(dim * 2, dim)
    def forward(self, x, edge_index):
        local_h = self.local(x, edge_index)
        global_h, _ = self.global_attn(x, x, x)
        return self.mlp(torch.cat([local_h, global_h], dim=-1))

4. 位置编码

Graph Transformer 没有像序列一样的位置, 需要图位置编码

4.1 拉普拉斯特征向量 (LapPE)

图的拉普拉斯矩阵 L = D - A, 做特征分解:

L = U Λ U^T

取前 k 个最小非零特征向量作为位置编码。

4.2 随机游走 (RWPE)

从每个节点出发, 走随机游走, 不同长度走出的频率作为位置编码。

4.3 距离编码

节点对之间的最短路径距离, 或多种距离度量。

5. 大规模 Graph Transformer

亿节点图跑全 attention 不现实 (O(N²) 内存), 几个 trick:

5.1 稀疏 Attention

只对邻居算 attention, 其余 mask 掉 (类似 GAT 但全局):

# 用 edge_index 当 mask
attn = softmax((Q K^T) / sqrt(d) + bias + edge_mask)

5.2 节点采样

每个 batch 只采样一个子图, 子图内做 full attention。

5.3 锚点 (Anchor)

预先选 K 个"锚点" (e.g. PageRank Top-K), 任意节点先和锚点算 attention, 再用锚点信息更新表示:

h_v = Attention(h_v, {h_anchor_1, ..., h_anchor_K})
h_v' = Attention({h_anchor_i}, h_v)

O(N) 复杂度, 处理亿节点。

5.4 NodeFormer (2022)

Kernelized Gumbel-Softmax + 锚点, 千万节点可训练。

6. 与 LLM 结合

Graph Transformer 让 GNN 借力 LLM 的预训练范式。

6.1 文本属性图 (TAG)

节点本身就是文本 (论文摘要 / 商品描述), 用 LLM 提特征 + GNN 学结构:

[论文摘要] -> LLM (GPT/Llama) -> 节点 embedding -> GNN

LLM 提"语义", GNN 学"关系"。

6.2 知识图谱 + LLM

知识图谱 (KG) 事实 + LLM 推理:

  • KG 提供结构化知识
  • LLM 做自然语言理解和生成
  • 互相增强 (KG 增强 LLM, LLM 补全 KG)

6.3 多模态图

节点带文本 / 图像 / 视频, 用多模态 encoder 提特征, 再过 GNN。

7. 实战:OGB-LSC 分子预测

from torch_geometric.nn import GPSConv, GCNConv

class GraphTransformer(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_layers=8):
        super().__init__()
        self.atom_encoder = AtomEncoder(in_dim, hidden_dim)
        self.layers = torch.nn.ModuleList([
            GPSConv(hidden_dim, GCNConv(hidden_dim, hidden_dim), heads=4)
            for _ in range(num_layers)
        ])
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim, hidden_dim // 2),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim // 2, out_dim),
        )
    def forward(self, data):
        x = self.atom_encoder(data.x)
        for layer in self.layers:
            x = layer(x, data.edge_index, data.edge_attr)
        # Readout: graph-level prediction
        from torch_geometric.nn import global_add_pool
        graph_x = global_add_pool(x, data.batch)
        return self.classifier(graph_x)

8. 模型选择

场景推荐
节点分类 (小图)GCN / GAT
节点分类 (大图)GraphSAGE + 邻居采样
链接预测 (知识图谱)TransE / RotatE
图分类 (小图)GIN + virtual node
图分类 (中等图)Graphormer
分子属性 (OGB-LSC)GPS / Graphormer
大规模图 (亿节点)NodeFormer + 锚点
文本属性图LLM encoder + GNN

9. 进阶话题

  • 图基础模型 (GFM): GraphGPT / LLaGA, 用 LLM 直接处理图
  • 图指令微调: 让 GNN 听人类指令
  • 图 RL: 决策图 (知识图谱推理路径)
  • 图对比学习: 不依赖标签的预训练
  • 图 + 因果: 用 GNN 学因果关系

10. 总结

整个 GNN 入门 4 章走完, 你应该掌握:

  1. 图基础: 表示 / 度 / 经典算法 (BFS / DFS / PageRank)
  2. GNN 基础: Message Passing / GCN / GraphSAGE / GAT
  3. GNN 应用: 节点 / 边 / 图三大任务 + 工业案例
  4. Graph Transformer: Graphormer / GPS / 大规模 + LLM 结合

下一步看 AutoML (自动机器学习), 学怎么让机器自己设计神经网络。

章末小测验

检验你对《Graph Transformer:把 Transformer 搬到图上》的掌握程度。

1

Graph Transformer 相比传统 GNN 的核心优势是?

2

Graphormer 的核心创新是?

讨论区(0)

加载评论中...