ChatGPT老师教你学图神经网络
图神经网络(Graph Neural Network, GNN)是一类专门处理图结构数据的神经网络。这里的“图”不是曲线图,而是由**节点(node)和边(edge)**组成的数据结构,例如社交网络、分子结构、知识图谱、交通路网、论文引用网络。GNN 的核心思想是:每个节点反复从邻居收集信息,再更新自己的表示。这一过程通常叫 message passing(消息传递) 或 neighbor aggregation(邻居聚合)。(Distill)
你可以把它理解成这样: 普通神经网络适合表格,CNN 适合网格状图像,RNN/Transformer 适合序列;而 GNN 适合“谁和谁相连”本身很重要的数据。Distill 对 GNN 的介绍也强调,图数据与图像、文本不同,关键难点在于图没有固定顺序、邻居数量不固定、结构本身携带信息,因此模型需要具备对节点排列变化的鲁棒性,并通过邻域传播来建模依赖关系。(Distill)
# 1. 什么是图神经网络
更形式化地说,GNN 的目标是学习图中元素的表示(representation / embedding)。这些表示可以是:
- 节点表示:每个节点一个向量
- 边表示:每条边一个向量
- 整图表示:整个图一个向量
然后再基于这些表示做下游任务,比如节点分类、链路预测、图分类、回归等。Stanford CS224W 课程材料也把 GNN 的输出概括为一组节点嵌入,再接预测头与损失函数完成训练。(SNAP)
# 2. GNN 的运行流程是怎样的
现代最常见的 GNN 可以概括成一个 “输入图 → 多层消息传递 → 读出/预测” 的流程。
# 第一步:准备输入图
输入通常包括:
- 节点集合 (V)
- 边集合 (E)
- 节点特征 (X):如用户属性、原子类型、论文词向量
- 边特征(可选):如边类型、距离、权重、时间戳
- 图结构:通常用邻接矩阵或边列表表示
在 PyTorch Geometric 里,这类图连接关系常用 edge_index 表示,本质上就是稀疏图的边索引。(PyG Documentation)
# 第二步:初始化表示
每个节点先有一个初始向量: [ h_v^{(0)} = x_v ] 也就是把节点特征作为第 0 层隐藏状态。若原始节点没有特征,也常通过 one-hot、可学习 embedding、结构特征等方式初始化。这个初始化思路是 GNN/MPNN 框架的标准起点。(arXiv)
# 第三步:消息传递
对每一层 (k),每个节点从邻居接收消息。一个通用形式是:
[ m_v^{(k)} = \operatorname{AGGREGATE}^{(k)} \left( { \phi^{(k)}(h_v^{(k-1)}, h_u^{(k-1)}, e_{uv}) : u \in \mathcal{N}(v) } \right) ]
这里意思是:
对节点 (v),对每个邻居 (u) 先算一条消息,再把所有消息聚合起来。PyG 的 MessagePassing 文档把这个过程拆成 message()、aggregate()、update() 三部分,属于很标准的实现抽象。(PyG Documentation)
# 第四步:更新节点表示
聚合完邻居信息后,更新当前节点表示:
[ h_v^{(k)} = \gamma^{(k)}(h_v^{(k-1)}, m_v^{(k)}) ]
也就是把“旧的自己”和“邻居带来的信息”融合,得到新的节点表示。重复 (K) 层后,节点就能感受到 (K)-hop 邻域的信息。GNN survey 和 Distill 都把这种多层传播视为现代 GNN 的核心机制。(arXiv)
# 第五步:读出与任务头
得到最后的节点表示后,根据任务类型决定输出形式:
- 节点任务:直接对每个节点的 (h_v^{(K)}) 做分类/回归
- 边任务:把两个端点的表示组合起来预测边
- 整图任务:对所有节点做 pooling/readout,得到整图向量,再分类/回归
整图读出通常会用 sum / mean / max pooling,Distill 和课程材料都把 pooling/readout 视为图级任务的关键步骤。(Distill)
# 第六步:损失与训练
最后将预测结果与标签计算损失,用反向传播训练参数。课程材料中常写成:
Input Graph → GNN → Node Embeddings → Prediction Head → Loss → Evaluation。(SNAP)
# 3. GNN 包含哪些核心组件
一个典型 GNN 一般包含这些组件。
# 3.1 图结构表示
用于说明谁和谁连接。常见形式:
- 邻接矩阵 (A)
- 边列表
edge_index - 稀疏矩阵
这是消息传递的基础。没有图结构,模型就不知道该从哪些邻居收集信息。(PyG Documentation)
# 3.2 节点特征 / 边特征
每个节点、边附带的属性向量。 例如分子图里节点是原子,边是化学键;社交图里节点是用户,边是好友关系。GNN 通过结构和属性共同学习表示。(Distill)
# 3.3 消息函数(Message Function)
决定邻居 (u) 如何向中心节点 (v) 发送消息。
它可以非常简单,也可以带权重、注意力、边特征。PyG 文档中对应 message()。(PyG Documentation)
# 3.4 聚合函数(Aggregation Function)
把来自不同邻居的消息合并起来。常见有:
- sum
- mean
- max
聚合函数通常要满足对邻居顺序不敏感,因为图的邻居没有天然顺序。PyG 明确列出了 add、mean、max 等典型聚合方式。(PyG Documentation)
# 3.5 更新函数(Update Function)
把聚合结果和节点旧状态结合,生成新状态。
可理解为每层的“节点状态转移器”。PyG 对应 update()。(PyG Documentation)
# 3.6 多层堆叠
堆叠多层后,节点能看到更远的邻域。 1 层看 1-hop,2 层看 2-hop,依此类推。但层数过多可能导致 over-smoothing 等问题,Stanford 课程中也讨论了消息传递的表达能力与局限。(SNAP)
# 3.7 读出层(Readout / Pooling)
当任务是整图预测时,需要把所有节点向量汇总成一个图向量。常见方法为全局 sum/mean/max pooling。(Distill)
# 3.8 预测头(Task Head)
最后的 MLP / 线性层 / 相似度函数等,用于输出具体任务结果,例如类别概率、实数值、边存在概率。(SNAP)
# 4. 输入和输出是什么
这取决于任务粒度。
# 输入
最标准的输入是:
[ G = (V, E, X, E_f) ]
其中:
- (V):节点
- (E):边
- (X):节点特征矩阵
- (E_f):边特征(可选)
工程实现中通常至少有:
x: 节点特征矩阵,形状约为[num_nodes, num_node_features]edge_index: 边索引,形状约为[2, num_edges]edge_attr: 边特征,可选batch: 多图训练时标记每个节点属于哪张图,可选
这些接口形式可在 PyG 文档中看到。(PyG Documentation)
# 输出
# 节点级任务
输出每个节点一个预测:
- 节点类别
- 节点回归值
- 节点 embedding
例子:论文分类、社交用户标签预测。(Distill)
# 边级任务
输出每条边或一对节点之间的预测:
- 是否有边
- 边类型
- 边权重
例子:好友推荐、知识图谱补全。(arXiv)
# 图级任务
输出整张图一个预测:
- 图类别
- 图属性回归
- 整图 embedding
例子:分子毒性预测、蛋白质功能分类。(Distill)
# 5. 一个最简运行示意
假设有一个节点 (v),它有三个邻居 (u_1,u_2,u_3)。
# 输入
- 自己的特征:(x_v)
- 邻居特征:(x_{u_1},x_{u_2},x_{u_3})
- 图里谁连着谁
# 第 1 层
- 从三个邻居各算一条消息
- 把三条消息做 sum/mean/max
- 与自己的旧表示融合,得到新表示 (h_v^{(1)})
# 第 2 层
重复上面过程,但这次邻居本身已经融合过它们各自的邻居信息了,所以 (v) 间接获得了更大范围的上下文。
# 最后
- 若是节点分类:直接拿 (h_v^{(K)}) 分类
- 若是图分类:把所有节点的 (h^{(K)}) pooling 成一个图向量再分类
这就是 GNN 最核心的执行逻辑。(PyG Documentation)
# 6. 可以把 GNN 记成一句话
一个非常实用的记忆方式是:
GNN = 在图上反复做“邻居收集信息 + 自身更新”,最后把得到的表示送去做预测。 (arXiv)
# 7. 和普通神经网络相比的关键区别
GNN 的特别之处主要有三点:
- 输入不是固定长度向量或规则网格,而是任意图结构。(Distill)
- 参数在所有节点/边上共享,像 CNN 卷积核在不同位置共享一样。(Distill)
- 核心计算不是按空间邻域滑窗,而是按图邻接关系传播。(arXiv)
# 8. 一个简明框架图
可以把整个 GNN 流程画成:
输入图 G = (节点, 边, 节点特征, 边特征)
↓
节点表示初始化
↓
第1层消息传递:邻居发送消息 → 聚合 → 更新
↓
第2层消息传递:邻居发送消息 → 聚合 → 更新
↓
...
↓
得到最终节点表示
↓
[节点任务] 直接预测
或
[图任务] pooling/readout 后预测
2
3
4
5
6
7
8
9
10
11
12
13
14
15
这个框架和 Stanford CS224W 课件中的 “Input Graph → GNN → Node Embeddings → Prediction Head” 是一致的。(SNAP)
# 9. 最后给你一个最稳的总结
- 什么是 GNN:处理图结构数据的神经网络,通过图上的消息传递学习表示。(arXiv)
- 运行流程:输入图 → 初始化节点表示 → 多层消息传递/聚合/更新 → 读出 → 预测。(PyG Documentation)
- 核心组件:图结构、节点/边特征、消息函数、聚合函数、更新函数、读出层、预测头。(PyG Documentation)
- 输入:节点、边、节点特征、边特征、图连接关系。(PyG Documentation)
- 输出:节点级、边级或图级的 embedding / 分类 / 回归结果。(SNAP)