GNNExplainer
GNNExplainer论文名称:GNNExplainer: Generating Explanations for Graph Neural Networks论文地址:https://arxiv.org/abs/1903.03894GNN使用节点的特征和图的结构作为信息沿着边传递。这种整合使得模型的可解释性更加困难。我们建议的模型GNNEXPLAINER,是一种与模型无关的,可以为任何的GNN模
GNNExplainer
论文名称:GNNExplainer: Generating Explanations for Graph Neural Networks
论文地址:https://arxiv.org/abs/1903.03894
GNN使用节点的特征和图的结构作为信息沿着边传递。这种整合使得模型的可解释性更加困难。我们建议的模型GNNEXPLAINER,是一种与模型无关的,可以为任何的GNN模型提供解释。GNNEXPLAINER能够识别子图的结构和节点的特征,然后,对样本的实例作出解释。GNNEXPLAINER作为优化器,最大化GNN预测任务和子图结构之间的互信息,能够识别重要的图结构和特征。
GNNEXPLAINER将 trained GNN and its prediction(s)作为输入,返回输入图的子图和对预测结果产生影响的特征(Figure 1)。该方法是与模型无关的,可以解释基于GNN的机器学习任务,包括:节点分类、链路预测、图分类,它可以处理单条和多条样本。当处理单条样本时,GNNEXPLAINER针对该样本进行解释。(a node
label, a new link, a graph-level label)。当处理多条样本时,针对该样本集合进行解释。
GNNEXPLAINER用GNN训练时整个图的子图进行解释,该子图最大化与预测值之间互信息。
1. Formulating explanations for graph neural networks
设图为 G G G, 边为 E E E, 节点为 V V V, 节点的特征为 d d d 维, X = { x 1 , … , x n } , x i ∈ R d \mathcal{X}=\left\{x_{1}, \ldots, x_{n}\right\}, x_{i} \in \mathbb{R}^{d} X={x1,…,xn},xi∈Rd,其中, n n n是节点的数量。 f f f是节点label的映射函数。 f : V ↦ { 1 , … , C } f: V \mapsto\{1, \ldots, C\} f:V↦{1,…,C}, 将 V V V中的每个节点映射为 C C C类, GNN模型 Φ \Phi Φ在所有训练节点上进行优化,对新的节点进行预测。
1.1 Background on graph neural networks
在 l l l层, GNN模型包括关键三步。(1)第一步,计算节点对 ( v i , v j ) (v_i,v_j) (vi,vj)之间的message, h i l − 1 \mathbf{h}_i^{l-1} hil−1和 h j l − 1 \mathbf{h}_j^{l-1} hjl−1分别是前一层节点 i i i和节点 j j j的表示, r i j r_{ij} rij是两个节点之间的关系: m i j l = MSG ( h i l − 1 , h j l − 1 , r i j ) m_{i j}^{l}=\operatorname{MSG}\left(\mathbf{h}_{i}^{l-1}, \mathbf{h}_{j}^{l-1}, r_{i j}\right) mijl=MSG(hil−1,hjl−1,rij)(2),第二步,对于每个节点 v i v_i vi, GNN汇总aggregates它的邻居 N v i \mathcal{N}_{v_i} Nvi的信息, aggregated message M i M_i Mi的计算方式: M i l = AGG ( { m i j l ∣ v j ∈ N v i } ) M_{i}^{l}=\operatorname{AGG}\left(\left\{m_{i j}^{l} \mid v_{j} \in \mathcal{N}_{v_{i}}\right\}\right) Mil=AGG({mijl∣vj∈Nvi}). 其中 N v i \mathcal{N}_{v_i} Nvi是节点 v i v_i vi的邻居的节点,它的定义不同会产生不同的GNN变种。(3)GNN 使用聚合函数 M i l M_i^l Mil聚合节点 v i v_i vi的representation h i l − 1 \mathbf{h}_i^{l-1} hil−1, 然后进行非线性转换获得节点 v i v_i vi的节点在 l l l层表示 h i l \mathbf{h}_i^l hil: h i l = UPDATE ( M i l , h i l − 1 ) \mathbf{h}_{i}^{l}=\operatorname{UPDATE}\left(M_{i}^{l}, \mathbf{h}_{i}^{l-1}\right) hil=UPDATE(Mil,hil−1), 然后经过 L L L层获得最后的输出: z i = h i L \mathbf{z}_{i}=\mathbf{h}_{i}^{L} zi=hiL。
1.2 GNNEXPLAINER: Problem formulation
我们处理问题的关键是节点 v v v的计算,将节点邻居的信息进行汇总,产生节点 v v v的预测 y ^ \hat{y} y^。节点 v v v的最终输出为 z \mathbf{z} z. 图 G c ( v ) G_c(v) Gc(v)的计算与临接矩阵 A c ( v ) ∈ { 0 , 1 } n × n A_{c}(v) \in\{0,1\}^{n \times n} Ac(v)∈{0,1}n×n和节点特征 X c ( v ) = { x j ∣ v j ∈ G c ( v ) } X_{c}(v)=\left\{x_{j} \mid v_{j} \in G_{c}(v)\right\} Xc(v)={xj∣vj∈Gc(v)}有关。GNN模型 Φ \Phi Φ学习 Y Y Y的概率分布 P Φ ( Y ∣ G c , X c ) P_{\Phi}\left(Y \mid G_{c}, X_{c}\right) PΦ(Y∣Gc,Xc), 其中 Y Y Y代表标签 1 , ⋯ , C {1,\cdots,C} 1,⋯,C随机变量,即每个节点属于 C C C类中每个类别的概率。
GNN的预测 y ^ = Φ ( G c ( v ) , X c ( v ) ) \hat{y}=\Phi\left(G_{c}(v), X_{c}(v)\right) y^=Φ(Gc(v),Xc(v)),模型 Φ \Phi Φ主要是由图的结构信息 G c ( v ) G_c(v) Gc(v)和节点的特征 X c ( v ) X_c(v) Xc(v)决定的。一般地, GNNEXPLAINER将预测值 y ^ \hat{y} y^ 解释为 ( G S , X S F ) \left(G_{S}, X_{S}^{F}\right) (GS,XSF), 其中 G S G_S GS是预测图的子图, X S X_S XS是 G S G_S GS的节点特征, X S F X_S^F XSF是 G S G_S GS中节点的子集(通过 F F F进行mask, X S F = { x j F ∣ v j ∈ G S } X_{S}^{F}=\{x_{j}^{F} \mid v_{j} \in G_S\} XSF={xjF∣vj∈GS} )。
2 GNNEXPLAINER
接下来,我们介绍一下 GNNEXPLAINER 如何在单条(2.1, 2.2)和多条(2.3)上的预测进行模型解释。最后介绍GNNEXPLAINER在机器学习任务上的应用(2.4),如链路预测和图分类。
2.1 Single-instance explanations
给定一个节点
v
v
v, 我们的目标是识别子图
G
S
⊆
G
c
G_{S} \subseteq G_{c}
GS⊆Gc和相关特征
X
S
=
{
x
j
∣
v
j
∈
G
S
}
X_S=\left\{x_{j} \mid v_{j} \in G_{S}\right\}
XS={xj∣vj∈GS}, 这些对于GNN预测
y
^
\hat{y}
y^ 是非常重要的。现在,我们假设
X
S
X_S
XS是子集节点的特征,
d
d
d维。在2.2将要讨论哪一维特征能够对模型进行解释。使用互信息
M
I
MI
MI衡量重要性, GNNEXPLAINER优化框架如下:
max
G
S
M
I
(
Y
,
(
G
S
,
X
S
)
)
=
H
(
Y
)
−
H
(
Y
∣
G
=
G
S
,
X
=
X
S
)
(1)
\max _{G_{S}} M I\left(Y,\left(G_{S}, X_{S}\right)\right)=H(Y)-H\left(Y \mid G=G_{S}, X=X_{S}\right)\tag{1}
GSmaxMI(Y,(GS,XS))=H(Y)−H(Y∣G=GS,X=XS)(1)
对于节点
v
v
v,
M
I
MI
MI是衡量是当计算图被限制在子图
G
S
G_S
GS,节点特征被限制在
X
S
X_S
XS时,预测概率
y
^
=
Φ
(
G
c
,
X
c
)
\hat{y}=\Phi\left(G_{c}, X_{c}\right)
y^=Φ(Gc,Xc)的变化。
举例来说, v j ∈ G c ( v i ) , v j ≠ v i v_{j} \in G_{c}\left(v_{i}\right), v_{j} \neq v_{i} vj∈Gc(vi),vj=vi,如果移除 v j v_j vj, y ^ i \hat{y}_i y^i的概率显著下降, 则节点 v j v_j vj就是很好反事实解释。类似地, ( v j , v k ) ∈ G c ( v i ) , v j , v k ≠ v i \left(v_{j}, v_{k}\right) \in G_{c}\left(v_{i}\right), v_{j}, v_{k} \neq v_{i} (vj,vk)∈Gc(vi),vj,vk=vi,如果移除 v j v_j vj和 v k v_k vk之间的边, y ^ i \hat{y}_i y^i 的预测概率值显著下降,则 v j v_j vj和 v k v_k vk之间的边是很好的反事实解释。
在Eq.(1)中, 交叉项
H
(
Y
)
H(Y)
H(Y)是常数,因为模型
Φ
\Phi
Φ已经训练好,因此,最大化
Y
Y
Y和
(
G
S
,
X
S
)
(G_S,X_S)
(GS,XS)之间的互信息等于最小化条件熵
H
(
Y
∣
G
=
G
S
,
X
=
X
S
)
H\left(Y \mid G=G_{S}, X=X_{S}\right)
H(Y∣G=GS,X=XS),如下:
H
(
Y
∣
G
=
G
S
,
X
=
X
S
)
=
−
E
Y
∣
G
S
,
X
S
[
log
P
Φ
(
Y
∣
G
=
G
S
,
X
=
X
S
)
]
(2)
H\left(Y \mid G=G_{S}, X=X_{S}\right)=-\mathbb{E}_{Y \mid G_{S}, X_{S}}\left[\log P_{\Phi}\left(Y \mid G=G_{S}, X=X_{S}\right)\right]\tag{2}
H(Y∣G=GS,X=XS)=−EY∣GS,XS[logPΦ(Y∣G=GS,X=XS)](2)
以子图
G
S
G_S
GS对
y
^
\hat{y}
y^ 进行解释, 实际上最小化
Φ
\Phi
Φ的不确定性。实际上,最大化概率
y
^
\hat{y}
y^。为了给出简介的解释,我们给
G
S
G_S
GS增加限制:
∣
G
S
∣
≤
K
M
\left|G_{S}\right| \leq K_{M}
∣GS∣≤KM, 其中
G
S
G_S
GS最多有
K
M
K_M
KM个节点。这意味着, GNNEXPLAINER通过
K
M
K_M
KM边消除
G
C
G_C
GC的噪声,给出预测的最大互信息。
GNNEXPLAINER’s optimization framework. 对于
G
c
G_c
Gc来说,用于解释
y
^
\hat{y}
y^ 的子图
G
S
G_S
GS非常多,直接处理是非常困难的。我们考虑部分邻接矩阵的方式:
A
S
[
j
,
k
]
≤
A
c
[
j
,
k
]
A_{S}[j, k] \leq A_{c}[j, k]
AS[j,k]≤Ac[j,k],其中,
A
S
∈
[
0
,
1
]
n
×
n
A_{S} \in[0,1]^{n \times n}
AS∈[0,1]n×n, 对于所有
j
,
k
j,k
j,k增加以上限制。这个近似可以理解为子图是
G
c
G_c
Gc的近似。我们将
G
S
∼
G
G_{S} \sim \mathcal{G}
GS∼G看做图的随机变量,目标函数Eq.(2)可以变换为:
min
G
E
G
S
∼
G
H
(
Y
∣
G
=
G
S
,
X
=
X
S
)
(3)
\min _{\mathcal{G}} \mathbb{E}_{G_{S} \sim \mathcal{G}} H\left(Y \mid G=G_{S}, X=X_{S}\right)\tag{3}
GminEGS∼GH(Y∣G=GS,X=XS)(3)
由于凸的假设,使用Jensen不等式给出上限:
min
G
H
(
Y
∣
G
=
E
G
[
G
S
]
,
X
=
X
S
)
(4)
\min _{\mathcal{G}} H\left(Y \mid G=\mathbb{E}_{\mathcal{G}}\left[G_{S}\right], X=X_{S}\right)\tag{4}
GminH(Y∣G=EG[GS],X=XS)(4)
在实际中,由于神经网络的复杂性,凸的假设是不成立的,但是,最小化这个目标函数和正则项通常会带来比较的解释。
为了估计 E G \mathbb{E}_{\mathcal{G}} EG, 我们将其分解为multivariate Bernoulli distribution: P G ( G S ) = ∏ ( j , k ) ∈ G c A S [ j , k ] P_{\mathcal{G}}\left(G_{S}\right)=\prod_{(j, k) \in G_{c}} A_{S}[j, k] PG(GS)=∏(j,k)∈GcAS[j,k],其中 A S A_S AS的 ( j , k ) -th (j,k)\text{-th} (j,k)-th条目代表边 ( v j , v k ) (v_j,v_k) (vj,vk)之间是否有边存在。我们经验发现,使用正则项可以使得分解值收敛局部最小,即使GNN是非凸的。将Equation 4中 E G [ G S ] \mathbb{E}_G[G_S] EG[GS]替换为masking 邻接矩阵 A c ⊙ σ ( M ) A_{c} \odot \sigma(M) Ac⊙σ(M)进行优化 , M ∈ R n × n M \in \mathbb{R}^{n \times n} M∈Rn×n指的是Mask, ⊙ \odot ⊙指element-wise乘积, σ \sigma σ 指的是将mask映射为 [ 0 , 1 ] n × n [0,1]^{n \times n} [0,1]n×n.
在一些应用中,用户更关注如何将训练的模型用于预测想要的label。我们需要修改Equation4:
min
M
−
∑
c
=
1
C
1
[
y
=
c
]
log
P
Φ
(
Y
=
y
∣
G
=
A
c
⊙
σ
(
M
)
,
X
=
X
c
)
(5)
\min _{M}-\sum_{c=1}^{C} \mathbb{1}[y=c] \log P_{\Phi}\left(Y=y \mid G=A_{c} \odot \sigma(M), X=X_{c}\right)\tag{5}
Mmin−c=1∑C1[y=c]logPΦ(Y=y∣G=Ac⊙σ(M),X=Xc)(5)
该公式 使用Mask机制,将
σ
(
M
)
\sigma(M)
σ(M)和
A
c
A_c
Ac进行乘积,移除
M
M
M中小的值,以达到用子图
G
S
G_S
GS解释GNN对节点
v
v
v的预测值
y
^
\hat{y}
y^进行解释的作用。
2.2 Joint learning of graph structural and node feature information
为了识别节点特征对预测值
y
^
\hat{y}
y^的重要性, GNNEXPLAINER学习
G
S
G_S
GS节点特征
F
F
F选择器。与节点所有特征不同,
X
S
=
{
x
j
∣
v
j
∈
G
S
}
X_{S}=\left\{x_{j} \mid v_{j} \in G_{S}\right\}
XS={xj∣vj∈GS}, GNNEXPLAINER考虑
G
S
G_S
GS的子集特征
X
S
F
X_{S}^{F}
XSF, 特征的选择通过二值特征选择器
F
∈
{
0
,
1
}
d
F \in\{0,1\}^{d}
F∈{0,1}d(Figure 2B):
X
S
F
=
{
x
j
F
∣
v
j
∈
G
S
}
,
x
j
F
=
[
x
j
,
t
1
,
…
,
x
j
,
t
k
]
for
F
t
i
=
1
(6)
X_{S}^{F}=\left\{x_{j}^{F} \mid v_{j} \in G_{S}\right\}, \quad x_{j}^{F}=\left[x_{j, t_{1}}, \ldots, x_{j, t_{k}}\right] \text { for } F_{t_{i}}=1\tag{6}
XSF={xjF∣vj∈GS},xjF=[xj,t1,…,xj,tk] for Fti=1(6)
其中,
x
j
F
x_j^F
xjF是没有被
F
F
F mask out的节点特征。
(
G
S
,
X
S
)
(G_S,X_S)
(GS,XS)进行联合优化以最大化互信息:
max
G
S
,
F
M
I
(
Y
,
(
G
S
,
F
)
)
=
H
(
Y
)
−
H
(
Y
∣
G
=
G
S
,
X
=
X
S
F
)
(7)
\max _{G_{S}, F} M I\left(Y,\left(G_{S}, F\right)\right)=H(Y)-H\left(Y \mid G=G_{S}, X=X_{S}^{F}\right)\tag{7}
GS,FmaxMI(Y,(GS,F))=H(Y)−H(Y∣G=GS,X=XSF)(7)
该方程对Eq.(1)目标函数进行调整,同时考虑结构和节点特征两个方面,对预测
y
^
\hat{y}
y^进行解释。
Learning binary feature selector F F F. 我们设 X S = X S ⊙ F X_S=X_S\odot F XS=XS⊙F, 其中 F F F是需要学习的参数。如果某个特征不重要,GNN会使得它的权重为0. 实际上,若果这个特征不重要,移除这个特征预测值不会有太大的变化,如果这个特征重要,预测值会显著下降。但是这种方法会忽略一些特征很重要,但是取值接近0。为了解决这个问题,在训练的过程中,我们使用蒙特卡洛从节点 X S X_S XS的边缘经验分布抽样。然后,我们使用参数化技巧进行反向传播,学习feature mask F F F。特别地,随机变量 X X X计算如下: X = Z + ( X S − Z ) ⊙ F X=Z+\left(X_{S}-Z\right) \odot F X=Z+(XS−Z)⊙F s.t. ∑ j F j ≤ K F \sum_{j} F_{j} \leq K_{F} ∑jFj≤KF,其中 Z Z Z是从经验分布抽样的 d d d维随机变量, K F K_F KF是保留的最大特征的数量,是可学习的参数。
Integrating additional constraints into explanations. 为了强化可解释性,我们可以对Eq.(7)增加正则项。例如,为了使得structural and node feature masks to be discrete, 我们使用element-wise entropy,或者增加特定领域限制,如,拉格朗日正则项。我们也可以将mask的元素求和,作为正则项。
最后,需要注意的是对GNN进行解释必须是一个有效的计算图。因为解释 ( G S , X S ) \left(G_{S}, X_{S}\right) (GS,XS)必须允许GNN的message能够流向节点 v v v, 以此来预测 y ^ \hat{y} y^. 重要的是, GNNEXPLAINER 自动可以提供有效计算图,因为它会在整个图上优化structural mask。如果边是没有连接的,它不会被选择,不会影响最终GNN预测。
2.3 Multi-instance explanations through graph prototypes
我们的目标是分析子图如何对一类标签进行解释, GNNEXPLAINER能够基于 graph alignments and prototypes对多实例进行解释。
首先,我们先选择一个类别 c c c的参考样本样本点,例如,将其他节点embedding的均值赋值 c c c。我们利用 G S ( v c ) G_S(v_c) GS(vc)对 v c v_c vc进行解释,然后将解释赋值给这个类别 c c c的其他节点。如果在大图中,进行匹配是非常具有挑战性的。但是单条样本产生是一个小图,而且near-optimal pairwise graph matchings是非常高效的。
其次,我们将邻接矩阵进行汇总给a graph prototype A proto A_{\text{proto}} Aproto, 例如计算中位数. A proto A_{\text{proto}} Aproto用于识别graph patterns,它在同类别中是共享的。可以用于预测和模型解释。
2.4 GNNEXPLAINER model extensions
Any machine learning task on graphs. 除了能够解释节点分类,在不需要修改优化算法的情况下,GNNEXPLAINER可以用于链路预测和图分类。当对 ( v j , v k ) (v_j,v_k) (vj,vk)进行链路预测时,GNNEXPLAINER会学习 X S ( v j ) X_S(v_j) XS(vj)和 X S ( v k ) X_S(v_k) XS(vk)两个mask。当进行图分类时,会将我们想解释图的所有邻接矩阵进行union.
Any GNN model. 现在GNN主要基于 message passing构建各种结构, GNNEXPLAINER能够对它们进行解释。
Computational complexity. GNNEXPLAINER的优化取决于计算图 G c G_c Gc的大小, G c ( v ) G_c(v) Gc(v)的邻接矩阵 A c ( v ) A_c(v) Ac(v)等于mask M M M的大小,需要GNNEXPLAINER学习。但是,通常来说,计算图相对较小, 即使输入大图 ,GNNEXPLAINER也能对其进行有效的解释。
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
所有评论(0)