ML 学习站
跳到正文

GNN 基础:GCN / GraphSAGE / GAT

Message Passing 框架, 三大经典 GNN 与实战。

50 分钟2 / 42,265
加载中...

本章介绍了图神经网络(GNN)的基础知识,重点讲解了三种经典模型:GCN、GraphSAGE 和 GAT。GNN 的核心思想是每个节点通过与邻居节点的信息传递来更新自身的表示。GCN(Graph Convolutional Network)通过邻居特征的平均值和自身特征进行线性变换和激活函数处理,结构简单但存在邻居等权和深层训练易出现过平滑的问题。GraphSAGE 通过邻居采样和多种聚合器(如均值、最大池化和 LSTM)提高了模型在大图上的可扩展性,但采样过程具有随机性。GAT(Graph Attention Network)引入了注意力机制,自动学习邻居的重要性,具有更好的解释性,但计算量较大且训练速度较慢。读者将掌握如何利用这些模型进行节点分类任务,并了解过平滑、异质图和大图等关键挑战的解决方法。

GNN 基础:GCN / GraphSAGE / GAT

CNN 处理图像 (欧几里得结构, 规则网格), RNN 处理序列 (链式结构), GNN 处理图 (任意拓扑结构)。

这一章拆解三大经典 GNN 模型。

1. 为什么需要 GNN

节点的特征包含结构信息 (邻居长什么样) 和自身信息 (节点属性)。传统 MLP 只看节点自身, 完全忽略结构。

GNN 核心思想: 每个节点都从邻居那里"学习", 更新自己的表示。

2. 通用框架:Message Passing

所有 GNN 都可以抽象成三步:

1. 聚合 (Aggregate):  节点收集邻居的消息
2. 更新 (Update):     节点聚合消息 + 自身, 更新表示
3. 循环 (Iterate):    重复 K 次, 信息传到 K 跳外

形式化:

m_v^(k) = AGGREGATE({ h_u^(k-1) : u in N(v) })   # 收集邻居
h_v^(k) = UPDATE(h_v^(k-1), m_v^(k))              # 更新自己

经过 K 层后, h_v^(K) 包含节点 v 的 K-hop 邻域信息

3. GCN (Graph Convolutional Network, 2017)

图卷积网络, Kipf & Welling 提出, 最经典的 GNN。

核心公式

H^(l+1) = σ(D^(-1/2) A D^(-1/2) H^(l) W^(l))
  • A = 邻接矩阵 (加自环 A+I)
  • D = 度矩阵 (D_ii = sum_j A_ij)
  • H^(l) = 第 l 层节点特征 (N×d)
  • W^(l) = 可学习权重 (d×d')
  • σ = 激活函数 (ReLU)

直觉

每个节点的新特征 = 邻居特征的平均 (D 归一化) + 自身特征, 过线性变换 + ReLU。

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, out_dim)
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = F.dropout(x, p=0.5)
        x = self.conv2(x, edge_index)
        return x

优缺点

  • ✅ 简单高效, 训练快
  • ❌ 邻居等权 (没办法区分重要邻居)
  • ❌ 浅层 (2-3 层), 太深会过平滑 (over-smoothing)

4. GraphSAGE (2017)

SAmple and aggreGatE, Hamilton 等提出。

核心改进

GCN 用全图邻居 + 固定聚合 (平均), GraphSAGE 提出:

  • 邻居采样: 大图只采样固定数量邻居 (e.g. 25), 可扩展到亿级图
  • 多种聚合器: Mean / MaxPool / LSTM

三种聚合器

# Mean 聚合 (最常用)
h_v = mean({h_u : u in N(v)})

# MaxPool 聚合
h_v = max({MLP(h_u) : u in N(v)})

# LSTM 聚合 (邻居有顺序假设, 实际打乱顺序)
h_v = LSTM([h_u for u in random_perm(N(v))])

LSTM 聚合效果最好, 但需要随机排列邻居保证 permutation-invariant。

优缺点

  • ✅ 大图可扩展 (邻居采样)
  • ✅ 多种聚合器
  • ❌ 邻居采样是随机的, 不同次结果不同

5. GAT (Graph Attention Network, 2018)

GCN 给所有邻居等权, GAT 用 Attention 自动学权重。

核心公式

每个节点 v 对每个邻居 u 计算 attention 系数:

e_uv = LeakyReLU(a^T [W h_u || W h_v])      # 拼接 + 单层前馈
α_uv = softmax_u(e_uv) = exp(e_uv) / Σ_k exp(e_uk)
h_v' = σ(Σ_{u in N(v)} α_uv W h_u)

多头注意力

和 Transformer 一样, 用 K 个独立 attention 头, 然后拼接 / 平均:

h_v' = ||_{k=1}^{K} σ(Σ α_uv^k W^k h_u)

优缺点

  • ✅ 自动学邻居重要性
  • ✅ 解释性好 (attention 权重可视化)
  • ❌ 计算量大 (每个边算 attention)
  • ❌ 训练慢

6. 三大模型对比

模型聚合方式优点缺点
GCN均值 (度加权)简单快等权
GraphSAGE采样 + 多聚合器大图可扩展采样随机
GATAttention自动学权重

7. 实战: Cora 节点分类

Cora 是引文网络, 2708 篇论文分 7 类, 每篇用 1433 维词袋特征表示:

from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv

dataset = Planetoid(root='data/Cora', name='Cora')
data = dataset[0]
# data.x: (2708, 1433), data.edge_index: (2, 10556)
# data.y: (2708,), data.train_mask: 140 个 True

model = GCN(dataset.num_features, 16, dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

for epoch in range(200):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

# 测试: 通常 80%+ 准确率

8. 关键挑战

过平滑 (Over-smoothing)

GCN / GAT 深层后, 所有节点表示趋于相同 (因为反复平均), 节点不可区分。

解法:

  • 用残差连接 (类似 ResNet)
  • 用 PairNorm / DropEdge / DropNode 随机丢边 / 节点
  • 加初始残差 (APPNP)

异质图

真实图常有多类节点 (用户 + 物品 + 评论), 普通 GNN 不直接支持。

解法:

  • 每类节点单独投影到同一空间
  • 用 HAN / HGT 等异质 GNN

大图

亿级节点 / 边无法全图训练。

解法:

  • 邻居采样 (GraphSAGE)
  • 子图采样 (Cluster-GCN / GraphSAINT)
  • 分布式训练 (DGL / GraphScope)

9. 进阶话题

  • GNNExplainer: 解释 GNN 预测, 找出关键子图
  • 对抗攻击: 在边上加扰动让 GNN 误分类
  • 图对比学习: GraphCL / GRACE, 无监督表征
  • 图生成: GraphVAE / GraphGAN, 生成新分子
  • 异质图: HAN / HGT / RGCN
  • 动态图: TGN / EvolveGCN

总结

GNN = Message Passing, 三大经典: GCN (均值) / GraphSAGE (采样) / GAT (Attention)。

下一章GNN 应用, 看节点分类 / 链接预测 / 图分类三大任务。

章末小测验

检验你对《GNN 基础:GCN / GraphSAGE / GAT》的掌握程度。

1

GNN 中 Message Passing 的三步是?

2

GCN / GraphSAGE / GAT 三者核心区别是?

3

GNN 深层后'过平滑'现象是指?

讨论区(0)

加载评论中...