GNNExplainer

论文名称:GNNExplainer: Generating Explanations for Graph Neural Networks

论文地址:https://arxiv.org/abs/1903.03894

GNN使用节点的特征和图的结构作为信息沿着边传递。这种整合使得模型的可解释性更加困难。我们建议的模型GNNEXPLAINER,是一种与模型无关的,可以为任何的GNN模型提供解释。GNNEXPLAINER能够识别子图的结构和节点的特征,然后,对样本的实例作出解释。GNNEXPLAINER作为优化器,最大化GNN预测任务和子图结构之间的互信息,能够识别重要的图结构和特征。

GNNEXPLAINER将 trained GNN and its prediction(s)作为输入,返回输入图的子图和对预测结果产生影响的特征(Figure 1)。该方法是与模型无关的,可以解释基于GNN的机器学习任务,包括:节点分类、链路预测、图分类,它可以处理单条和多条样本。当处理单条样本时,GNNEXPLAINER针对该样本进行解释。(a node
label, a new link, a graph-level label)。当处理多条样本时,针对该样本集合进行解释。

GNNEXPLAINER用GNN训练时整个图的子图进行解释,该子图最大化与预测值之间互信息。

在这里插入图片描述

在这里插入图片描述

1. Formulating explanations for graph neural networks

设图为 G G G, 边为 E E E, 节点为 V V V, 节点的特征为 d d d 维, X = { x 1 , … , x n } , x i ∈ R d \mathcal{X}=\left\{x_{1}, \ldots, x_{n}\right\}, x_{i} \in \mathbb{R}^{d} X={x1,,xn},xiRd,其中, n n n是节点的数量。 f f f是节点label的映射函数。 f : V ↦ { 1 , … , C } f: V \mapsto\{1, \ldots, C\} f:V{1,,C}, 将 V V V中的每个节点映射为 C C C类, GNN模型 Φ \Phi Φ在所有训练节点上进行优化,对新的节点进行预测。

1.1 Background on graph neural networks

l l l层, GNN模型包括关键三步。(1)第一步,计算节点对 ( v i , v j ) (v_i,v_j) (vi,vj)之间的message, h i l − 1 \mathbf{h}_i^{l-1} hil1 h j l − 1 \mathbf{h}_j^{l-1} hjl1分别是前一层节点 i i i和节点 j j j的表示, r i j r_{ij} rij是两个节点之间的关系: m i j l = MSG ⁡ ( h i l − 1 , h j l − 1 , r i j ) m_{i j}^{l}=\operatorname{MSG}\left(\mathbf{h}_{i}^{l-1}, \mathbf{h}_{j}^{l-1}, r_{i j}\right) mijl=MSG(hil1,hjl1,rij)(2),第二步,对于每个节点 v i v_i vi, GNN汇总aggregates它的邻居 N v i \mathcal{N}_{v_i} Nvi的信息, aggregated message M i M_i Mi的计算方式: M i l = AGG ⁡ ( { m i j l ∣ v j ∈ N v i } ) M_{i}^{l}=\operatorname{AGG}\left(\left\{m_{i j}^{l} \mid v_{j} \in \mathcal{N}_{v_{i}}\right\}\right) Mil=AGG({mijlvjNvi}). 其中 N v i \mathcal{N}_{v_i} Nvi是节点 v i v_i vi的邻居的节点,它的定义不同会产生不同的GNN变种。(3)GNN 使用聚合函数 M i l M_i^l Mil聚合节点 v i v_i vi的representation h i l − 1 \mathbf{h}_i^{l-1} hil1, 然后进行非线性转换获得节点 v i v_i vi的节点在 l l l层表示 h i l \mathbf{h}_i^l hil: h i l = UPDATE ⁡ ( M i l , h i l − 1 ) \mathbf{h}_{i}^{l}=\operatorname{UPDATE}\left(M_{i}^{l}, \mathbf{h}_{i}^{l-1}\right) hil=UPDATE(Mil,hil1), 然后经过 L L L层获得最后的输出: z i = h i L \mathbf{z}_{i}=\mathbf{h}_{i}^{L} zi=hiL

1.2 GNNEXPLAINER: Problem formulation

我们处理问题的关键是节点 v v v的计算,将节点邻居的信息进行汇总,产生节点 v v v的预测 y ^ \hat{y} y^。节点 v v v的最终输出为 z \mathbf{z} z. 图 G c ( v ) G_c(v) Gc(v)的计算与临接矩阵 A c ( v ) ∈ { 0 , 1 } n × n A_{c}(v) \in\{0,1\}^{n \times n} Ac(v){0,1}n×n和节点特征 X c ( v ) = { x j ∣ v j ∈ G c ( v ) } X_{c}(v)=\left\{x_{j} \mid v_{j} \in G_{c}(v)\right\} Xc(v)={xjvjGc(v)}有关。GNN模型 Φ \Phi Φ学习 Y Y Y的概率分布 P Φ ( Y ∣ G c , X c ) P_{\Phi}\left(Y \mid G_{c}, X_{c}\right) PΦ(YGc,Xc), 其中 Y Y Y代表标签 1 , ⋯   , C {1,\cdots,C} 1,,C随机变量,即每个节点属于 C C C类中每个类别的概率。

GNN的预测 y ^ = Φ ( G c ( v ) , X c ( v ) ) \hat{y}=\Phi\left(G_{c}(v), X_{c}(v)\right) y^=Φ(Gc(v),Xc(v)),模型 Φ \Phi Φ主要是由图的结构信息 G c ( v ) G_c(v) Gc(v)和节点的特征 X c ( v ) X_c(v) Xc(v)决定的。一般地, GNNEXPLAINER将预测值 y ^ \hat{y} y^ 解释为 ( G S , X S F ) \left(G_{S}, X_{S}^{F}\right) (GS,XSF), 其中 G S G_S GS是预测图的子图, X S X_S XS G S G_S GS的节点特征, X S F X_S^F XSF G S G_S GS中节点的子集(通过 F F F进行mask, X S F = { x j F ∣ v j ∈ G S } X_{S}^{F}=\{x_{j}^{F} \mid v_{j} \in G_S\} XSF={xjFvjGS} )。

2 GNNEXPLAINER

接下来,我们介绍一下 GNNEXPLAINER 如何在单条(2.1, 2.2)和多条(2.3)上的预测进行模型解释。最后介绍GNNEXPLAINER在机器学习任务上的应用(2.4),如链路预测和图分类。

2.1 Single-instance explanations

给定一个节点 v v v, 我们的目标是识别子图 G S ⊆ G c G_{S} \subseteq G_{c} GSGc和相关特征 X S = { x j ∣ v j ∈ G S } X_S=\left\{x_{j} \mid v_{j} \in G_{S}\right\} XS={xjvjGS}, 这些对于GNN预测 y ^ \hat{y} y^ 是非常重要的。现在,我们假设 X S X_S XS是子集节点的特征, d d d维。在2.2将要讨论哪一维特征能够对模型进行解释。使用互信息 M I MI MI衡量重要性, GNNEXPLAINER优化框架如下:
max ⁡ G S M I ( Y , ( G S , X S ) ) = H ( Y ) − H ( Y ∣ G = G S , X = X S ) (1) \max _{G_{S}} M I\left(Y,\left(G_{S}, X_{S}\right)\right)=H(Y)-H\left(Y \mid G=G_{S}, X=X_{S}\right)\tag{1} GSmaxMI(Y,(GS,XS))=H(Y)H(YG=GS,X=XS)(1)
对于节点 v v v, M I MI MI是衡量是当计算图被限制在子图 G S G_S GS,节点特征被限制在 X S X_S XS时,预测概率 y ^ = Φ ( G c , X c ) \hat{y}=\Phi\left(G_{c}, X_{c}\right) y^=Φ(Gc,Xc)的变化。

举例来说, v j ∈ G c ( v i ) , v j ≠ v i v_{j} \in G_{c}\left(v_{i}\right), v_{j} \neq v_{i} vjGc(vi),vj=vi,如果移除 v j v_j vj, y ^ i \hat{y}_i y^i的概率显著下降, 则节点 v j v_j vj就是很好反事实解释。类似地, ( v j , v k ) ∈ G c ( v i ) , v j , v k ≠ v i \left(v_{j}, v_{k}\right) \in G_{c}\left(v_{i}\right), v_{j}, v_{k} \neq v_{i} (vj,vk)Gc(vi),vj,vk=vi,如果移除 v j v_j vj v k v_k vk之间的边, y ^ i \hat{y}_i y^i 的预测概率值显著下降,则 v j v_j vj v k v_k vk之间的边是很好的反事实解释。

在Eq.(1)中, 交叉项 H ( Y ) H(Y) H(Y)是常数,因为模型 Φ \Phi Φ已经训练好,因此,最大化 Y Y Y ( G S , X S ) (G_S,X_S) (GS,XS)之间的互信息等于最小化条件熵 H ( Y ∣ G = G S , X = X S ) H\left(Y \mid G=G_{S}, X=X_{S}\right) H(YG=GS,X=XS),如下:
H ( Y ∣ G = G S , X = X S ) = − E Y ∣ G S , X S [ log ⁡ P Φ ( Y ∣ G = G S , X = X S ) ] (2) H\left(Y \mid G=G_{S}, X=X_{S}\right)=-\mathbb{E}_{Y \mid G_{S}, X_{S}}\left[\log P_{\Phi}\left(Y \mid G=G_{S}, X=X_{S}\right)\right]\tag{2} H(YG=GS,X=XS)=EYGS,XS[logPΦ(YG=GS,X=XS)](2)
以子图 G S G_S GS y ^ \hat{y} y^ 进行解释, 实际上最小化 Φ \Phi Φ的不确定性。实际上,最大化概率 y ^ \hat{y} y^。为了给出简介的解释,我们给 G S G_S GS增加限制: ∣ G S ∣ ≤ K M \left|G_{S}\right| \leq K_{M} GSKM, 其中 G S G_S GS最多有 K M K_M KM个节点。这意味着, GNNEXPLAINER通过 K M K_M KM边消除 G C G_C GC的噪声,给出预测的最大互信息。

GNNEXPLAINER’s optimization framework. 对于 G c G_c Gc来说,用于解释 y ^ \hat{y} y^ 的子图 G S G_S GS非常多,直接处理是非常困难的。我们考虑部分邻接矩阵的方式: A S [ j , k ] ≤ A c [ j , k ] A_{S}[j, k] \leq A_{c}[j, k] AS[j,k]Ac[j,k],其中, A S ∈ [ 0 , 1 ] n × n A_{S} \in[0,1]^{n \times n} AS[0,1]n×n, 对于所有 j , k j,k j,k增加以上限制。这个近似可以理解为子图是 G c G_c Gc的近似。我们将 G S ∼ G G_{S} \sim \mathcal{G} GSG看做图的随机变量,目标函数Eq.(2)可以变换为:
min ⁡ G E G S ∼ G H ( Y ∣ G = G S , X = X S ) (3) \min _{\mathcal{G}} \mathbb{E}_{G_{S} \sim \mathcal{G}} H\left(Y \mid G=G_{S}, X=X_{S}\right)\tag{3} GminEGSGH(YG=GS,X=XS)(3)
由于凸的假设,使用Jensen不等式给出上限:
min ⁡ G H ( Y ∣ G = E G [ G S ] , X = X S ) (4) \min _{\mathcal{G}} H\left(Y \mid G=\mathbb{E}_{\mathcal{G}}\left[G_{S}\right], X=X_{S}\right)\tag{4} GminH(YG=EG[GS],X=XS)(4)
在实际中,由于神经网络的复杂性,凸的假设是不成立的,但是,最小化这个目标函数和正则项通常会带来比较的解释。

为了估计 E G \mathbb{E}_{\mathcal{G}} EG, 我们将其分解为multivariate Bernoulli distribution: P G ( G S ) = ∏ ( j , k ) ∈ G c A S [ j , k ] P_{\mathcal{G}}\left(G_{S}\right)=\prod_{(j, k) \in G_{c}} A_{S}[j, k] PG(GS)=(j,k)GcAS[j,k],其中 A S A_S AS ( j , k ) -th (j,k)\text{-th} (j,k)-th条目代表边 ( v j , v k ) (v_j,v_k) (vj,vk)之间是否有边存在。我们经验发现,使用正则项可以使得分解值收敛局部最小,即使GNN是非凸的。将Equation 4中 E G [ G S ] \mathbb{E}_G[G_S] EG[GS]替换为masking 邻接矩阵 A c ⊙ σ ( M ) A_{c} \odot \sigma(M) Acσ(M)进行优化 , M ∈ R n × n M \in \mathbb{R}^{n \times n} MRn×n指的是Mask, ⊙ \odot 指element-wise乘积, σ \sigma σ 指的是将mask映射为 [ 0 , 1 ] n × n [0,1]^{n \times n} [0,1]n×n.

在一些应用中,用户更关注如何将训练的模型用于预测想要的label。我们需要修改Equation4:
min ⁡ M − ∑ c = 1 C 1 [ y = c ] log ⁡ P Φ ( Y = y ∣ G = A c ⊙ σ ( M ) , X = X c ) (5) \min _{M}-\sum_{c=1}^{C} \mathbb{1}[y=c] \log P_{\Phi}\left(Y=y \mid G=A_{c} \odot \sigma(M), X=X_{c}\right)\tag{5} Mminc=1C1[y=c]logPΦ(Y=yG=Acσ(M),X=Xc)(5)
该公式 使用Mask机制,将 σ ( M ) \sigma(M) σ(M) A c A_c Ac进行乘积,移除 M M M中小的值,以达到用子图 G S G_S GS解释GNN对节点 v v v的预测值 y ^ \hat{y} y^进行解释的作用。

2.2 Joint learning of graph structural and node feature information

为了识别节点特征对预测值 y ^ \hat{y} y^的重要性, GNNEXPLAINER学习 G S G_S GS节点特征 F F F选择器。与节点所有特征不同, X S = { x j ∣ v j ∈ G S } X_{S}=\left\{x_{j} \mid v_{j} \in G_{S}\right\} XS={xjvjGS}, GNNEXPLAINER考虑 G S G_S GS的子集特征 X S F X_{S}^{F} XSF, 特征的选择通过二值特征选择器 F ∈ { 0 , 1 } d F \in\{0,1\}^{d} F{0,1}d(Figure 2B):
X S F = { x j F ∣ v j ∈ G S } , x j F = [ x j , t 1 , … , x j , t k ]  for  F t i = 1 (6) X_{S}^{F}=\left\{x_{j}^{F} \mid v_{j} \in G_{S}\right\}, \quad x_{j}^{F}=\left[x_{j, t_{1}}, \ldots, x_{j, t_{k}}\right] \text { for } F_{t_{i}}=1\tag{6} XSF={xjFvjGS},xjF=[xj,t1,,xj,tk] for Fti=1(6)
其中, x j F x_j^F xjF是没有被 F F F mask out的节点特征。 ( G S , X S ) (G_S,X_S) (GS,XS)进行联合优化以最大化互信息:
max ⁡ G S , F M I ( Y , ( G S , F ) ) = H ( Y ) − H ( Y ∣ G = G S , X = X S F ) (7) \max _{G_{S}, F} M I\left(Y,\left(G_{S}, F\right)\right)=H(Y)-H\left(Y \mid G=G_{S}, X=X_{S}^{F}\right)\tag{7} GS,FmaxMI(Y,(GS,F))=H(Y)H(YG=GS,X=XSF)(7)
该方程对Eq.(1)目标函数进行调整,同时考虑结构和节点特征两个方面,对预测 y ^ \hat{y} y^进行解释。

Learning binary feature selector F F F. 我们设 X S = X S ⊙ F X_S=X_S\odot F XS=XSF, 其中 F F F是需要学习的参数。如果某个特征不重要,GNN会使得它的权重为0. 实际上,若果这个特征不重要,移除这个特征预测值不会有太大的变化,如果这个特征重要,预测值会显著下降。但是这种方法会忽略一些特征很重要,但是取值接近0。为了解决这个问题,在训练的过程中,我们使用蒙特卡洛从节点 X S X_S XS的边缘经验分布抽样。然后,我们使用参数化技巧进行反向传播,学习feature mask F F F。特别地,随机变量 X X X计算如下: X = Z + ( X S − Z ) ⊙ F X=Z+\left(X_{S}-Z\right) \odot F X=Z+(XSZ)F s.t. ∑ j F j ≤ K F \sum_{j} F_{j} \leq K_{F} jFjKF,其中 Z Z Z是从经验分布抽样的 d d d维随机变量, K F K_F KF是保留的最大特征的数量,是可学习的参数。

Integrating additional constraints into explanations. 为了强化可解释性,我们可以对Eq.(7)增加正则项。例如,为了使得structural and node feature masks to be discrete, 我们使用element-wise entropy,或者增加特定领域限制,如,拉格朗日正则项。我们也可以将mask的元素求和,作为正则项。

最后,需要注意的是对GNN进行解释必须是一个有效的计算图。因为解释 ( G S , X S ) \left(G_{S}, X_{S}\right) (GS,XS)必须允许GNN的message能够流向节点 v v v, 以此来预测 y ^ \hat{y} y^. 重要的是, GNNEXPLAINER 自动可以提供有效计算图,因为它会在整个图上优化structural mask。如果边是没有连接的,它不会被选择,不会影响最终GNN预测。

2.3 Multi-instance explanations through graph prototypes

我们的目标是分析子图如何对一类标签进行解释, GNNEXPLAINER能够基于 graph alignments and prototypes对多实例进行解释。

首先,我们先选择一个类别 c c c的参考样本样本点,例如,将其他节点embedding的均值赋值 c c c。我们利用 G S ( v c ) G_S(v_c) GS(vc) v c v_c vc进行解释,然后将解释赋值给这个类别 c c c的其他节点。如果在大图中,进行匹配是非常具有挑战性的。但是单条样本产生是一个小图,而且near-optimal pairwise graph matchings是非常高效的。

其次,我们将邻接矩阵进行汇总给a graph prototype A proto A_{\text{proto}} Aproto, 例如计算中位数. A proto A_{\text{proto}} Aproto用于识别graph patterns,它在同类别中是共享的。可以用于预测和模型解释。

2.4 GNNEXPLAINER model extensions

Any machine learning task on graphs. 除了能够解释节点分类,在不需要修改优化算法的情况下,GNNEXPLAINER可以用于链路预测和图分类。当对 ( v j , v k ) (v_j,v_k) (vj,vk)进行链路预测时,GNNEXPLAINER会学习 X S ( v j ) X_S(v_j) XS(vj) X S ( v k ) X_S(v_k) XS(vk)两个mask。当进行图分类时,会将我们想解释图的所有邻接矩阵进行union.

Any GNN model. 现在GNN主要基于 message passing构建各种结构, GNNEXPLAINER能够对它们进行解释。

Computational complexity. GNNEXPLAINER的优化取决于计算图 G c G_c Gc的大小, G c ( v ) G_c(v) Gc(v)的邻接矩阵 A c ( v ) A_c(v) Ac(v)等于mask M M M的大小,需要GNNEXPLAINER学习。但是,通常来说,计算图相对较小, 即使输入大图 ,GNNEXPLAINER也能对其进行有效的解释。

Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐