原文:GIN: 如何设计最强大的图神经网络图神经网络(Graph Neural Networks,简称GNN)不仅仅局限于对节点进行分类,其中一个最受icon-default.png?t=N7T8https://mp.weixin.qq.com/s/C8_x3nobnNiysq-rWyoDCQ

系列教程GNN-algorithms之七:《图同构网络—GIN》【导读】自GCN异军突起后,图神经网络这个领域也逐渐壮大。但是疑惑也随之而来,为什么GNN会这么有效?论文How Powerful Are Graph Neural Networks给出了答案。本文将手把手教你搭建GIN模型。icon-default.png?t=N7T8https://mp.weixin.qq.com/s/RnZWkq1l6kF3j8nVkmCpwg 

背景

        图神经网络(Graph Neural Networks,简称GNN)不仅仅局限于对节点进行分类,其中一个最受欢迎的应用是图分类。在处理分子时,图分类是一项常见的任务,因为分子可以被表示为图,而每个原子(节点)的特征可以用来预测整个分子的行为

        如GCN和GraphSAGE,都是通过迭代聚合一阶邻居信息来更新节点的特征表示,可以拆分为三个步骤:

  • Aggregate:聚合一阶邻居节点的特征。
  • Combine:将邻域特征与中心节点的特征融合,更新中心节点的特征。
  • Readout:如果是图分类任务,需要把Graph中所有节点特征转换为Graph的特征表示。

然而,GNN只学习节点嵌入,而无法直接得到整个图的嵌入。为了解决这个问题,一种新型的层被提出,称为全局池化(global pooling),用于将节点嵌入组合起来生成整个图的嵌入。

此外,还有一种新型的GNN架构被设计出来,称为图同构网络(Graph Isomorphism Network,简称GIN)。这种架构由Xu等人于2018年提出,它通过对节点嵌入进行一系列的图同构操作,来学习整个图的表示。

通过使用全局池化层和图同构网络,GNN在图分类任务中取得了很好的效果。这些方法的提出使得GNN在处理图数据时更加灵活和强大。

在本文中,我们将详细介绍图同构网络(GIN)相对于图卷积网络(GCN)或GraphSAGE在判别能力方面的优势,并探讨它与Weisfeiler-Lehman测试的关联。除了其强大的聚合器外,GIN还为图神经网络(GNN)的整体提供了令人兴奋的见解。

PROTEINS数据集

图片

PROTEINS是一个在生物信息学中流行的数据集。它由1113个蛋白质图组成,其中

  • 节点表示氨基酸
  • 当两个节点之间的距离小于0.6纳米时,它们之间会有一条边相连

该数据集的目标是将每个蛋白质分类为酶或非酶。酶是一类特殊的蛋白质,在细胞中作为催化剂加速化学反应的速度。它们对于消化(例如脂肪酶)、呼吸(例如氧化酶)和人体其他重要功能至关重要。此外,酶还被用于商业应用,如抗生素的生产。

PROTEINS数据集也可以在TUDataset上找到,并且在PyTorch Geometric中有相应的实现。通过研究这个数据集,我们可以更好地理解GIN在图分类任务中的优势和应用。

from torch_geometric.datasets import TUDataset

dataset = TUDataset(root='.', name='PROTEINS').shuffle()

# Print information about the dataset
print(f'Dataset: {dataset}')
print('-------------------')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of nodes: {dataset[0].x.shape[0]}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
Dataset: PROTEINS(1113)
-------------------
Number of graphs: 1113
Number of nodes: 117
Number of features: 3
Number of classes: 2

如果你对蛋白质感兴趣,让我们将其中一个蛋白质绘制成图形,看看它的样子。

from torch_geometric.utils import to_networkx
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

G = to_networkx(dataset[2], to_undirected=True)

# 3D spring layout
pos = nx.spring_layout(G, dim=3, seed=0)

# Extract node and edge positions from the layout
node_xyz = np.array([pos[v] for v in sorted(G)])
edge_xyz = np.array([(pos[u], pos[v]) for u, v in G.edges()])

# Create the 3D figure
fig = plt.figure(figsize=(16,16))
ax = fig.add_subplot(111, projection="3d")

# Suppress tick labels
for dim in (ax.xaxis, ax.yaxis, ax.zaxis):
    dim.set_ticks([])

# Plot the nodes - alpha is scaled by "depth" automatically
ax.scatter(*node_xyz.T, s=500, c="#0A047A")

# Plot the edges
for vizedge in edge_xyz:
    ax.plot(*vizedge.T, color="tab:gray")

# fig.tight_layout()
plt.show()

图片

之前提到的3D结构是随机生成的,因为获得正确的3D表示是一个非常困难的问题,这也是AlphaFold的研究重点。

图形并不是表示分子的唯一方式。另一种常用的方法是使用简化的分子输入线条表示法(SMILES),它使用一行(字符串)的符号来表示分子。SMILES是通过对稍作修改的分子图进行深度优先树遍历时遇到的节点的打印而获得的。

研究人员在处理分子或化合物时经常使用这种表示方法。幸运的是,PROTEINS数据集已经以图形的形式进行了编码。否则,我们可能需要将SMILES字符串转换为networkx图。

from torch_geometric.loader import DataLoader

# Create training, validation, and test sets
train_dataset = dataset[:int(len(dataset)*0.8)]
val_dataset   = dataset[int(len(dataset)*0.8):int(len(dataset)*0.9)]
test_dataset  = dataset[int(len(dataset)*0.9):]

print(f'Training set   = {len(train_dataset)} graphs')
print(f'Validation set = {len(val_dataset)} graphs')
print(f'Test set       = {len(test_dataset)} graphs')

# Create mini-batches
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

print('\nTrain loader:')
for i, subgraph in enumerate(train_loader):
    print(f' - Subgraph {i}: {subgraph}')

print('\nValidation loader:')
for i, subgraph in enumerate(val_loader):
    print(f' - Subgraph {i}: {subgraph}')

print('\nTest loader:')
for i, subgraph in enumerate(test_loader):
    print(f' - Subgraph {i}: {subgraph}')
Training set   = 890 graphs
Validation set = 111 graphs
Test set       = 112 graphs

Train loader:
 - Subgraph 0: DataBatch(edge_index=[2, 7966], x=[2114, 3], y=[64], batch=[2114], ptr=[65])
 - Subgraph 1: DataBatch(edge_index=[2, 8492], x=[2263, 3], y=[64], batch=[2263], ptr=[65])
 - Subgraph 2: DataBatch(edge_index=[2, 9518], x=[2589, 3], y=[64], batch=[2589], ptr=[65])
 - Subgraph 3: DataBatch(edge_index=[2, 10846], x=[3008, 3], y=[64], batch=[3008], ptr=[65])
 - Subgraph 4: DataBatch(edge_index=[2, 9618], x=[2586, 3], y=[64], batch=[2586], ptr=[65])
 - Subgraph 5: DataBatch(edge_index=[2, 7572], x=[2027, 3], y=[64], batch=[2027], ptr=[65])
 - Subgraph 6: DataBatch(edge_index=[2, 10512], x=[2875, 3], y=[64], batch=[2875], ptr=[65])
 - Subgraph 7: DataBatch(edge_index=[2, 7034], x=[1855, 3], y=[64], batch=[1855], ptr=[65])
 - Subgraph 8: DataBatch(edge_index=[2, 11966], x=[3313, 3], y=[64], batch=[3313], ptr=[65])
 - Subgraph 9: DataBatch(edge_index=[2, 9898], x=[2764, 3], y=[64], batch=[2764], ptr=[65])
 - Subgraph 10: DataBatch(edge_index=[2, 8798], x=[2411, 3], y=[64], batch=[2411], ptr=[65])
 - Subgraph 11: DataBatch(edge_index=[2, 9922], x=[2736, 3], y=[64], batch=[2736], ptr=[65])
 - Subgraph 12: DataBatch(edge_index=[2, 10772], x=[2787, 3], y=[64], batch=[2787], ptr=[65])
 - Subgraph 13: DataBatch(edge_index=[2, 11140], x=[2782, 3], y=[58], batch=[2782], ptr=[59])

Validation loader:
 - Subgraph 0: DataBatch(edge_index=[2, 8240], x=[2088, 3], y=[64], batch=[2088], ptr=[65])
 - Subgraph 1: DataBatch(edge_index=[2, 5626], x=[1503, 3], y=[47], batch=[1503], ptr=[48])

Test loader:
 - Subgraph 0: DataBatch(edge_index=[2, 7946], x=[2156, 3], y=[64], batch=[2156], ptr=[65])
 - Subgraph 1: DataBatch(edge_index=[2, 6222], x=[1614, 3], y=[48], batch=[1614], ptr=[49])

虽然我们不会直接将PROTEINS数据集输入到我们的图神经网络(GNN)中,但是我们可以使用小批量处理来加快训练速度。即使PROTEINS数据集不是很大,使用小批量训练仍然是一个高效的方法。我们可以使用图卷积网络(GCN)或图注意力网络(GAT),但我想介绍一种新的架构,即图同构网络(Graph Isomorphism Network)。

图同构网络(GIN)

A. Weisfeiler-Lehman test

GIN是一种旨在最大化图神经网络(GNN)的表示能力的网络架构。表示能力的衡量方法之一是使用Weisfeiler-Lehman(WL)图同构测试。WL测试可以判断两个图是否非同构,但不能保证它们是同构的。

图片

尽管看起来可能不起眼,但要区分两个大型图形是一项非常困难的任务。实际上,这个问题被认为不可能在多项式时间内解决,也不被认为是NP完全的。它甚至可能处于计算复杂性类别NP-intermediate的中间位置(如果这个类别存在的话)。

// 这意味着找到一个有效的算法来判断两个图是否同构是一个具有挑战性的问题,并且可能需要使用更复杂的方法来解决。这也说明了为什么Weisfeiler-Lehman(WL)测试在图学习领域中引起了广泛的关注,因为它提供了一种近似判断图同构性的方法。通过将WL测试的思想应用于图神经网络(GNN)的学习过程中,我们可以提高GNN的表示能力,并在图分类任务中取得更好的性能。

WL测试与GNN的学习方式非常相似。在WL测试中,每个节点从相同的标签开始,然后将邻居节点的标签进行聚合和哈希处理,生成一个新的标签。重复这个过程,直到标签不再改变。

一些图学习领域的研究人员注意到了WL测试和GNN学习的相似之处,并将其应用于设计更强大的GNN架构,如GIN。通过使用GIN,我们可以提高GNN在图分类任务中的性能,并获得更准确的表示能力。如果你对WL测试感兴趣,我推荐阅读David Bieber的博文和Michael Bronstein的文章。

Weisfeiler-Lehman(WL)测试不仅与图神经网络(GNN)中特征向量聚合的方式相似,而且它具有区分图形的能力,使其比许多其他架构(如GCNs和GraphSAGE)更强大。

        WL test是判断两个Graph结构是否相同的有效方法,主要通过迭代以下步骤来判断Graph的同构性: (初始化:将节点的id作为自身的标签。)

  • 1. 聚合:将邻居节点和自身的标签进行聚合。
  • 2. 更新节点标签:使用Hash表将节点聚合标签映射作为节点的的新标签。

WL test迭代过程如下图:

图片

(此图引用自知乎 https://zhuanlan.zhihu.com/p/62006729,如有侵权,请联系删除)

上图a中的G图中节点1的邻居有节点4;节点2的邻居有节点3和节点5;节点3的邻居有节点2,节点4,节点5;节点4的邻居有节点1,节点3,节点5;节点5的邻居有节点2,节点3,节点4。(步骤1)聚合邻居节点和自身标签后的结果就是b图中的G。然后用Hash将聚合后的结果映射为一个新的标签,进行标签压缩,如图c。用压缩后的标签来替代之前的聚合结果,进行标签更新(步骤二),如图d,G‘同理。

    对于Graph的特征表示,WL test方法用迭代前后图中节点标签的个数作为Graph的表示特征,如图e所示。从上图我们可以看出WL_test的迭代过程和GNN的聚合过程非常相似,并且作者也证明了WL_test是图神经网络聚合邻域信息能力的上限

B.聚合器(aggregator)

为了实现与WL测试相似的效果,研究人员设计了一种新的聚合器,该聚合器在处理非同构图时能够生成不同的节点嵌入。

具体而言,研究人员使用了两个可逆函数作为解决方案,这两个函数的具体形式并不清楚。然而,通过使用多层感知机(MLP)来学习这两个函数的参数,可以实现近似地学习这些可逆函数

  • 图注意力网络(GAT)中,我们使用神经网络来学习任务特定的权重因子

  • 图同构网络(GIN)中,我们通过通用逼近定理来学习两个可逆函数的近似。这种方法使得GIN能够更好地捕捉图的结构信息,并提供更准确的表示能力。

使用GIN计算特定节点的隐藏向量的方法如下:

需要注意的是,MLP的多层结构在这里强调了其重要性。作者指出,在图学习中,单层结构是不够的。

C.全局池化

全局池化或图级读出是使用GNN计算的节点嵌入来生成图嵌入的过程。

一种简单的方法是使用每个节点h_i嵌入的平均值、总和或最大值来获得图嵌入h_G:

作者提出了两个重要观点来考虑图级读出:

  • 为了考虑所有的结构信息,需要保留前几层的嵌入。

  • 总和运算符比平均值和最大值更具表达能力。

基于这些观察结果,作者提出了以下全局池化方法:

对于每一层,将节点嵌入进行求和,然后将结果进行串联。这种解决方案将总和运算符的表达能力与串联中前几次迭代的记忆相结合。

 GIN节点的更新过程

对上述公式的解释

在PyTorch Geometric中的GIN

在PyTorch Geometric中,有一个名为GINConv的层,用于实现Graph Isomorphism Network(GIN)。然而,与原始设计相比,PyTorch Geometric中的实现有一些差异。

在GINConv层中,有几个参数可以调整:

  • nn:用于近似两个可逆函数的多层感知机(MLP)。

  • eps:ɛ的初始值,默认为0。

  • train_eps:一个布尔值,用于确定ɛ是否可训练,默认为False。

值得注意的是,在PyTorch Geometric中的实现中,默认情况下完全去除了ɛ参数。这意味着ɛ是一个可以调整的超参数,但可能不是必需的。

此外,PyTorch Geometric中还有一个名为GINEConv的第二个GIN层。它是根据论文中对GIN的实现而来的,该实现将ReLU函数应用于邻居节点的特征。然而,在本教程中我们不会使用GINEConv层,因为它的优势尚不清楚。

根据原始论文的灵感,我们需要为GINConv层设计一个MLP。

图片

原论文中堆叠了5层,但我们将使用3层。整个架构如下所示:

图片

尽管我找不到任何使用图嵌入串联的GIN实现,但这是我根据原论文的描述设计的版本。总体而言,它提高了1%的准确率。现在,我们将其与使用简单平均池化(无串联)的GCN进行比较。

import torch
import torch.nn.functional as F
from torch.nn import Linear, Sequential, BatchNorm1d, ReLU, Dropout
from torch_geometric.nn import GCNConv, GINConv
from torch_geometric.nn import global_mean_pool, global_add_pool


class GCN(torch.nn.Module):
    """GCN"""
    def __init__(self, dim_h):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, dim_h)
        self.conv2 = GCNConv(dim_h, dim_h)
        self.conv3 = GCNConv(dim_h, dim_h)
        self.lin = Linear(dim_h, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        # Node embeddings 
        h = self.conv1(x, edge_index)
        h = h.relu()
        h = self.conv2(h, edge_index)
        h = h.relu()
        h = self.conv3(h, edge_index)

        # Graph-level readout
        hG = global_mean_pool(h, batch)

        # Classifier
        h = F.dropout(hG, p=0.5, training=self.training)
        h = self.lin(h)
        
        return hG, F.log_softmax(h, dim=1)

class GIN(torch.nn.Module):
    """GIN"""
    def __init__(self, dim_h):
        super(GIN, self).__init__()
        self.conv1 = GINConv(
            Sequential(Linear(dataset.num_node_features, dim_h),
                       BatchNorm1d(dim_h), ReLU(),
                       Linear(dim_h, dim_h), ReLU()))
        self.conv2 = GINConv(
            Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(),
                       Linear(dim_h, dim_h), ReLU()))
        self.conv3 = GINConv(
            Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(),
                       Linear(dim_h, dim_h), ReLU()))
        self.lin1 = Linear(dim_h*3, dim_h*3)
        self.lin2 = Linear(dim_h*3, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        # Node embeddings 
        h1 = self.conv1(x, edge_index)
        h2 = self.conv2(h1, edge_index)
        h3 = self.conv3(h2, edge_index)

        # Graph-level readout
        h1 = global_add_pool(h1, batch)
        h2 = global_add_pool(h2, batch)
        h3 = global_add_pool(h3, batch)

        # Concatenate graph embeddings
        h = torch.cat((h1, h2, h3), dim=1)

        # Classifier
        h = self.lin1(h)
        h = h.relu()
        h = F.dropout(h, p=0.5, training=self.training)
        h = self.lin2(h)
        
        return h, F.log_softmax(h, dim=1)

gcn = GCN(dim_h=32)
gin = GIN(dim_h=32)
def train(model, loader):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(),
                                      lr=0.01,
                                      weight_decay=0.01)
    epochs = 100

    model.train()
    for epoch in range(epochs+1):
        total_loss = 0
        acc = 0
        val_loss = 0
        val_acc = 0

        # Train on batches
        for data in loader:
          optimizer.zero_grad()
          _, out = model(data.x, data.edge_index, data.batch)
          loss = criterion(out, data.y)
          total_loss += loss / len(loader)
          acc += accuracy(out.argmax(dim=1), data.y) / len(loader)
          loss.backward()
          optimizer.step()

          # Validation
          val_loss, val_acc = test(model, val_loader)

    # Print metrics every 10 epochs
    if(epoch % 10 == 0):
        print(f'Epoch {epoch:>3} | Train Loss: {total_loss:.2f} '
              f'| Train Acc: {acc*100:>5.2f}% '
              f'| Val Loss: {val_loss:.2f} '
              f'| Val Acc: {val_acc*100:.2f}%')
          
    test_loss, test_acc = test(model, test_loader)
    print(f'Test Loss: {test_loss:.2f} | Test Acc: {test_acc*100:.2f}%')
    
    return model

@torch.no_grad()
def test(model, loader):
    criterion = torch.nn.CrossEntropyLoss()
    model.eval()
    loss = 0
    acc = 0

    for data in loader:
        _, out = model(data.x, data.edge_index, data.batch)
        loss += criterion(out, data.y) / len(loader)
        acc += accuracy(out.argmax(dim=1), data.y) / len(loader)

    return loss, acc

def accuracy(pred_y, y):
    """Calculate accuracy."""
    return ((pred_y == y).sum() / len(y)).item()

gcn = train(gcn, train_loader)
gin = train(gin, train_loader)
Epoch 100 | Train Loss: 0.67 | Train Acc: 60.61% | Val Loss: 0.70 | Val Acc: 54.50%
Test Loss: 0.69 | Test Acc: 55.99%
Epoch 100 | Train Loss: 0.49 | Train Acc: 75.61% | Val Loss: 0.53 | Val Acc: 78.99%
Test Loss: 0.60 | Test Acc: 66.93%

这次的比较结果表明,GIN架构在性能上完全超越了GCN。平均准确率提高了10%,这个差距可以归因于以下几个原因:

  • GIN的聚合器经过专门设计,能够更好地区分那些GCN的聚合器无法区分的图形。

  • GIN架构中的每一层都将图隐藏向量进行了连接,而不仅仅考虑最后一层的隐藏向量。这样做的好处是能够充分利用每一层的信息,提升了分类的准确性。

  • 从理论上来说,求和运算符优于平均运算符。在GIN架构中,采用了求和运算符,这进一步增强了分类性能。

为了更直观地展示我们使用GCN和GIN进行分类的结果,让我们进行一下可视化。

fig, ax = plt.subplots(4, 4, figsize=(16,16))
fig.suptitle('GCN - Graph classification')

for i, data in enumerate(dataset[1113-16:]):
    # Calculate color (green if correct, red otherwise)
    _, out = gcn(data.x, data.edge_index, data.batch)
    color = "green" if out.argmax(dim=1) == data.y else "red"

    # Plot graph
    ix = np.unravel_index(i, ax.shape)
    ax[ix].axis('off')
    G = to_networkx(dataset[i], to_undirected=True)
    nx.draw_networkx(G,
                    pos=nx.spring_layout(G, seed=0),
                    with_labels=False,
                    node_size=150,
                    node_color=color,
                    width=0.8,
                    ax=ax[ix]
                    )

图片

fig, ax = plt.subplots(4, 4, figsize=(16,16))
fig.suptitle('GIN - Graph classification')

for i, data in enumerate(dataset[1113-16:]):
    # Calculate color (green if correct, red otherwise)
    _, out = gin(data.x, data.edge_index, data.batch)
    color = "green" if out.argmax(dim=1) == data.y else "red"

    # Plot graph
    ix = np.unravel_index(i, ax.shape)
    ax[ix].axis('off')
    G = to_networkx(dataset[i], to_undirected=True)
    nx.draw_networkx(G,
                    pos=nx.spring_layout(G, seed=0),
                    with_labels=False,
                    node_size=150,
                    node_color=color,
                    width=0.8,
                    ax=ax[ix]
                    )

图片

有趣的是,这两个模型会犯不同的错误。这在机器学习中是常见的现象,当不同的算法应用于同一个问题时会出现。

我们可以利用这种优势来创建一个集成模型。有许多方法可以组合我们的图嵌入。最简单的方法是取归一化输出向量的平均值。

gcn.eval()
gin.eval()
acc_gcn = 0
acc_gin = 0
acc = 0

for data in test_loader:
    # Get classifications
    _, out_gcn = gcn(data.x, data.edge_index, data.batch)
    _, out_gin = gin(data.x, data.edge_index, data.batch)
    out = (out_gcn + out_gin)/2

    # Calculate accuracy scores
    acc_gcn += accuracy(out_gcn.argmax(dim=1), data.y) / len(test_loader)
    acc_gin += accuracy(out_gin.argmax(dim=1), data.y) / len(test_loader)
    acc += accuracy(out.argmax(dim=1), data.y) / len(test_loader)

# Print results
print(f'GCN accuracy:     {acc_gcn*100:.2f}%')
print(f'GIN accuracy:     {acc_gin*100:.2f}%')
print(f'GCN+GIN accuracy: {acc*100:.2f}%')
GCN accuracy:     55.99%
GIN accuracy:     66.93%
GCN+GIN accuracy: 67.45%

幸运的是,我们发现这种集成模型的准确率有所提高。

当然,并不总是这样。更复杂的方法涉及构建一个完全不同的用于分类的机器学习算法,比如随机森林。这个分类器以图嵌入作为输入,并输出最终的分类结果。

结论

在本文中,我们探讨了图同构网络作为理解图神经网络的重要一步。这些网络不仅在多个基准测试中提高了准确性,还提供了一个理论框架来解释为什么一种架构比另一种更好。在本文中:

  • 我们介绍了一个新的任务,即使用全局池化进行图分类;

  • 我们介绍了WL测试,并讨论了它与新的GIN层的关系;

  • 我们实现了一个GIN模型和一个GCN模型,并使用它们的分类结果进行了简单的集成。

虽然GIN在社交图等方面表现良好,但在现实世界中,并不总是能够充分发挥其理论优势。这也适用于其他“可证明强大”的架构,如3WLGNN。

Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐