在这里插入图片描述

项目地址:https://github.com/VainF/Torch-Pruning

Torch-Pruning 是一个专用于torch的模型剪枝库,其基于DepGraph 技术分析出模型layer中的依赖关系。DepGraph 与现有的修剪方法(如 Magnitude Pruning 或 Taylor Pruning)相结合可以达到良好的剪枝效果。

本博文结合项目官网案例,对信息进行结构话,抽离出剪枝技术说明、剪枝模型保存与加载、剪枝技术的基本使用,剪枝技术的具体使用案例。并结合外部信息,分析剪枝对模型性能精度的影响。

1、基本说明

1.1 项目安装

打开https://github.com/VainF/Torch-Pruning,下载项目
在这里插入图片描述
然后在终端中,进入项目目录,并执行pip install -r requirements.txt 安装项目依赖库
在这里插入图片描述
然后在执行 pip install -e . ,将项目安装在当前目录下,并设置为editing模式。
在这里插入图片描述
验证安装:执行命令python -c "import torch_pruning", 如果没有输出报错信息则表示安装成功。
在这里插入图片描述

1.2 DepGraph 技术说明

在结构修剪中,组被定义为深度网络中最小的可移除单元。每个组由多个相互依赖的层组成,需要同时修剪这些层以保持最终结构的完整性。然而,深度网络通常表现出层与层之间错综复杂的依赖关系,这对结构修剪提出了重大挑战。这项研究通过引入一种名为 DepGraph 的自动化机制来解决这一挑战,该机制可以轻松实现参数分组,并有助于修剪各种深度网络。
在这里插入图片描述

直接剪枝会会破坏layer间的依赖关系,会导致forward流程报错。具体如下面代码,移除model.conv1模块中的idxs为0与1的channel,导致后续的bn1层输入输入与参数格式对不上号,然后报错。

from torchvision.models import resnet18
import torch_pruning as tp
import torch

model = resnet18().eval()
tp.prune_conv_out_channels(model.conv1, idxs=[0,1]) # remove channel 0 and channel 1
output = model(torch.randn(1,3,224,224)) # test

在这里插入图片描述
基本在后续层添加剪枝,运行代码也会保存,因为batchnorm的下一层要求的输出channel是64。

model = resnet18(pretrained=True).eval()
tp.prune_conv_out_channels(model.conv1, idxs=[0,1]) 
tp.prune_batchnorm_out_channels(model.bn1, idxs=[0,1])
tp.prune_batchnorm_in_channels(model.layer1[0].conv1, idxs=[0,1])
output = model(torch.randn(1,3,224,224)) 

使用DepGraph剪枝代码如下,先使用tp.DependencyGraph().build_dependenc构建出依赖图,然后基于DG.get_pruning_group函数获取目标剪枝层的依赖关系组,最后在检验关系并进行剪枝。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True).eval()

# 1. build dependency graph for resnet18
DG = tp.DependencyGraph().build_dependency(model, example_inputs=torch.randn(1,3,224,224))

# 2. Specify the to-be-pruned channels. Here we prune those channels indexed by [2, 6, 9].
group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9] )

# 3. prune all grouped layers that are coupled with model.conv1 (included).
print(group)
if DG.check_pruning_group(group): # avoid full pruning, i.e., channels=0.
    group.prune()
    
# 4. Save & Load
model.zero_grad() # We don't want to store gradient information
torch.save(model, 'model.pth') # without .state_dict
model = torch.load('model.pth') # load the model object

代码执行后的输出如下所示,可以看到捕捉到group对应的依赖layer

--------------------------------
          Pruning Group
--------------------------------
[0] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), idxs=[2, 6, 9] (Pruning Root)
[1] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[2] prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp_20(ReluBackward0), idxs=[2, 6, 9]
[3] prune_out_channels on _ElementWiseOp_20(ReluBackward0) => prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0), idxs=[2, 6, 9]
[4] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_out_channels on _ElementWiseOp_18(AddBackward0), idxs=[2, 6, 9]
[5] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_in_channels on layer1.0.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[6] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[7] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on _ElementWiseOp_17(ReluBackward0), idxs=[2, 6, 9]
[8] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_out_channels on _ElementWiseOp_16(AddBackward0), idxs=[2, 6, 9]
[9] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_in_channels on layer1.1.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[10] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[11] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on _ElementWiseOp_15(ReluBackward0), idxs=[2, 6, 9]
[12] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.downsample.0 (Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)), idxs=[2, 6, 9]
[13] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.conv1 (Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[14] prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.1.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[15] prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.0.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
--------------------------------

1.3 剪枝模型的保存与加载

剪枝后的模型由于网络结构改变了,如果只保存模型参数,是无法支持原始网络结构,需要将模型结构连参数一并保存。加载时连同参数一起加载。

model.zero_grad() # We don't want to store gradient information
torch.save(model, 'model.pth') # without .state_dict
model = torch.load('model.pth') # load the pruned model

或者基于tp库中tp.state_dict函数提取目标参数进行保存,并基于tp.load_state_dict函数将剪枝后的参数赋值到原始模型中形成剪枝模型。

# save the pruned state_dict, which includes both pruned parameters and modified attributes
state_dict = tp.state_dict(pruned_model) # the pruned model, e.g., a resnet-18-half
torch.save(state_dict, 'pruned.pth')

# create a new model, e.g. resnet18
new_model = resnet18().eval()

# load the pruned state_dict into the unpruned model.
loaded_state_dict = torch.load('pruned.pth', map_location='cpu')
tp.load_state_dict(new_model, state_dict=loaded_state_dict)
print(new_model) # This will be a pruned model.

2、剪枝基本案例

2.1 具有目标结构的剪枝

以下代码使用TaylorImportance指标进行剪枝,设置忽略输出层的剪枝。并设置MagnitudePruner中对通道剪枝50%,一共分iterative_steps步完成剪枝,每一次剪枝都进行微调。
整体来说,具备目标结构的剪枝,效果是最差的。 基于https://blog.csdn.net/a486259/article/details/140407147 分析的数据得出的结论。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

#model = resnet18(pretrained=True)
model = resnet18()

# Importance criteria
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.TaylorImportance()

ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # DO NOT prune the final classifier!

iterative_steps = 5 # progressive pruning
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    iterative_steps=iterative_steps,
    ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    #pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    ignored_layers=ignored_layers,
)

base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
    if isinstance(imp, tp.importance.TaylorImportance):
        # Taylor expansion requires gradients for importance estimation
        loss = model(example_inputs).sum() # a dummy loss for TaylorImportance
        loss.backward() # before pruner.step()
    pruner.step()
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    print(f"iter {i} | rate:{macs/base_macs:.4f}  {nparams/base_nparams:.4f}")
print(model)
    # finetune your model here
    # finetune(model)
    # ...

代码的输出信息如下所示,可以看到macs与nparams在逐步降低。最终输出的模型结构,所有的chanel都减半了,只有输出层例外。

iter 0 | rate:0.8092  0.8111
iter 1 | rate:0.6469  0.6445
iter 2 | rate:0.4971  0.4979
iter 3 | rate:0.3718  0.3695
iter 4 | rate:0.2674  0.2614
ResNet(
  (conv1): Conv2d(3, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=256, out_features=1000, bias=True)
)
PS D:\开源项目\Torch-Pruning-master>
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=256, out_features=1000, bias=True)
)
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=256, out_features=1000, bias=True)
)

2.2 自动结构剪枝

这里的自动结构是有一个预设目标,即将总体channel剪枝到原模型的多少,但没有预定的目标结构。可能有的laye通道剪枝数多,有的剪枝数少。 与2.1中的代码相比,主要是增加了参数 global_pruning=True。但这个剪枝方式比具有目标结构的剪枝更加有效。就像裁员一样,要求各个部门内裁员比例相同与在公司内控制裁员比例(各个部门裁员比例按重要度排列,裁员比例不一样),必然是第二种方式更有效。第一种方式,使低效率部门的靠前但无用员工保留下来了。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

#model = resnet18(pretrained=True)
model = resnet18()

# Importance criteria
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.TaylorImportance()

ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # DO NOT prune the final classifier!

iterative_steps = 3 # progressive pruning
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    iterative_steps=iterative_steps,
    pruning_ratio=0.5, # remove 50%的channel
    ignored_layers=ignored_layers,
    global_pruning=True
)

base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
    if isinstance(imp, tp.importance.TaylorImportance):
        # Taylor expansion requires gradients for importance estimation
        loss = model(example_inputs).sum() # a dummy loss for TaylorImportance
        loss.backward() # before pruner.step()
    pruner.step()
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    print(f"iter {i} | rate:{macs/base_macs:.4f}  {nparams/base_nparams:.4f}")
print(model)
    # finetune your model here
    # finetune(model)
    # ...

2.3 稀疏化剪枝

一些修剪器,如BNScalePruner和GroupNormPuner,支持稀疏训练。通过在标准训练循环中插入pruner.update_regulalizer()和pruner.regularization(model),可以很容易地实现这一点。修剪器将把正则化梯度累积到.grad。

for epoch in range(epochs):
    model.train()
    pruner.update_regularizer() # <== initialize regularizer
    for i, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = F.cross_entropy(out, target)
        loss.backward() # after loss.backward()
        pruner.regularize(model) # <== for sparse training
        optimizer.step() # before optimizer.step()

2.4 MagnitudePruner中的参数

指定特定层的剪枝比例 通过pruning_ratio_dict参数,指定model.layer2的剪枝比例为20%,这里适用于有先验经验的layer,控制对特定layer的剪枝比例。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18()
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2)

pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    imp,
    pruning_ratio = 0.5,
    pruning_ratio_dict = {model.layer2: 0.2}
)
pruner.step()
print(model)

代码执行后的层为:ResNet{64, 128, 256, 512} => ResNet{32, 102, 128, 256}

设置最大剪枝比例 通过 max_pruning_ratio 参数设置最大剪枝比例,避免由于稀疏剪枝或者自动剪枝时某个层被严重剪枝或者移除。

剪枝次数与剪枝调度器 您打算分多轮修剪模型,iterative_steps 会很有用。默认情况下,修剪器会逐渐增加模型的稀疏度,直到达到所需的 pruning_ratio。如以下代码,分5次实现剪枝目标。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18()
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2)

iterative_steps = 5 # progressive pruning
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    iterative_steps=iterative_steps,
    pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
)

# prune the model, iteratively if necessary.
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
    pruner.step()
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    print("Round %d/%d, Params: %.2f M" % (i+1, iterative_steps, nparams/1e6))
    # finetune your model here
    # finetune(model)
    # ...
print(model)

对应输出如下
Round 1/5, Params: 9.44 M
Round 2/5, Params: 7.45 M
Round 3/5, Params: 5.71 M
Round 4/5, Params: 4.20 M
Round 5/5, Params: 2.93 M

设置忽略的层 这主要是避免对输出层进行剪枝,修改模型的输出结构。使用代码如下,通过ignored_layers参数传入忽略的layer对象。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18()
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2)

pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    pruning_ratio=0.5, # remove 50% channels
    ignored_layers=[model.conv1, model.fc] # ignore the first & last layers
)
pruner.step()
print(model)

channel取整 在很多的时候都认为channel为16的倍数,gpu运行效率最高。使用代码如下,通过round_to参数,保持channel是特定数的倍数。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18()
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2)

pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    pruning_ratio=0.3, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    round_to=10 # round to 10x. Note: 10x is not a good practice.
)

pruner.step()
print(model)

channel_groups 某些层(例如 nn.GroupNorm 和 nn.Conv2d)具有 group 参数,这会在层内引入额外的依赖项。修剪后,保持所有组的大小相同至关重要。为了满足这一要求,引入了参数 channel_groups 以启用对这些通道的手动分组。如以下代码,通过channel_groups参数,控制model.group_conv1中的参数为8个一组

pruner = tp.pruner.MagnitudePruner(
            model,
            example_inputs=example_inputs,
            importance=importance,
            iterative_steps=1,
            pruning_ratio=0.5,
            channel_groups = {model.group_conv1: 8} # For Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), groups=8)
        )

额外参数剪枝 有些时候模型具备的可训练参数并非conv、fc等传统layer中,需要基于unwrapped_parameters参数将额外的可剪枝参数传入到剪枝器中。具体如下所示:

from torchvision.models.convnext import CNBlock, ConvNeXt
unwrapped_parameters = []
for m in model.modules():
    if isinstance(m, CNBlock):
        unwrapped_parameters.append( (m.layer_scale, 0) )

pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    pruning_ratio=0.5, 
    unwrapped_parameters=unwrapped_parameters 

限定剪枝范围 root_module_types 参数用于指定组的“根”或第一层。在许多情况下,我们专注于修剪线性层和卷积 (Conv) 层。要专门针对这些层启用修剪,我们可以使用以下参数:root_module_types=[nn.Conv2D, nn.Linear]。这可确保将修剪应用于所需的层。

pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    pruning_ratio=0.5, 
    root_module_types=[nn.Conv2D, nn.Linear]

3、具体应用案例

3.1 timm模型剪枝

官方代码为:examples\timm_models\prune_timm_models.py
具体详情如下,这里有一个特殊用法,是通过num_heads参数实现对于transformer layer的支持

import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))))
os.environ['TIMM_FUSED_ATTN'] = '0'
import torch
import torch.nn as nn 
import torch.nn.functional as F
from typing import Sequence
import timm
from timm.models.vision_transformer import Attention
import torch_pruning as tp
import argparse

parser = argparse.ArgumentParser(description='Prune timm models')
parser.add_argument('--model', default=None, type=str, help='model name')
parser.add_argument('--pruning_ratio', default=0.5, type=float, help='channel pruning ratio')
parser.add_argument('--global_pruning', default=False, action='store_true', help='global pruning')
parser.add_argument('--pretrained', default=False, action='store_true', help='global pruning')
parser.add_argument('--list_models', default=False, action='store_true', help='list all models in timm')
args = parser.parse_args()

def main():
    timm_models = timm.list_models()
    if args.list_models:
        print(timm_models)
    if args.model is None: 
        return
    assert args.model in timm_models, "Model %s is not in timm model list: %s"%(args.model, timm_models)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = timm.create_model(args.model, pretrained=args.pretrained, no_jit=True).eval().to(device)

    imp = tp.importance.GroupNormImportance()
    print("Pruning %s..."%args.model)
        
    input_size = model.default_cfg['input_size']
    example_inputs = torch.randn(1, *input_size).to(device)
    test_output = model(example_inputs)
    ignored_layers = []
    num_heads = {}

    for m in model.modules():
        if hasattr(m, 'head'): #isinstance(m, nn.Linear) and m.out_features == model.num_classes:
            ignored_layers.append(model.head)
            print("Ignore classifier layer: ", m.head)
       
        # Attention layers
        if hasattr(m, 'num_heads'):
            if hasattr(m, 'qkv'):
                num_heads[m.qkv] = m.num_heads
                print("Attention layer: ", m.qkv, m.num_heads)
            elif hasattr(m, 'qkv_proj'):
                num_heads[m.qkv_proj] = m.num_heads

    print("========Before pruning========")
    print(model)
    base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs)
    pruner = tp.pruner.MetaPruner(
                    model, 
                    example_inputs, 
                    global_pruning=args.global_pruning, # If False, a uniform pruning ratio will be assigned to different layers.
                    importance=imp, # importance criterion for parameter selection
                    iterative_steps=1, # the number of iterations to achieve target pruning ratio
                    pruning_ratio=args.pruning_ratio, # target pruning ratio
                    num_heads=num_heads,
                    ignored_layers=ignored_layers,
                )
    for g in pruner.step(interactive=True):
        g.prune()

    for m in model.modules():
        # Attention layers
        if hasattr(m, 'num_heads'):
            if hasattr(m, 'qkv'):
                m.num_heads = num_heads[m.qkv]
                m.head_dim = m.qkv.out_features // (3 * m.num_heads)
            elif hasattr(m, 'qkv_proj'):
                m.num_heads = num_heads[m.qqkv_projkv]
                m.head_dim = m.qkv_proj.out_features // (3 * m.num_heads)

    print("========After pruning========")
    print(model)
    test_output = model(example_inputs)
    pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs)
    print("MACs: %.4f G => %.4f G"%(base_macs/1e9, pruned_macs/1e9))
    print("Params: %.4f M => %.4f M"%(base_params/1e6, pruned_params/1e6))

if __name__=='__main__':
    main()

3.2 llm模型剪枝

在examples\LLMs\prune_llama.py中提供了一个对于llama模型的剪枝案例.
核心代码如下,可以看到也是基于num_heads记录transformer的结构信息,然后在剪枝后将num_heads数据赋值到对应模型参数上。与原始代码相比,这里删除了模型精度验证相关的代码。


# Code adapted from 
# https://github.com/IST-DASLab/sparsegpt/blob/master/datautils.py
# https://github.com/locuslab/wanda

import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))))

import argparse
import os 
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from importlib.metadata import version
import time
import torch
import torch.nn as nn
from collections import defaultdict
import fnmatch
import numpy as np
import random

print('torch', version('torch'))
print('transformers', version('transformers'))
print('accelerate', version('accelerate'))
print('# of gpus: ', torch.cuda.device_count())

def get_llm(model_name, cache_dir="./cache"):
    model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        torch_dtype=torch.float16, 
        cache_dir=cache_dir, 
        device_map="auto"
    )

    model.seqlen = model.config.max_position_embeddings 
    return model

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, help='LLaMA model')
    parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.')
    parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration samples.')
    parser.add_argument('--pruning_ratio', type=float, default=0, help='Sparsity level')
    parser.add_argument("--cache_dir", default="./cache", type=str )
    parser.add_argument('--save', type=str, default=None, help='Path to save results.')
    parser.add_argument('--save_model', type=str, default=None, help='Path to save the pruned model.')
    parser.add_argument("--eval_zero_shot", action="store_true")
    args = parser.parse_args()

    # Setting seeds for reproducibility
    np.random.seed(args.seed)
    torch.random.manual_seed(args.seed)

    model_name = args.model.split("/")[-1]
    print(f"loading llm model {args.model}")
    model = get_llm(args.model, args.cache_dir)       
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
    device = torch.device("cuda:0")
    if "30b" in args.model or "65b" in args.model: # for 30b and 65b we use device_map to load onto multiple A6000 GPUs, thus the processing here.
        device = model.hf_device_map["lm_head"]
    print("use device ", device)

    ##############
    # Pruning
    ##############
    print("----------------- Before Pruning -----------------")
    print(model)
    text = "Hello world."
    inputs = torch.tensor(tokenizer.encode(text)).unsqueeze(0).to(model.device)
    import torch_pruning as tp 
    num_heads = {}
    for name, m in model.named_modules():
        if name.endswith("self_attn"):
            num_heads[m.q_proj] = model.config.num_attention_heads
            num_heads[m.k_proj] = model.config.num_key_value_heads
            num_heads[m.v_proj] = model.config.num_key_value_heads
            
    head_pruning_ratio = args.pruning_ratio
    hidden_size_pruning_ratio = args.pruning_ratio
    pruner = tp.pruner.MagnitudePruner(
        model, 
        example_inputs=inputs,
        importance=tp.importance.GroupNormImportance(),
        global_pruning=False,
        pruning_ratio=hidden_size_pruning_ratio,
        ignored_layers=[model.lm_head],
        num_heads=num_heads,
        prune_num_heads=True,
        prune_head_dims=False,
        head_pruning_ratio=head_pruning_ratio,
    )
    pruner.step()

    # Update model attributes
    num_heads = int( (1-head_pruning_ratio) * model.config.num_attention_heads )
    num_key_value_heads = int( (1-head_pruning_ratio) * model.config.num_key_value_heads )
    model.config.num_attention_heads = num_heads
    model.config.num_key_value_heads = num_key_value_heads
    for name, m in model.named_modules():
        if name.endswith("self_attn"):
            m.hidden_size = m.q_proj.out_features
            m.num_heads = num_heads
            m.num_key_value_heads = num_key_value_heads
        elif name.endswith("mlp"):
            model.config.intermediate_size = m.gate_proj.out_features
    print("----------------- After Pruning -----------------")
    print(model)

    #ppl_test = eval_ppl(args, model, tokenizer, device)
    #print(f"wikitext perplexity {ppl_test}")

    if args.save_model:
        model.save_pretrained(args.save_model)
        tokenizer.save_pretrained(args.save_model)

if __name__ == '__main__':
    main()

3.3 目标检测模型剪枝

在Torch-Pruning 库中提供了针对yolov8、yolov7、yolov5的剪枝案例。关于yolov8还提供了剪枝后的训练策略,其主要技巧在与对不可剪枝层的可剪枝话处理(C2f模块的剪枝,其含split操作,不利于剪枝索引)。后续会补充博客,说明对yolov8的剪枝使用。

4、其他信息

4.1 剪枝器中的评价指标

在torch_pruning\pruner\importance.py中有很多个剪枝评价指标

__all__ = [
    # Base Class
    "Importance",

    # Basic Group Importance
    "GroupNormImportance",
    "GroupTaylorImportance",
    "GroupHessianImportance",

    # Aliases
    "MagnitudeImportance",
    "TaylorImportance",
    "HessianImportance",

    # Other Importance
    "BNScaleImportance",
    "LAMPImportance",
    "RandomImportance",
]

整体来看是TaylorImportance最好,一直使用该值即可。
来看

4.2 剪枝对性能精度的影响

在博客https://blog.csdn.net/a486259/article/details/140407147?spm=1001.2014.3001.5501 中基本确定了剪枝50%,对模型精度是没有任何影响的。这里对Torch-Pruning 库相关的论文数据进行二次核验,以致于分析出剪枝中速度提升对精度的影响。

以DepGraph: Towards Any Structural Pruning数据为例,可以发现最高支持6x速度剪枝后保持模型性能。
在这里插入图片描述
以LLM-Pruner: On the Structural Pruning of Large Language Models 论文数据为例,可以发现使用Vector评价方法的剪枝,移除10%的参数,zero-shot下对模型精度影响不大。而图4更表明,剪枝方法正确的话,移除50%的参数对模型性能影响也不大。
在这里插入图片描述
以论文 Structural Pruning for Diffusion Models 的数据为分析,同样可以发现剪枝50%左右的通道,对结果影响不对。
在这里插入图片描述

Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐