译:《图神经网络》
博客地址:《A Gentle Introduction to Graph Neural Networks》
前言
图随处可见,对应的应用包括但不限于抗菌发现、物理模拟、虚假新闻检测、交通预测、推荐系统等。
这篇文章探索和解释了现代图神经网络,并且将工作划分为四个部分:
- 什么数据可以表示成一个图
- 图数据与别的数据的区别,为什么要用GNN而不用别的
- 构建一个GNN
- 提供一个GNN平台
什么是图
图
图通常可以表示为实体(节点)集合之间的关系(边),如图:
图中节点、边或全局的属性可以用嵌入向量表示,如图:
图中黄色、蓝色、粉红色的高度表示对应属性的值
根据节点间关系又可分为有向图和无向图
图片的图形式
每个节点表示一个像素,不同节点通过边连接,每个节点存储表示RGB的三维向量信息
文本的图形式
通过索引将每个字符、单词或token联系,然后使用索引序列表示文本
一些其它的图
图表示分子:
图表示社交网络:
什么类型的问题有图结构数据?
- 图层面任务
- 节点层面任务
- 边层面任务
图层面任务
给定一个图,对图进行分类,如判断图是否有两个环:
节点层面任务
给定一个图,对节点进行分类,如判断某个节点的阵营:
边层面任务
给定一个图,对边的属性进行预测:
神经网络用于图面对的挑战
最核心的问题是如何表示图,才可以与神经网络相适应。
图上面有四种信息:节点属性、边属性、全局信息、连接性
前三个信息相对简单,可以直接使用向量进行表示,而连接性表示起来就会相对困难。可以用前面展示的邻接矩阵的形式,但存在问题是:
- 矩阵可能会非常大,因为得到的矩阵是\(N\times N\)的方阵,当\(node\)较多时,矩阵就会非常大,可以用稀疏矩阵存储,但计算起来又会比较困难
- 邻接矩阵交换任何行或列的顺序不会有影响,这意味着等价的各种邻接矩阵输入到神经网络中后,都应该都得一样的结果
因此,GNN中采用以下表示方法:
将节点和边的属性以及全局信息分别用一个标量或向量表示, 同时维护一个邻接列表,长度和边数一样,第\(i\)个向量表示的是第\(i\)条边连接的哪两个节点。
图神经网络
GNN是关于图的所有属性(节点、边、全局上下文)的可以优化的变换,它保持图的对称性(排列不变性)。这里使用“信息传递网络”来构建GNN。GNN的输入输出都是一个图,它会对图的顶点、边、全局属性等信息进行变换,而不会去改变图的连接性。
最简单的GNN
最简单的GNN如图,是对于顶点向量、边向量和全局向量分别构造一个MLP,这三个MLP组成一个GNN的层,这个层的输入输出都是一个图
但这样的话,输出如何得到需要预测的值呢?
对每个节点进行预测
已经有节点向量
直接将GNN后每个点的向量表示作为输入,使用一个全连接层和一个softmax得到输出,但这里,无论有多少个节点,这里只有一个全连接层,即所有节点共享一个全连接层参数。
没有节点向量
使用pooling,如图,可以使用和对应节点连接的那些边的向量和全局向量全部相加(如果维度不同则需要进行投影等变换操作)
整个过程如下图:(\(\rho_{E_n\rightarrow V_n}\)表示边到节点的polling操作)
对每个边进行预测
同节点预测一样,而对于只有节点而没有边,但要对边进行预测时,也可以同节点一样节点转换成边
对全局属性进行预测
同上
综合
综合上述几种情况,构成端到端结构如下:
首先,将图数据输入到GNN模型,然后变换得到保留原连接信息的图,最后根据需要对其中要做预测的属性添加合适的全连接层等,如果缺失信息则加入合适的pooling层,最后得到预测结果。
问题
虽然这种最简单的GNN很简单,但是会有很大的局限性,主要问题在于GNN blocks中,并没有对其使用图的结构信息,也就是说对于每个属性进行变换时,就是对应的属性进行MLP而没有对应的结构信息(如某个节点和哪些边相连等),所以这里的GNN blocks并没有将整个图的信息更新到属性中,导致最后结果可能并不是非常好。
改进GNN
使用信息传递技术进行改善。
如图,当对某个节点进行更新时,将其与其邻居的向量通过pooling的形式汇聚到一起再输入到MLP,得到这个节点向量的更新。
对应到GNN表示上,如下图所示:(其中的1表示的是1近邻,就是距离为1的点进行pooling)
提前汇聚
在前面最简单的GNN中,对于缺失的属性,会在最后从别的属性中汇聚过来弥补这个属性,而实际上不需要等到最后才这样,而可以在比较早的时候进行。
边的信息传递给节点再传递给边
如图,先将节点的信息给边,再将更新后的边的信息汇聚到节点再做更新
如下图,更新方式有两种,没有哪种更好,但可以同时进行:
全局信息
之前所说都是通过邻居间pooling,而如果图比较大而且连接没那么紧密的时候,导致信息从一个点传递到另一个很远的点需要走很长的步才行。
解决方案是加入一个master node或context vector,这个“节点”是一个虚拟的点,这个点可以和所有的点相连,也和所有边相连(抽象的形式),也即\(U\),也就是说,如图,当将节点信息汇聚到边时,也会将\(U\)汇聚过去,最后更新\(U\)时也会将所有节点信息和边信息汇聚
综合
如图,综上上述情况,对于一个节点,就可以使其与相邻边、相邻节点和全局信息进行汇聚,最后将汇聚结果进行MLP处理,这也有点类似与注意力机制。