ML 学习站
跳到正文

GNN 应用:节点 / 边 / 图三大任务

节点分类 / 链接预测 / 图分类 + 工业案例 (推荐/药物/交通)。

55 分钟3 / 42,768
加载中...

本章详细介绍了图神经网络(GNN)在节点级、边级和图级三大任务中的应用,并结合具体案例进行讲解。核心概念包括节点分类、链接预测和图分类。节点分类任务通过半监督学习,利用部分节点标签预测其他节点标签,常用于引文网络和社交网络用户画像。链接预测则关注预测两个节点之间是否存在边,广泛应用于推荐系统和知识图谱补全。图分类任务对整图进行分类预测,常用于分子属性预测和蛋白质功能预测。读者将学习到如何选择合适的GNN模型和评估指标,并掌握PyTorch Geometric和DGL等常用工具库的使用。完成本章后,读者能够应用GNN解决实际问题,如设计推荐系统、进行药物发现和交通预测等。

GNN 应用:节点 / 边 / 图三大任务

GNN 落地主要有三类任务: 节点级 (每个节点一个预测), 边级 (每条边一个预测), 图级 (整图一个预测)。

这一章拆解三大任务和实战案例。

1. 节点级任务 (Node-level)

1.1 节点分类 (Node Classification)

最经典的任务, 已知部分节点标签, 预测其他节点:

  • Cora 引文网络 (2708 论文, 7 类)
  • Reddit 帖子分类 (232K 帖子, 41 类)
  • 社交网络用户画像 (用户分群)

训练方式: 半监督, 只需要部分标签, 消息传递会传播到全图。

评估: 准确率 / Macro-F1。

from torch_geometric.nn import GATConv

class NodeClassifier(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.gat1 = GATConv(in_dim, hidden_dim, heads=8)
        self.gat2 = GATConv(hidden_dim * 8, out_dim, heads=1)
    def forward(self, x, edge_index):
        x = self.gat1(x, edge_index).relu()
        x = self.gat2(x, edge_index)
        return x

1.2 节点回归 (Node Regression)

预测节点的连续值:

  • 分子原子能量 (DFT 计算昂贵, GNN 预测)
  • 城市交通流量预测
  • 蛋白质残基属性

1.3 节点聚类 (Node Clustering)

把节点分群, 无监督:

  • 社区检测 (Louvain / 谱聚类)
  • 用 GNN embedding + KMeans
from torch_geometric.nn import GCNConv
from sklearn.cluster import KMeans

# 训 GCN, 拿 embedding
model = GCN(in_dim, hidden_dim, out_dim)
# ... 训练 ...

with torch.no_grad():
    embeddings = model.conv1(data.x, data.edge_index).cpu().numpy()

# KMeans 聚类
clusters = KMeans(n_clusters=10).fit_predict(embeddings)

2. 边级任务 (Edge-level)

预测两个节点之间是否会有边, 推荐系统核心:

  • 好友推荐 (预测用户 u 和 v 是否会成为好友)
  • 物品推荐 (预测用户 u 是否会喜欢物品 i)
  • 知识图谱补全 (预测三元组 (h, r, ?) 缺什么)

核心思路: 学一个评分函数 f(h, r, t), 给真实三元组高分, 假三元组低分。

from torch_geometric.nn import GCNConv

# 用 GCN 学节点 embedding
node_emb = model(data.x, data.edge_index)

# 边存在性预测 (u, v): 拼接 u 和 v embedding, 过 MLP
edge_pred = torch.sigmoid(MLP(torch.cat([node_emb[u], node_emb[v]], dim=1)))

2.2 三种评分函数

DistMult (最简单):

f(h, r, t) = <h, r, t> = sum_i h_i * r_i * t_i

TransE (平移不变, 几何直观):

f(h, r, t) = -||h + r - t||_2

RotatE (复数空间旋转, 强大):

t = h ∘ r   (复数乘法 = 旋转)
f = -||t - t'||

2.3 负采样

只训练正样本会过拟合, 要负采样生成负例:

  • 随机负采样: 随机替换头或尾
  • 难负采样 (Hard Negative): 用相似但错的样本 (e.g. 同一类型但不同实体)
  • 对比学习: InfoNCE 损失
# 随机负采样
def negative_sample(edge_index, num_nodes):
    neg_u = torch.randint(0, num_nodes, (edge_index.size(1),))
    neg_v = torch.randint(0, num_nodes, (edge_index.size(1),))
    return torch.stack([neg_u, neg_v])

2.4 实战:好友推荐

import torch
from torch_geometric.nn import GCNConv

class FriendRecommender(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)
    def predict(self, z, edge_label_index):
        # 用余弦相似度
        u = z[edge_label_index[0]]
        v = z[edge_label_index[1]]
        return (u * v).sum(dim=-1)

3. 图级任务 (Graph-level)

3.1 图分类 (Graph Classification)

整图一个预测, 每个图独立:

  • 分子属性预测 (毒理 / 溶解度)
  • 蛋白质功能预测
  • 社交网络异常检测

关键: 需要Readout 把节点 embedding 聚合成图 embedding:

from torch_geometric.nn import global_mean_pool, global_max_pool

# 训 GCN 拿节点 embedding
node_emb = gcn(data.x, data.edge_index)

# Readout: 整图 batch 所有节点 pool 到一个向量
graph_emb = global_mean_pool(node_emb, data.batch)  # (B, d)

# 分类
logits = classifier(graph_emb)

Readout 方式:

  • global_mean_pool: 所有节点取平均
  • global_max_pool: 所有节点取最大
  • global_add_pool: 所有节点求和
  • Set2Set: 用 LSTM 聚合
  • SortPool: 排序后取前 K 个

3.2 图回归 (Graph Regression)

整图预测连续值, 如分子能量:

# 同样 readout + 回归
graph_emb = global_mean_pool(node_emb, data.batch)
energy = regressor(graph_emb)  # 标量

3.3 图生成 (Graph Generation)

生成新图, 分子设计 / 药物发现:

  • GraphVAE: VAE 变体, 生成节点 / 边
  • GraphGAN: GAN 思路
  • Diffusion: 图扩散模型, 当前 SOTA
  • MolGPT: Transformer 自回归生成 SMILES
# 用 RDKit + 预训练 GNN 生成新分子
from torch_geometric.nn import GINConv
from rdkit import Chem

generator = MolecularGenerator(node_dim=..., edge_dim=...)
new_smiles = generator.sample(n=1000)
valid = [s for s in new_smiles if Chem.MolFromSmiles(s) is not None]

4. 三大任务对比

任务输入输出例子
节点分类1 个节点 (在全图)标签引文网络
链接预测2 个节点是否有边好友推荐
图分类整图标签分子属性

5. 工具与库

PyTorch Geometric (PyG)

学术最常用:

from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINConv, global_add_pool

dataset = TUDataset(root='data/MUTAG', name='MUTAG')
loader = DataLoader(dataset, batch_size=32, shuffle=True)

class GraphClassifier(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.conv1 = GINConv(MLP(in_dim, hidden_dim))
        self.conv2 = GINConv(MLP(hidden_dim, hidden_dim))
        self.classifier = Linear(hidden_dim, out_dim)
    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        x = global_add_pool(x, batch)
        return self.classifier(x)

DGL

工业界更常用, 多 backend (PyTorch / MXNet / TF):

import dgl
import dgl.nn as dglnn

g = dgl.graph(([0, 1, 2], [1, 2, 3]))
conv = dglnn.GraphConv(in_dim, out_dim)
h = conv(g, h)

6. 工业案例

6.1 推荐系统 (Pinterest / Uber Eats)

  • 把用户-物品交互当二部图
  • PinSage (Pinterest 2018): GraphSAGE 变体, 亿级物品图
  • LightGCN (He et al. 2020): 简化 GCN, 去掉特征变换, 适合推荐

6.2 药物发现 (AlphaFold / ESM)

  • 分子图预测毒理 / 溶解度
  • 蛋白质图 (残基 + 接触) 预测功能
  • AlphaFold 3: 用 GNN 预测蛋白质-DNA-小分子复合物结构

6.3 交通预测 (Google Maps / DiDi)

  • 路网 = 图, 路口 = 节点, 路 = 边
  • 预测每条路 5-30 分钟后的车速
  • STGCN / DCRNN: 时空 GNN

6.4 金融风控 (蚂蚁 / PayPal)

  • 用户-设备-账户构成异质图
  • GNN 检测欺诈团伙 / 套现路径

7. 评估指标

任务指标
节点分类Accuracy / Macro-F1
链接预测AUC / Hits@K / MRR
图分类Accuracy / ROC-AUC
图回归MAE / RMSE

8. 数据集速查

  • Cora / CiteSeer / Pubmed: 节点分类 baseline
  • OGB (Open Graph Benchmark): 大规模标准 benchmark
  • MUTAG / PROTEINS: 图分类小数据
  • ZINC: 分子图, 15K-250K 分子
  • OAG: 学术图谱, 异质图

总结

GNN 应用 = 节点 / 边 / 图 三大任务, 每个都有标准 baseline 和数据集。

最后一章Graph Transformer, 看 GNN 与 Transformer 结合的新方向。

章末小测验

检验你对《GNN 应用:节点 / 边 / 图三大任务》的掌握程度。

1

链接预测的核心目标是?

2

图分类任务中 Readout 操作用来?

3

TransE 的核心思想是?

讨论区(0)

加载评论中...