GraphMix¶
GraphMix: Improved Training of GNNs for Semi-Supervised Learning
-
Motivation
-
目前图数据增强的领域很少人探讨,已有方法也都计算量太大,因此作者想提出一种有效的图数据增强方法
-
为了解决interpolation方法的图数据增强后不知道如何保证原本图结构不被破坏的问题,作者希望提出一种方法,能在interpolation图数据增强后仍能保留图数据结构
-
通过interpolation解决GNNs的过平滑问题
-
主要增强方法:对输入未标记样本的多个随机扰动的平均预测和标签锐化
-
建模过程:
-
作者使用 Manifold Mixup 来训练一个辅助的 FCN,基于插值的方法可以跳脱出当前节点的局部邻域来扩充信息缓解 over-smoothing
-
在 FCN 和 GNNs 之间共享网络参数,利用 FCN 中节点的更具区别性的表示,以及图的结构,按照通常的方式计算 GNNs 损失,以进一步细化节点表示
-
-
核心公式
-
FCN loss from labeled data
-
其中\(\mathcal{L}_{MM}\)是Manifold Mixup计算公式,g是从input到hidden,f是hidden到output,即通过类似交叉熵损失来衡量mix后的数据输入网络前后的差距
-
其中 \(\text{Mix}_{\lambda}(a, b) = \lambda * a + (1 - \lambda) * b\) 即interpolation function
-
FCN loss from unlabeled data,其中 \(\hat{Y}_u\) 是GNN预测的结果的分布
-
FCN total loss,其中 \(w(t)\) 是一个sigmoid上升函数,训练过程会从0逐步增大
-
GraphMix total loss
-
-
实验
-
半监督节点分类问题对各种模型的表现都提升,且使得类特定的隐藏状态更加集中:作者将GraphMix用到GCN,GAT和Graph U-Net上,对于不同数据集(包含大数据集),通过准确率,tsne分类可视化,还有Class-specific Soft-Rank验证想法
-
对于只有少量标签数据的数据集使用GraphMix表现更好:作者对比了每类标签只保留5和10个后的少标签数据集下各模型的准确率