本章深入探讨了将 Transformer 架构应用于图数据的 Graph Transformer,核心在于解决传统 GNN 的局限性。Graph Transformer 通过全局 Attention 机制,使任意两节点直接相连,克服了 GNN 的三个主要局限:过平滑、过压缩和长距离依赖弱。Graph Transformer 与 Vanilla Transformer 的主要区别在于显式编码图结构信息。主流架构包括 Graphormer(引入图结构偏置矩阵)、GATv2(改进 attention 机制)和 GPS(结合 GNN 和 Transformer)。此外,Graph Transformer 需要图位置编码,如拉普拉斯特征向量、随机游走和距离编码。为应对大规模图数据,采用了稀疏 Attention、节点采样、锚点和 NodeFormer 等方法。Graph Transformer 还可与 LLM 结合,用于文本属性图、知识图谱和多模态图。读者将掌握 Graph Transformer 的核心概念、主流模型及其应用场景,并了解其在分子预测等任务中的优势。
Graph Transformer:把 Transformer 搬到图上
GNN 消息传递只能看 K 跳邻居, Graph Transformer 用全局 Attention 让任意两节点直接相连, 解决长距离依赖。
这一章拆解 Graph Transformer 的核心思想和代表模型。
1. 为什么 GNN 不够
GNN 有 3 个核心局限:
- 过平滑 (Over-smoothing): 深层 GNN 所有节点表示趋同
- 过压缩 (Over-squashing): 信息必须挤过 bottleneck (长路径)
- 长距离依赖弱: K 层 GNN 只能看 K-hop, 远了传不到
Graph Transformer 用全局 Attention 一举解决: 任意两节点直接相连, 不受距离限制。
2. Graph Transformer vs Vanilla Transformer
Vanilla Transformer (NLP):
- 输入: 序列 token
- Attention: 任意两 token 直接相连
Graph Transformer:
- 输入: 节点 + 边
- Attention: 任意两节点直接相连 + 图结构编码
核心差异: 图结构必须显式编码进 attention (否则就是普通 Transformer)。
3. 三种主流架构
3.1 Graphormer (Microsoft, 2021)
用图结构当 Attention 偏置:
Attention(Q, K, V, E) = softmax((Q K^T) / sqrt(d) + B) V
B 是图结构偏置矩阵, 包含三类信息:
- 节点度编码 (Centrality Encoding): 在 embedding 里加度的 sin/cos 编码
- 空间编码 (Spatial Encoding): b_ij = 节点 i 到 j 的最短路径长度
- 边编码 (Edge Encoding): c_ij = 边 (i,j) 的特征 (如有)
效果: OGB-LSC 冠军, 分子预测 SOTA。
class GraphormerLayer(nn.Module):
def forward(self, x, attn_bias, edge_attr):
# x: (N, d) 节点特征
# attn_bias: (N, N) 空间编码 (最短路径)
# edge_attr: (N, N, de) 边特征
q, k, v = self.qkv(x).chunk(3, dim=-1)
attn = (q @ k.T) / sqrt(d) + attn_bias + edge_attr.mean(-1)
attn = softmax(attn)
return attn @ v
3.2 GATv2 / Transformer-style Attention
修正 GAT 的 attention, 让 attention 系数真正反映节点相似度:
GAT 公式:
α_ij = softmax(LeakyReLU(a^T [W h_i || W h_j]))
问题: a^T 在 LeakyReLU 之前, 非线性破坏了 attention 分布。
GATv2 把 LeakyReLU 移到最后:
α_ij = softmax(a^T LeakyReLU(W [h_i || h_j]))
效果提升 5-10%。
3.3 GPS (Generalizable Permutation Set, 2022)
GNN + Transformer 并行, 最后融合:
h_v = MLP(h_v^GNN + h_v^Transformer)
- GNN 路径: 局部消息传递 (GCN / GAT)
- Transformer 路径: 全局 attention
- 融合: 拼接 / 加权
兼顾局部 + 全局, OGB 多个任务 SOTA。
class GPSLayer(nn.Module):
def __init__(self, dim):
self.local = GATConv(dim, dim)
self.global_attn = MultiheadAttention(dim, num_heads=8)
self.mlp = MLP(dim * 2, dim)
def forward(self, x, edge_index):
local_h = self.local(x, edge_index)
global_h, _ = self.global_attn(x, x, x)
return self.mlp(torch.cat([local_h, global_h], dim=-1))
4. 位置编码
Graph Transformer 没有像序列一样的位置, 需要图位置编码。
4.1 拉普拉斯特征向量 (LapPE)
图的拉普拉斯矩阵 L = D - A, 做特征分解:
L = U Λ U^T
取前 k 个最小非零特征向量作为位置编码。
4.2 随机游走 (RWPE)
从每个节点出发, 走随机游走, 不同长度走出的频率作为位置编码。
4.3 距离编码
节点对之间的最短路径距离, 或多种距离度量。
5. 大规模 Graph Transformer
亿节点图跑全 attention 不现实 (O(N²) 内存), 几个 trick:
5.1 稀疏 Attention
只对邻居算 attention, 其余 mask 掉 (类似 GAT 但全局):
# 用 edge_index 当 mask
attn = softmax((Q K^T) / sqrt(d) + bias + edge_mask)
5.2 节点采样
每个 batch 只采样一个子图, 子图内做 full attention。
5.3 锚点 (Anchor)
预先选 K 个"锚点" (e.g. PageRank Top-K), 任意节点先和锚点算 attention, 再用锚点信息更新表示:
h_v = Attention(h_v, {h_anchor_1, ..., h_anchor_K})
h_v' = Attention({h_anchor_i}, h_v)
O(N) 复杂度, 处理亿节点。
5.4 NodeFormer (2022)
Kernelized Gumbel-Softmax + 锚点, 千万节点可训练。
6. 与 LLM 结合
Graph Transformer 让 GNN 借力 LLM 的预训练范式。
6.1 文本属性图 (TAG)
节点本身就是文本 (论文摘要 / 商品描述), 用 LLM 提特征 + GNN 学结构:
[论文摘要] -> LLM (GPT/Llama) -> 节点 embedding -> GNN
LLM 提"语义", GNN 学"关系"。
6.2 知识图谱 + LLM
知识图谱 (KG) 事实 + LLM 推理:
- KG 提供结构化知识
- LLM 做自然语言理解和生成
- 互相增强 (KG 增强 LLM, LLM 补全 KG)
6.3 多模态图
节点带文本 / 图像 / 视频, 用多模态 encoder 提特征, 再过 GNN。
7. 实战:OGB-LSC 分子预测
from torch_geometric.nn import GPSConv, GCNConv
class GraphTransformer(torch.nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim, num_layers=8):
super().__init__()
self.atom_encoder = AtomEncoder(in_dim, hidden_dim)
self.layers = torch.nn.ModuleList([
GPSConv(hidden_dim, GCNConv(hidden_dim, hidden_dim), heads=4)
for _ in range(num_layers)
])
self.classifier = torch.nn.Sequential(
torch.nn.Linear(hidden_dim, hidden_dim // 2),
torch.nn.ReLU(),
torch.nn.Linear(hidden_dim // 2, out_dim),
)
def forward(self, data):
x = self.atom_encoder(data.x)
for layer in self.layers:
x = layer(x, data.edge_index, data.edge_attr)
# Readout: graph-level prediction
from torch_geometric.nn import global_add_pool
graph_x = global_add_pool(x, data.batch)
return self.classifier(graph_x)
8. 模型选择
| 场景 | 推荐 |
|---|---|
| 节点分类 (小图) | GCN / GAT |
| 节点分类 (大图) | GraphSAGE + 邻居采样 |
| 链接预测 (知识图谱) | TransE / RotatE |
| 图分类 (小图) | GIN + virtual node |
| 图分类 (中等图) | Graphormer |
| 分子属性 (OGB-LSC) | GPS / Graphormer |
| 大规模图 (亿节点) | NodeFormer + 锚点 |
| 文本属性图 | LLM encoder + GNN |
9. 进阶话题
- 图基础模型 (GFM): GraphGPT / LLaGA, 用 LLM 直接处理图
- 图指令微调: 让 GNN 听人类指令
- 图 RL: 决策图 (知识图谱推理路径)
- 图对比学习: 不依赖标签的预训练
- 图 + 因果: 用 GNN 学因果关系
10. 总结
整个 GNN 入门 4 章走完, 你应该掌握:
- 图基础: 表示 / 度 / 经典算法 (BFS / DFS / PageRank)
- GNN 基础: Message Passing / GCN / GraphSAGE / GAT
- GNN 应用: 节点 / 边 / 图三大任务 + 工业案例
- Graph Transformer: Graphormer / GPS / 大规模 + LLM 结合
下一步看 AutoML (自动机器学习), 学怎么让机器自己设计神经网络。
章末小测验
检验你对《Graph Transformer:把 Transformer 搬到图上》的掌握程度。
Graph Transformer 相比传统 GNN 的核心优势是?
Graphormer 的核心创新是?