ImGCL:Revisiting Graph Contrastive Learning on Imbalanced Node Classification
论文地址:Revisiting Graph Contrastive Learning on Imbalanced Node Classification.pdf
Contribution
利用一个自适应采样策略的对比学习框架解决了数据集不平衡(长尾数据集)的性能受限问题,具体来说,根据算法目前的学习情况形成伪标签,逐步将数据集调整至平衡
Motivation
前人的工作大多忽略了无监督学习数据集的长尾问题,对比当时的SOTA—GBTbaseline在Amazon-Computers数据集 有通过更改尾部数据集学习频次从而解决数据集长尾问题的方案但由于需要标签,几乎不能在无监督学习上work
Method
思路:利用对原图不同类别和不同重要性的节点进行不同比例的down-sampling获得子图,作为新的iteration的输入,从而对齐最后的testset数据比例,即均等比例
Progressively Balanced Sampling (PBS)
Sampling Strategies
p
k
=
N
k
q
∑
i
=
1
K
N
i
q
\begin{aligned}p_k&=\frac{N_k^q}{\sum_{i=1}^KN_i^q}\end{aligned}
pk=∑i=1KNiqNkq表示了从
K
K
K个类中采样某一个节点的概率,
N
k
q
N^q_k
Nkq表示第
k
k
k类的节点个数,
q
∈
[
0
,
1
]
q\in[0,1]
q∈[0,1]根据不同策略调整 PBS 数据集设置:训练集long-tail,测试集balance 令
q
=
1
q=1
q=1,有
p
k
R
=
N
k
∑
i
=
1
K
N
i
p_{k}^{R}=\frac{N_{\boldsymbol{k}}}{\sum_{i=1}^{K}N_{\boldsymbol{i}}}
pkR=∑i=1KNiNk,为适应训练集和测试集过渡,引入
α
=
1
−
t
T
\alpha=1-\frac tT
α=1−Tt,有
p
k
P
B
=
α
∗
p
k
R
+
(
1
−
α
)
∗
p
k
M
=
α
∗
N
k
∑
i
=
1
K
N
i
+
(
1
−
α
)
∗
1
K
\begin{aligned}p_{k}^{\mathrm{PB}}& =\alpha*p_k^R+(1-\alpha)*p_k^M \\&=\alpha*\frac{N_k}{\sum_{i=1}^KN_i}+(1-\alpha)*\frac1K\end{aligned}
pkPB=α∗pkR+(1−α)∗pkM=α∗∑i=1KNiNk+(1−α)∗K1 Online Clustering Based PBS
为利用PBS方法适应性调整数据集类别分布,利用online cluster方法生成伪标签 聚类方法使用K-Means,K是一个超参数数量等于下游分类任务的类别数 学习一个形状为
D
×
K
D\times K
D×K的质心矩阵C,对所有节点的embedding表示计算使下列值最小,则认为该节点属于这个类别
K表示簇的数量,D是hidden dimension等同于节点embedding的长度,下式计算节点embedding和簇中心embedding的均方误差大小
min
C
∈
R
D
×
K
1
N
∑
n
=
1
N
min
y
^
n
∥
z
t
,
n
−
C
y
^
n
∥
2
2
such that
y
^
n
⊤
1
K
=
1
\begin{aligned}\min_{C\in\mathbb{R}^{D\times K}}\frac{1}{N}\sum_{n=1}^{N}\min_{\hat{y}_{n}}\|z_{t,n}-C\hat{y}_{n}\|_{2}^{2}\text{ such that}\quad\hat{y}_{n}^{\top}1_{K}=1\end{aligned}
C∈RD×KminN1n=1∑Ny^nmin∥zt,n−Cy^n∥22 such thaty^n⊤1K=1
z
t
,
n
z_{t,n}
zt,n表示第t个iteration时第n个节点的embedding,
z
t
,
n
∈
R
D
z_{t,n}\in\mathbb{R}^D
zt,n∈RD 以此获得独热向量
y
^
n
∈
R
+
K
\hat{y}_n\in\mathbb{R}_+^K
y^n∈R+K,表示节点属于第k个簇/类 Node Centrality Based PBS
计算节点中心性,根据节点重要性对不同类别节点进行down-sampling 利用PageRank方法进行节点重要性/中心性的计算,
σ
=
α
A
D
−
1
+
1
, where
σ
∈
R
N
\sigma=\alpha AD^{-1}+1\text{, where }\sigma\in\mathbb{R}^N
σ=αAD−1+1, where σ∈RN,
A
A
A是节点邻接矩阵,
D
D
D是节点的度矩阵,循环多次直到稳定获得节点重要性
即节点重要性的影响因素是节点本身的度,连接的节点的重要性,因此需要循环传播节点连接重要性
对某个类的某个节点的采样概率为
p
v
,
j
N
P
B
=
min
{
σ
v
−
σ
min
σ
max
−
σ
min
⋅
p
j
P
B
,
p
τ
}
p_{v,j}^{\mathrm{NPB}}=\min\left\{\frac{\sigma_{v}-\sigma_{\min}}{\sigma_{\max}-\sigma_{\min}}\cdot p_{j}^{\mathrm{PB}},p_{\tau}\right\}
pv,jNPB=min{σmax−σminσv−σmin⋅pjPB,pτ}
p
j
P
B
p_{j}^{\mathrm{PB}}
pjPB是PBS采样概率经正则化的概率小,即采样每类节点的概率,用于调整类别之间数量由不平衡线性过渡到平衡,
p
τ
p_{\tau}
pτ表示最低采样限度,防止部分边缘节点无法被采样
summary
Details
Dataset 利用了四个常用的数据集作为直推节点分类任务的数据,分别是Wiki-CS, Amazon-computers, Amazon-photo, and DBLP Train Set 8:1:1划分数据集,其中测试集验证集为平衡数据集,训练集不均等 Different Type Imbalance
Exp Imbalance 训练集不同类的采样比例遵循指数分布,参数越大越不平衡 Pareto Imbalance 训练集不同类的采样比例遵循Pareto分布,参数越小越不平衡
推荐链接
发表评论