从ANN到SNN的转换:实现、原理及两种归一化方法

引言

随着神经形态计算的迅猛发展,脉冲神经网络(Spiking Neural Networks, SNNs)作为一种仿生神经计算模型,逐渐展现出其在低功耗和事件驱动计算领域的巨大潜力。不同于传统的人工神经网络(Artificial Neural Networks, ANNs),SNN通过二值化的脉冲信号进行信息传递,从而更接近生物神经元的行为。其离散时间、事件触发的处理模式使得SNN在能效和计算效率上具有天然的优势,尤其在神经形态硬件上更为适合。

尽管SNN具备诸多优点,但由于脉冲神经元的异质性以及神经元发放模式的离散性,直接训练SNN模型存在较大挑战。为此,基于ANN到SNN的转换方法成为了当前热门的研究方向。通过将预先训练好的ANN转换为SNN,研究人员能够在保留ANN性能的前提下,充分利用SNN的能效优势。本文将介绍如何通过一套系统的方法实现ANN到SNN的转换,并深入探讨两种归一化方法:MaxNorm和RobustNorm,帮助我们更好地理解这一过程的细节。

1. ANN2SNN转换概述

1.1 ANN与SNN的核心差异

人工神经网络(ANN) 中的神经元采用连续的激活函数,如ReLU、Sigmoid或Tanh等,激活值可以是任意实数。这种方式虽然能够实现复杂的非线性映射,但其计算能耗较高,且不具备生物神经元的事件驱动特性。

脉冲神经网络(SNN) 的工作原理与ANN有显著区别。SNN的神经元使用脉冲(spike)作为信息载体,激活方式通过离散脉冲的形式表现。每个神经元的发放过程是基于输入电压的累积,当累积的电压达到某个阈值时,神经元会“发放”脉冲信号。SNN中的常见神经元模型包括:

  • 积分发放神经元(Integrate-and-Fire, IF Neuron):IF神经元通过累积输入电压,当电压超过阈值时,神经元发放脉冲,随后电压重置为初始值。
  • 泄露积分发放神经元(Leaky Integrate-and-Fire, LIF Neuron):在IF模型的基础上增加了泄露机制,使得神经元的电压在没有持续输入时会随时间衰减,更加接近生物神经元的动态特性。

由于ANN和SNN在信息传递机制上的本质差异,直接将ANN的权重应用于SNN是不可行的。因此,在实现ANN到SNN的转换时,需要对神经元的行为和模型的结构进行调整。具体来说,主要挑战在于如何将ANN中的连续激活值有效地映射到SNN中的脉冲发放行为上。

1.2 ANN2SNN的转换流程

ANN到SNN的转换是一个系统化的过程,核心步骤包括训练ANN、激活值归一化处理以及神经元替换。整个流程可概括为以下五个步骤:

  1. 训练ANN模型:首先使用标准的机器学习框架(如PyTorch、TensorFlow)训练一个高性能的ANN模型。通常采用卷积神经网络(CNN)架构,在任务(如图像分类)上进行训练。
  2. 激活值记录:在ANN的训练过程中,插入电压钩子(Voltage Hook)以记录每层网络的激活值。这一步的目的是获取每层神经元的激活范围,便于后续的归一化处理。
  3. 归一化处理:对每层神经元的激活值进行归一化,确保ANN中的权重在SNN中依然能产生合理的神经元发放行为。最常用的两种归一化方法是基于最大值的MaxNorm和基于分位数的RobustNorm。
  4. 替换为脉冲神经元:将ANN中的连续激活函数(如ReLU)替换为SNN中的脉冲神经元(如IF或LIF神经元),并应用归一化系数对输入电压进行缩放。
  5. SNN仿真与验证:在多步时间仿真下运行SNN模型,并在特定任务(如图像分类)上验证SNN的性能。

2. 两种归一化方法

归一化处理是ANN2SNN转换中的关键步骤。
在这里插入图片描述

可以发现,两者的曲线几乎一致。需要注意的是,脉冲频率不可能高于1,因此IF神经元无法拟合ANN中ReLU的输入大于1的情况。

由于SNN神经元的发放特性不同于ANN中的连续激活函数,为了保证模型在转换后的SNN中依旧具有良好的表现,需要对输入的电压或电流进行适当的缩放。本文讨论了两种归一化方法:MaxNorm和RobustNorm。

2.1 MaxNorm归一化

MaxNorm是最简单的归一化方式,适用于没有大量噪声或异常激活值的数据。该方法的核心思想是将每层神经元的输入电压缩放到其激活值的最大范围内,以确保神经元能够有效发放脉冲。

  1. 激活值的最大值收集:遍历训练数据集,记录每一层ReLU激活的最大值( s m a x s_{max} smax)。
  2. 转换为SNN:替换ReLU层为IF神经元,激活值通过一个比例缩放:
    输入 = 输入 s m a x 输出 = 输出 × s m a x \text{输入} = \frac{\text{输入}}{s_{max}} \quad \text{输出} = \text{输出} \times s_{max} 输入=smax输入输出=输出×smax
    即,输入电压缩放为 1 / s m a x 1/s_{max} 1/smax ,IF神经元发放脉冲后,再将输出电压放大回 s m a x s_{max} smax

这种归一化方法的优点在于简单高效,尤其是在输入数据比较规整、没有极端异常值的情况下,能够较好地保持ANN模型的性能。

代码示例:
model._modules[name] = nn.Sequential(
    VoltageScaler(1.0 / max_item),    # 缩放输入
    neuron.IFNode(v_threshold=1., v_reset=None),    # IF神经元
    VoltageScaler(max_item)    # 恢复输出
)

2.2 RobustNorm归一化

RobustNorm归一化是一种更加稳健的归一化策略,特别适用于数据中可能包含噪声或异常激活值的情况。与MaxNorm不同,RobustNorm不直接使用最大值进行归一化,而是使用激活值的某个高分位数(如99.9%)来确定归一化系数。
在这里插入图片描述

这种方法减少了极端激活值对归一化过程的影响,确保模型在数据分布复杂或含有噪声的情况下能够保持性能。

  1. 激活值的分位数收集:遍历训练数据集,记录每一层ReLU激活的某个高分位数(如99.9%)。
  2. 归一化权重和偏置:在替换神经元之前,对权重和偏置进行缩放,确保层与层之间的比例一致。
  3. 转换为SNN:类似MaxNorm,将激活值进行分位数缩放。

这种方法通过调整每一层的权重,进一步优化了层间的信息传递,减少了转换过程中精度的损失。

代码示例:
# 在替换神经元之前,调整权重
if self.prev_scale is not None:
    current_scale = max_item
    prev_scale = self.prev_scale
    module.weight.data = module.weight.data * (prev_scale / current_scale)
    if hasattr(module, 'bias') and module.bias is not None:
        module.bias.data = module.bias.data * (prev_scale / current_scale)
self.prev_scale = max_item

3. 实现流程

在代码中,首先训练了一个具有较好精度的卷积神经网络(CNN)模型。随后使用VoltageHook来遍历训练数据,收集激活值的范围。根据收集到的最大激活值或分位数,进行归一化并替换成SNN中的IF神经元。

接下来详细解释代码中几个关键模块的功能,包括VoltageHookVoltageScalerConverter等。

3.1 VoltageHook

VoltageHook是一个自定义层,用于记录ANN中每一层的激活值。这个激活值在SNN中用于归一化(scaling)。在ANN的ReLU激活后,我们需要知道激活值的范围,以便后续归一化。

  • scale:保存激活层的尺度,用于后续的SNN模型归一化。
  • mode:决定使用最大值(MaxNorm)还是分位数(RobustNorm)来记录激活值。
class VoltageHook(nn.Module):
    def __init__(self, scale=1.0, mode='Max'):
        """
        确定在ANN推理中激活的范围。
        """
        super().__init__()
        self.register_buffer('scale', torch.tensor(scale))
        self.mode = mode

    def forward(self, x):
        if self.mode.lower() in ['max']:
            s_t = x.max().detach()  # 获取该层的最大激活值
        else:
            s_t = torch.tensor(np.percentile(x.detach().cpu(), float(self.mode[:-1])))  # 获取指定分位数的激活值

        self.scale = s_t  # 将激活值的最大值或分位数保存为该层的scale
        return x

3.2 VoltageScaler

VoltageScaler的作用是在SNN中对输入和输出进行电压的缩放。由于SNN神经元的行为与ANN不同,我们需要根据先前收集到的激活值对神经元输入和输出电压进行缩放。

  • scale:用于缩放输入电压或恢复输出电压。
  • forward:将输入乘以scale进行缩放。
class VoltageScaler(nn.Module):
    def __init__(self, scale=1.0):
        """
        缩放SNN推理中电流
        """
        super().__init__()
        self.register_buffer('scale', torch.tensor(scale))

    def forward(self, x):
        return x * self.scale  # 对输入电压进行缩放

3.3 Converter

Converter类负责从ANN到SNN的转换,并处理激活值归一化的过程。它包含三个主要功能:

  • 设置VoltageHook:遍历模型的每一层,并在ReLU激活层后插入VoltageHook,用于收集激活值。
  • 数据收集:通过训练数据集,计算每一层的激活值最大值或分位数。
  • 替换为IFNode:将ReLU层替换为SNN的IF神经元,并根据之前收集的scale进行电压的归一化。
class Converter(nn.Module):
    def __init__(self, dataloader, mode='Max'):
        super().__init__()
        self.mode = mode
        self.dataloader = dataloader
        self.device = None
        self.prev_scale = None  # 添加一个变量,用于存储前一层的最大激活值

    def forward(self, origin_model):
        # 创建模型的副本
        relu_model = copy.deepcopy(origin_model)
        if self.device is None:
            self.device = next(relu_model.parameters()).device
        relu_model.eval()

        # 插入 VoltageHook
        model = self.set_voltagehook(relu_model, mode=self.mode).to(self.device)

        # 使用训练数据集遍历模型,收集激活值
        for _, (imgs, _) in enumerate(tqdm(self.dataloader)):
            model(imgs.to(self.device))

        # 替换为 IFNode
        model = self.replace_by_ifnode(model)
        return model

    @staticmethod
    def set_voltagehook(model, mode='MaxNorm'):
        """
        在每个ReLU层后插入 VoltageHook,用于收集该层的激活值
        """
        for name, module in model._modules.items():
            if hasattr(module, "_modules"):
                model._modules[name] = Converter.set_voltagehook(module, mode=mode)
                if module.__class__.__name__ == 'ReLU':
                    model._modules[name] = nn.Sequential(
                        nn.ReLU(),
                        VoltageHook(mode=mode)  # 插入 VoltageHook
                    )
        return model

    def replace_by_ifnode(self, model):
        """
        将每层 ReLU 层替换为 IFNode 神经元
        """
        for name, module in model._modules.items():
            if hasattr(module, "_modules"):
                model._modules[name] = self.replace_by_ifnode(module)

                # 检查是否为ReLU层,并且有VoltageHook
                if module.__class__.__name__ == 'Sequential' and len(module) == 2 and \
                        module[0].__class__.__name__ == 'ReLU' and \
                        module[1].__class__.__name__ == 'VoltageHook':

                    max_item = module[1].scale.item()  # 获取 VoltageHook 中记录的最大激活值

                    # 替换为 SNN 层
                    model._modules[name] = nn.Sequential(
                        VoltageScaler(1.0 / max_item),  # 归一化输入
                        neuron.IFNode(v_threshold=1., v_reset=None),  # 替换为 IFNode
                        VoltageScaler(max_item)  # 恢复输出电压
                    )

        return model

3.4 核心流程概述

  1. 训练ANN模型:首先,训练一个卷积神经网络(CNN)在MNIST数据集上进行分类。
  2. 激活值收集:通过VoltageHook层,记录每一层ReLU的激活值,采用两种归一化方式:MaxNorm和RobustNorm。
  3. 模型转换:使用收集到的激活值,将ANN模型转换为SNN,将ReLU替换为IF神经元。
  4. SNN模拟与验证:通过多个时间步仿真SNN,并评估其精度。

通过上述流程,用户可以将ANN模型转换为SNN,并根据不同的归一化方式对其性能进行比较。

4. 实验结果

在实验中,采用了两种归一化方式进行SNN转换,分别为MaxNorm和RobustNorm。转换后的SNN通过50个时间步进行仿真,并比较了两种方法的精度随时间步长的变化。

  • MaxNorm:SNN模型使用最大值归一化,随着时间步长的增加,SNN的精度逐渐提高,最终在50个时间步的仿真下达到较高精度。
  • RobustNorm:基于99.9%分位数的归一化方式,SNN精度表现类似,但对异常激活值的敏感性较低。
    在这里插入图片描述

精度结果展示了两种转换方式的优势和不足,MaxNorm简单直接,而RobustNorm更加稳健。

5. 结论

通过本文的分析和实验,我们展示了ANN到SNN转换的一般方法,以及两种不同的归一化策略。MaxNorm适合简单的场景,而RobustNorm在噪声较大的数据上具有更好的鲁棒性。SNN模型的转换不仅能提升计算效率,还能在硬件中实现低功耗的神经网络应用,为未来神经形态计算的发展提供了有效的路径。

附录:完整代码

from tqdm import tqdm
from spikingjelly.clock_driven import neuron
import copy
import torchvision
import matplotlib.pyplot as plt
import torch.nn as nn
import torch
import numpy as np


class VoltageHook(nn.Module):
    def __init__(self, scale=1.0, mode='Max'):
        """
            确定在ANN推理中激活的范围。
        """
        super().__init__()
        self.register_buffer('scale', torch.tensor(scale))
        self.mode = mode

    def forward(self, x):
        if self.mode.lower() in ['max']:
            s_t = x.max().detach()
        else:
            s_t = torch.tensor(np.percentile(x.detach().cpu(), float(self.mode[:-1])))

        self.scale = s_t

        return x


class VoltageScaler(nn.Module):
    def __init__(self, scale=1.0):
        """
            缩放SNN推理中电流
        """
        super().__init__()
        self.register_buffer('scale', torch.tensor(scale))

    def forward(self, x):
        return x * self.scale

    # def extra_repr(self):
    #     return '%f' % self.scale.item()

class Converter(nn.Module):
    def __init__(self, dataloader, mode='Max'):
        super().__init__()
        self.mode = mode
        self.dataloader = dataloader
        self.device = None
        self.prev_scale = None  # 添加一个变量,用于存储前一层的最大激活值

    def forward(self, origin_model):
        relu_model = copy.deepcopy(origin_model)
        if self.device is None:
            self.device = next(relu_model.parameters()).device
        relu_model.eval()

        model = self.set_voltagehook(relu_model, mode=self.mode).to(self.device)

        for _, (imgs, _) in enumerate(tqdm(self.dataloader)):
            model(imgs.to(self.device))

        model = self.replace_by_ifnode(model)
        return model

    @staticmethod
    def set_voltagehook(model, mode='MaxNorm'):
        for name, module in model._modules.items():
            if hasattr(module, "_modules"):
                model._modules[name] = Converter.set_voltagehook(module, mode=mode)
                if module.__class__.__name__ == 'ReLU':
                    model._modules[name] = nn.Sequential(
                        nn.ReLU(),
                        VoltageHook(mode=mode)
                    )
        return model

    def replace_by_ifnode(self, model):
        for name, module in model._modules.items():
            if hasattr(module, "_modules"):
                model._modules[name] = self.replace_by_ifnode(module)

                # 检查是否为ReLU层,并且有VoltageHook
                if module.__class__.__name__ == 'Sequential' and len(module) == 2 and \
                        module[0].__class__.__name__ == 'ReLU' and \
                        module[1].__class__.__name__ == 'VoltageHook':

                    max_item = module[1].scale.item()

                    # # 在替换神经元之前,调整权重
                    # if self.prev_scale is not None:
                    #     # 获取前一层的最大值 (𝜆^(𝑙−1)) 和当前层的最大值 (𝜆^𝑙)
                    #     current_scale = max_item
                    #     prev_scale = self.prev_scale
                    #
                    #     # 按照 𝐖^𝑙 → 𝐖^𝑙 * (𝜆^(𝑙−1) / 𝜆^𝑙) 调整权重
                    #     if hasattr(module, 'weight'):
                    #         module.weight.data = module.weight.data * (prev_scale / current_scale)
                    #     elif hasattr(module, 'bias') and module.bias is not None:
                    #         module.bias.data = module.bias.data * (prev_scale / current_scale)
                    #
                    # # 更新 prev_scale 为当前层的最大值
                    # self.prev_scale = max_item

                    # 替换为 SNN 层
                    model._modules[name] = nn.Sequential(
                        VoltageScaler(1.0 / max_item),
                        neuron.IFNode(v_threshold=1., v_reset=None),
                        VoltageScaler(max_item)
                    )

        return model


class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1),  # 输入通道为1,输出通道为32,卷积核为3x3,步长为1
            nn.BatchNorm2d(32),  # 批归一化
            nn.ReLU(),  # ReLU激活
            nn.Conv2d(32, 32, 3, 1),  # 第二个卷积层,通道数不变
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 2x2最大池化

            nn.Conv2d(32, 64, 3, 1),  # 第三个卷积层,输出通道为64
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, 1),  # 第四个卷积层,通道数保持64
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 2x2最大池化

            nn.Flatten(),  # 展平操作,将卷积层的输出展平为一维向量
            nn.Linear(64 * 4 * 4, 512),  # 全连接层,输入为卷积层输出的展平结果
            nn.ReLU(),  # ReLU激活
            nn.Linear(512, 10),  # 最后一层,全连接层,输出为10类
            nn.Softmax(dim=1)  # Softmax 输出
        )
    #
    # def __init__(self):
    #     super().__init__()
    #     self.network = nn.Sequential(
    #         nn.Conv2d(1, 32, 3, 1),
    #         nn.BatchNorm2d(32),
    #         nn.ReLU(),
    #         nn.AvgPool2d(2, 2),
    #
    #         nn.Conv2d(32, 32, 3, 1),
    #         nn.BatchNorm2d(32),
    #         nn.ReLU(),
    #         nn.AvgPool2d(2, 2),
    #
    #         nn.Conv2d(32, 32, 3, 1),
    #         nn.BatchNorm2d(32),
    #         nn.ReLU(),
    #         nn.AvgPool2d(2, 2),
    #
    #         nn.Flatten(),
    #         nn.Linear(32, 10)
    #     )

    def forward(self, x):
        x = self.network(x)
        return x


def val(net, device, data_loader, T=None):
    net.eval().to(device)
    correct = 0.0
    total = 0.0
    if T is not None:
        corrects = np.zeros(T)
    with torch.no_grad():
        for batch, (img, label) in enumerate(tqdm(data_loader)):
            img = img.to(device)
            if T is None:
                out = net(img)
                correct += (out.argmax(dim=1) == label.to(device)).float().sum().item()
            else:
                for m in net.modules():
                    if hasattr(m, 'reset'):
                        m.reset()
                for t in range(T):
                    if t == 0:
                        out = net(img)
                    else:
                        out += net(img)
                    corrects[t] += (out.argmax(dim=1) == label.to(device)).float().sum().item()
            total += out.shape[0]
    return correct / total if T is None else corrects / total


def main():
    torch.random.manual_seed(0)
    torch.cuda.manual_seed(0)
    device = 'cuda'
    dataset_dir = '../MNIST'
    batch_size = 100
    T = 50
    # 训练参数
    lr = 1e-3
    epochs = 10

    model = CNN().to(device)
    train_data_dataset = torchvision.datasets.MNIST(
        root=dataset_dir,
        train=True,
        transform=torchvision.transforms.ToTensor(),
        download=True)
    train_data_loader = torch.utils.data.DataLoader(
        dataset=train_data_dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=False)
    test_data_dataset = torchvision.datasets.MNIST(
        root=dataset_dir,
        train=False,
        transform=torchvision.transforms.ToTensor(),
        download=True)
    test_data_loader = torch.utils.data.DataLoader(
        dataset=test_data_dataset,
        batch_size=50,
        shuffle=True,
        drop_last=False)
    #
    # # 定义损失函数和优化器
    # loss_function = nn.CrossEntropyLoss()
    # optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
    # # 开始训练模型
    # for epoch in range(epochs):
    #     model.train()
    #     running_loss = 0.0
    #     for (img, label) in train_data_loader:
    #         optimizer.zero_grad()
    #         out = model(img.to(device))
    #         loss = loss_function(out, label.to(device))
    #         loss.backward()
    #         optimizer.step()
    #         running_loss += loss.item()
    #
    #     print(f'Epoch [{epoch + 1}/{epochs}], Loss: {running_loss / len(train_data_loader):.4f}')
    #
    #     # 保存模型
    #     torch.save(model.state_dict(), 'paper_mnist_cnn_model.pth')
    #
    #     # 每个epoch后验证精度
    #     acc = val(model, device, test_data_loader)
    #     print(f'Validation Accuracy after epoch {epoch + 1}: {acc:.3f}')
    #     print()

    model.load_state_dict(torch.load('paper_mnist_cnn_model.pth'))
    acc = val(model, device, test_data_loader)
    print('ANN Validating Accuracy: %.4f' % (acc))

    # 使用转换后的模型
    print('---------------------------------------------')
    print('Converting using MaxNorm')
    model_converter = Converter(mode='max', dataloader=train_data_loader)
    snn_model = model_converter(model)
    print('Simulating...')
    mode_max_accs = val(snn_model, device, test_data_loader, T=T)
    print(f'SNN accuracy (simulation {T} time-steps): {mode_max_accs[-1]:.4f}')

    # 后续其他转换逻辑保持不变
    print('---------------------------------------------')
    print('Converting using RobustNorm')
    model_converter = Converter(mode='99.9%', dataloader=train_data_loader)
    snn_model = model_converter(model)
    print('Simulating...')
    mode_robust_accs = val(snn_model, device, test_data_loader, T=T)
    print(f'SNN accuracy (simulation {T} time-steps): {mode_robust_accs[-1]:.4f}')

    # 绘制不同转换方式下的精度随时间步长的变化
    fig = plt.figure()
    plt.plot(np.arange(0, T), mode_max_accs, label='mode: max')
    plt.plot(np.arange(0, T), mode_robust_accs, label='mode: 99.9%')
    plt.legend()
    plt.xlabel('t')
    plt.ylabel('Acc')
    plt.show()


if __name__ == '__main__':
    main()

参考链接:
ANN转换SNN — spikingjelly alpha 文档
Frontiers | Conversion of Continuous-Valued Deep Networks to Efficient Event-Driven Networks for Image Classification

Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐