【深度学习总结】基于U-Mamba使用nnUNetv2处理BraTS挑战赛数据

代码地址:U-Mamba

U-Mamba介绍

Mamba出世之后,因其高效率的长程建模和特征选择性扫描能力,很多研究者将其应用于医疗图像分割领域,U-Mamba就是其中之一,它的结构如下:
在这里插入图片描述
可以看到,它在卷积后面加了一个SSM块,进行特征的选择。

数据集下载

这里我们使用的是BraTs 2019数据集,这是一个脑肿瘤的分割挑战赛,已经举办很多届了,这里不过多赘述。

在百度飞浆社区下载BraTs 2019的训练集,地址为:https://aistudio.baidu.com/aistudio/datasetdetail/67772

环境准备

下载好U-Mamba的仓库后,进行环境的配置,具体如下:

  • 首先安装好causal-conv1d和mamba-ssm,有两种办法:

    • 使用pip安装

      pip install causal-conv1d>=1.2.0
      pip install mamba-ssm --no-cache-dir
      
    • 使用whl安装(如果第一种报错的话)

      首先在causal-conv1d的官方仓库下载对应的whl文件,地址:releases,注意pytorch版本、cuda版本以及python要对应。然后执行:

      pip install 你的whl文件
      

      然后在releases下载mamba-ssm的whl文件,然后同样执行上面的pip命令。

    最后进入U-Mamba仓库的umamba文件夹,执行如下命令:

    cd umamba
    pip install -e .
    

数据集准备

要将数据集提前准备好nnUNet可以处理的形式,它的一些路径在U-Mamba/umamba/nnunetv2/paths.py文件设置,有三种文件夹:

  • nnUNet_raw:符合一定格式的输入的数据集
  • nnUNet_preprocessed:预处理后的数据集的输出地址
  • nnUNet_results:生成的模型结果

原始数据集要放在raw文件夹中,形式为:
在这里插入图片描述
然后每个数据集下有:
在这里插入图片描述
因为我使用的是BraTS2019的数据集,不是这种形式,因此要进行数据的转换。

先在nnUNet的官网下载处理代码,地址为:Dataset043_BraTS19.py,然后放在本地的dataset_conversion路径下。

从代码中可以看出,转换后的文件是被保存到nnUNet_raw中:

out_base = join(nnUNet_raw, foldername)
imagestr = join(out_base, "imagesTr")
labelstr = join(out_base, "labelsTr")

代码需要改一部分,因为数据集中文件的后缀是.gz,因此要将nii改成nii.gz。

print("copying hggs")
for c in tqdm(case_ids_hgg):
    shutil.copy(join(brats_data_dir, "HGG", c, c + "_t1.nii.gz"), join(imagestr, c + '_0000.nii.gz'))
    shutil.copy(join(brats_data_dir, "HGG", c, c + "_t1ce.nii.gz"), join(imagestr, c + '_0001.nii.gz'))
    shutil.copy(join(brats_data_dir, "HGG", c, c + "_t2.nii.gz"), join(imagestr, c + '_0002.nii.gz'))
    shutil.copy(join(brats_data_dir, "HGG", c, c + "_flair.nii.gz"), join(imagestr, c + '_0003.nii.gz'))

    copy_BraTS_segmentation_and_convert_labels_to_nnUNet(join(brats_data_dir, "HGG", c, c + "_seg.nii.gz"),
                                                         join(labelstr, c + '.nii.gz'))
print("copying lggs")
for c in tqdm(case_ids_lgg):
    shutil.copy(join(brats_data_dir, "LGG", c, c + "_t1.nii.gz"), join(imagestr, c + '_0000.nii.gz'))
    shutil.copy(join(brats_data_dir, "LGG", c, c + "_t1ce.nii.gz"), join(imagestr, c + '_0001.nii.gz'))
    shutil.copy(join(brats_data_dir, "LGG", c, c + "_t2.nii.gz"), join(imagestr, c + '_0002.nii.gz'))
    shutil.copy(join(brats_data_dir, "LGG", c, c + "_flair.nii.gz"), join(imagestr, c + '_0003.nii.gz'))

    copy_BraTS_segmentation_and_convert_labels_to_nnUNet(join(brats_data_dir, "LGG", c, c + "_seg.nii.gz"),
                                                         join(labelstr, c + '.nii.gz'))

然后运行nnUNetv2_plan_and_preprocess -d 43 --verify_dataset_integrity,如果很长时间是这个情况:

在这里插入图片描述
修改UMamba/umamba/nnunetv2/experiment_planning/plan_and_preprocess_entrypoints.py中plan_and_preprocess_entry函数的npfp,之前是8,现在改成2,然后就可以正常输出了。

parser.add_argument('-npfp', type=int, default=2, required=False,
               help='[OPTIONAL] Number of processes used for fingerprint extraction. Default: 8')

最后输出的文件内容为:
在这里插入图片描述
其中nnUNetPlans.json是配置文件。

运行

使用如下命令运行:

nnUNetv2_train 43 2d all -tr nnUNetTrainerUMambaEnc

运行成功如下:
在这里插入图片描述
在这里插入图片描述

  • 训练的主代码在U-Mamba/umamba/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py中的run_training函数中,梯度下降在train_step函数中。

其他

数据集加载在:nnunetv2/training/dataloading/nnunet_dataset.py中,其中properties_file应该是含有分割的区域信息,后缀为pkl;seg.npy是原始的分割文件。

将区域标签进行转换的代码,它最终是以区域标签进行训练的

nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py中,如下:

tr_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label]
                                                                       if ignore_label is not None else regions,
                                                                       'target', 'target'))

同时,在此之前,还将seg换成了target键,如下:

tr_transforms.append(RenameTransform('seg', 'target', True))

2D网络结构

如果你不想使用nnUNetv2来进行训练,下面的网络结构或许对你有帮助:

UMambaBot的模型结构

UMambaBot: UMambaBot(
  (encoder): UNetResEncoder(
    (stem): Sequential(
      (0): BasicResBlock(
        (conv1): Conv2d(4, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (act1): LeakyReLU(negative_slope=0.01, inplace=True)
        (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm2): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (act2): LeakyReLU(negative_slope=0.01, inplace=True)
        (conv3): Conv2d(4, 32, kernel_size=(1, 1), stride=(1, 1))
      )
      (1): BasicBlockD(
        (conv1): ConvDropoutNormReLU(
          (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
          (all_modules): Sequential(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (2): LeakyReLU(negative_slope=0.01, inplace=True)
          )
        )
        (conv2): ConvDropoutNormReLU(
          (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (all_modules): Sequential(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          )
        )
        (nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
      )
    )
    (stages): Sequential(
      (0): Sequential(
        (0): BasicResBlock(
          (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act1): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act2): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv3): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): BasicBlockD(
          (conv1): ConvDropoutNormReLU(
            (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
            (all_modules): Sequential(
              (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (2): LeakyReLU(negative_slope=0.01, inplace=True)
            )
          )
          (conv2): ConvDropoutNormReLU(
            (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (all_modules): Sequential(
              (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            )
          )
          (nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
        )
      )
      (1): Sequential(
        (0): BasicResBlock(
          (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (norm1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act1): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act2): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv3): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2))
        )
        (1): BasicBlockD(
          (conv1): ConvDropoutNormReLU(
            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
            (all_modules): Sequential(
              (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (2): LeakyReLU(negative_slope=0.01, inplace=True)
            )
          )
          (conv2): ConvDropoutNormReLU(
            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (all_modules): Sequential(
              (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            )
          )
          (nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
        )
      )
      (2): Sequential(
        (0): BasicResBlock(
          (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (norm1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act1): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act2): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv3): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2))
        )
        (1): BasicBlockD(
          (conv1): ConvDropoutNormReLU(
            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
            (all_modules): Sequential(
              (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (2): LeakyReLU(negative_slope=0.01, inplace=True)
            )
          )
          (conv2): ConvDropoutNormReLU(
            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (all_modules): Sequential(
              (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            )
          )
          (nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
        )
      )
      (3): Sequential(
        (0): BasicResBlock(
          (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (norm1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act1): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act2): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2))
        )
      )
      (4): Sequential(
        (0): BasicResBlock(
          (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (norm1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act1): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act2): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2))
        )
      )
      (5): Sequential(
        (0): BasicResBlock(
          (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (norm1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act1): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act2): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv3): Conv2d(512, 512, kernel_size=(1, 1), stride=(2, 2))
        )
      )
    )
  )
  (mamba_layer): MambaLayer(
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (mamba): Mamba(
      (in_proj): Linear(in_features=512, out_features=2048, bias=False)
      (conv1d): Conv1d(1024, 1024, kernel_size=(4,), stride=(1,), padding=(3,), groups=1024)
      (act): SiLU()
      (x_proj): Linear(in_features=1024, out_features=64, bias=False)
      (dt_proj): Linear(in_features=32, out_features=1024, bias=True)
      (out_proj): Linear(in_features=1024, out_features=512, bias=False)
    )
  )
  (decoder): UNetResDecoder(
    (encoder): UNetResEncoder(
      (stem): Sequential(
        (0): BasicResBlock(
          (conv1): Conv2d(4, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act1): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act2): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv3): Conv2d(4, 32, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): BasicBlockD(
          (conv1): ConvDropoutNormReLU(
            (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
            (all_modules): Sequential(
              (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (2): LeakyReLU(negative_slope=0.01, inplace=True)
            )
          )
          (conv2): ConvDropoutNormReLU(
            (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (all_modules): Sequential(
              (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            )
          )
          (nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
        )
      )
      (stages): Sequential(
        (0): Sequential(
          (0): BasicResBlock(
            (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (act1): LeakyReLU(negative_slope=0.01, inplace=True)
            (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (act2): LeakyReLU(negative_slope=0.01, inplace=True)
            (conv3): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
          )
          (1): BasicBlockD(
            (conv1): ConvDropoutNormReLU(
              (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
              (all_modules): Sequential(
                (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
                (2): LeakyReLU(negative_slope=0.01, inplace=True)
              )
            )
            (conv2): ConvDropoutNormReLU(
              (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (all_modules): Sequential(
                (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              )
            )
            (nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
          )
        )
        (1): Sequential(
          (0): BasicResBlock(
            (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
            (norm1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (act1): LeakyReLU(negative_slope=0.01, inplace=True)
            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (act2): LeakyReLU(negative_slope=0.01, inplace=True)
            (conv3): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2))
          )
          (1): BasicBlockD(
            (conv1): ConvDropoutNormReLU(
              (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
              (all_modules): Sequential(
                (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
                (2): LeakyReLU(negative_slope=0.01, inplace=True)
              )
            )
            (conv2): ConvDropoutNormReLU(
              (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (all_modules): Sequential(
                (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              )
            )
            (nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
          )
        )
        (2): Sequential(
          (0): BasicResBlock(
            (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
            (norm1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (act1): LeakyReLU(negative_slope=0.01, inplace=True)
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (act2): LeakyReLU(negative_slope=0.01, inplace=True)
            (conv3): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2))
          )
          (1): BasicBlockD(
            (conv1): ConvDropoutNormReLU(
              (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
              (all_modules): Sequential(
                (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
                (2): LeakyReLU(negative_slope=0.01, inplace=True)
              )
            )
            (conv2): ConvDropoutNormReLU(
              (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (all_modules): Sequential(
                (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              )
            )
            (nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
          )
        )
        (3): Sequential(
          (0): BasicResBlock(
            (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
            (norm1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (act1): LeakyReLU(negative_slope=0.01, inplace=True)
            (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (act2): LeakyReLU(negative_slope=0.01, inplace=True)
            (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2))
          )
        )
        (4): Sequential(
          (0): BasicResBlock(
            (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
            (norm1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (act1): LeakyReLU(negative_slope=0.01, inplace=True)
            (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (act2): LeakyReLU(negative_slope=0.01, inplace=True)
            (conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2))
          )
        )
        (5): Sequential(
          (0): BasicResBlock(
            (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
            (norm1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (act1): LeakyReLU(negative_slope=0.01, inplace=True)
            (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (act2): LeakyReLU(negative_slope=0.01, inplace=True)
            (conv3): Conv2d(512, 512, kernel_size=(1, 1), stride=(2, 2))
          )
        )
      )
    )
    (stages): ModuleList(
      (0): Sequential(
        (0): BasicResBlock(
          (conv1): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act1): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act2): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv3): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): BasicBlockD(
          (conv1): ConvDropoutNormReLU(
            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
            (all_modules): Sequential(
              (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (2): LeakyReLU(negative_slope=0.01, inplace=True)
            )
          )
          (conv2): ConvDropoutNormReLU(
            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (all_modules): Sequential(
              (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            )
          )
          (nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
        )
      )
      (1): Sequential(
        (0): BasicResBlock(
          (conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act1): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act2): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv3): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): BasicBlockD(
          (conv1): ConvDropoutNormReLU(
            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
            (all_modules): Sequential(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (2): LeakyReLU(negative_slope=0.01, inplace=True)
            )
          )
          (conv2): ConvDropoutNormReLU(
            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (all_modules): Sequential(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            )
          )
          (nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
        )
      )
      (2): Sequential(
        (0): BasicResBlock(
          (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act1): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act2): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv3): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): BasicBlockD(
          (conv1): ConvDropoutNormReLU(
            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
            (all_modules): Sequential(
              (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (2): LeakyReLU(negative_slope=0.01, inplace=True)
            )
          )
          (conv2): ConvDropoutNormReLU(
            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (all_modules): Sequential(
              (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            )
          )
          (nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
        )
      )
      (3): Sequential(
        (0): BasicResBlock(
          (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act1): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act2): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv3): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (4): Sequential(
        (0): BasicResBlock(
          (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act1): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act2): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv3): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
        )
      )
    )
    (upsample_layers): ModuleList(
      (0): UpsampleLayer(
        (conv): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
      )
      (1): UpsampleLayer(
        (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
      )
      (2): UpsampleLayer(
        (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
      )
      (3): UpsampleLayer(
        (conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
      )
      (4): UpsampleLayer(
        (conv): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
      )
    )
    (seg_layers): ModuleList(
      (0): Conv2d(512, 3, kernel_size=(1, 1), stride=(1, 1))
      (1): Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1))
      (2): Conv2d(128, 3, kernel_size=(1, 1), stride=(1, 1))
      (3): Conv2d(64, 3, kernel_size=(1, 1), stride=(1, 1))
      (4): Conv2d(32, 3, kernel_size=(1, 1), stride=(1, 1))
    )
  )
)

UMambaEnc的模型结构

feature_map_sizes: [[192, 160], [96, 80], [48, 40], [24, 20], [12, 10], [6, 5]]
do_channel_token: [False, False, False, False, True, True]
MambaLayer: dim: 64
MambaLayer: dim: 256
MambaLayer: dim: 30
UMambaEnc: UMambaEnc(
  (encoder): ResidualMambaEncoder(
    (stem): Sequential(
      (0): BasicResBlock(
        (conv1): Conv2d(4, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (act1): LeakyReLU(negative_slope=0.01, inplace=True)
        (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm2): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (act2): LeakyReLU(negative_slope=0.01, inplace=True)
        (conv3): Conv2d(4, 32, kernel_size=(1, 1), stride=(1, 1))
      )
      (1): BasicBlockD(
        (conv1): ConvDropoutNormReLU(
          (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
          (all_modules): Sequential(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (2): LeakyReLU(negative_slope=0.01, inplace=True)
          )
        )
        (conv2): ConvDropoutNormReLU(
          (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (all_modules): Sequential(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          )
        )
        (nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
      )
    )
    (mamba_layers): ModuleList(
      (0): Identity()
      (1): MambaLayer(
        (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (mamba): Mamba(
          (in_proj): Linear(in_features=64, out_features=256, bias=False)
          (conv1d): Conv1d(128, 128, kernel_size=(4,), stride=(1,), padding=(3,), groups=128)
          (act): SiLU()
          (x_proj): Linear(in_features=128, out_features=36, bias=False)
          (dt_proj): Linear(in_features=4, out_features=128, bias=True)
          (out_proj): Linear(in_features=128, out_features=64, bias=False)
        )
      )
      (2): Identity()
      (3): MambaLayer(
        (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mamba): Mamba(
          (in_proj): Linear(in_features=256, out_features=1024, bias=False)
          (conv1d): Conv1d(512, 512, kernel_size=(4,), stride=(1,), padding=(3,), groups=512)
          (act): SiLU()
          (x_proj): Linear(in_features=512, out_features=48, bias=False)
          (dt_proj): Linear(in_features=16, out_features=512, bias=True)
          (out_proj): Linear(in_features=512, out_features=256, bias=False)
        )
      )
      (4): Identity()
      (5): MambaLayer(
        (norm): LayerNorm((30,), eps=1e-05, elementwise_affine=True)
        (mamba): Mamba(
          (in_proj): Linear(in_features=30, out_features=120, bias=False)
          (conv1d): Conv1d(60, 60, kernel_size=(4,), stride=(1,), padding=(3,), groups=60)
          (act): SiLU()
          (x_proj): Linear(in_features=60, out_features=34, bias=False)
          (dt_proj): Linear(in_features=2, out_features=60, bias=True)
          (out_proj): Linear(in_features=60, out_features=30, bias=False)
        )
      )
    )
    (stages): ModuleList(
      (0): Sequential(
        (0): BasicResBlock(
          (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act1): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act2): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv3): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): BasicBlockD(
          (conv1): ConvDropoutNormReLU(
            (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
            (all_modules): Sequential(
              (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (2): LeakyReLU(negative_slope=0.01, inplace=True)
            )
          )
          (conv2): ConvDropoutNormReLU(
            (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (all_modules): Sequential(
              (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            )
          )
          (nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
        )
      )
      (1): Sequential(
        (0): BasicResBlock(
          (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (norm1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act1): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act2): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv3): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2))
        )
        (1): BasicBlockD(
          (conv1): ConvDropoutNormReLU(
            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
            (all_modules): Sequential(
              (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (2): LeakyReLU(negative_slope=0.01, inplace=True)
            )
          )
          (conv2): ConvDropoutNormReLU(
            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (all_modules): Sequential(
              (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            )
          )
          (nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
        )
      )
      (2): Sequential(
        (0): BasicResBlock(
          (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (norm1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act1): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act2): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv3): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2))
        )
        (1): BasicBlockD(
          (conv1): ConvDropoutNormReLU(
            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
            (all_modules): Sequential(
              (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (2): LeakyReLU(negative_slope=0.01, inplace=True)
            )
          )
          (conv2): ConvDropoutNormReLU(
            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (all_modules): Sequential(
              (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            )
          )
          (nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
        )
      )
      (3): Sequential(
        (0): BasicResBlock(
          (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (norm1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act1): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act2): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2))
        )
      )
      (4): Sequential(
        (0): BasicResBlock(
          (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (norm1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act1): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act2): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2))
        )
      )
      (5): Sequential(
        (0): BasicResBlock(
          (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (norm1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act1): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act2): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv3): Conv2d(512, 512, kernel_size=(1, 1), stride=(2, 2))
        )
      )
    )
  )
  (decoder): UNetResDecoder(
    (encoder): ResidualMambaEncoder(
      (stem): Sequential(
        (0): BasicResBlock(
          (conv1): Conv2d(4, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act1): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act2): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv3): Conv2d(4, 32, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): BasicBlockD(
          (conv1): ConvDropoutNormReLU(
            (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
            (all_modules): Sequential(
              (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (2): LeakyReLU(negative_slope=0.01, inplace=True)
            )
          )
          (conv2): ConvDropoutNormReLU(
            (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (all_modules): Sequential(
              (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            )
          )
          (nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
        )
      )
      (mamba_layers): ModuleList(
        (0): Identity()
        (1): MambaLayer(
          (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (mamba): Mamba(
            (in_proj): Linear(in_features=64, out_features=256, bias=False)
            (conv1d): Conv1d(128, 128, kernel_size=(4,), stride=(1,), padding=(3,), groups=128)
            (act): SiLU()
            (x_proj): Linear(in_features=128, out_features=36, bias=False)
            (dt_proj): Linear(in_features=4, out_features=128, bias=True)
            (out_proj): Linear(in_features=128, out_features=64, bias=False)
          )
        )
        (2): Identity()
        (3): MambaLayer(
          (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (mamba): Mamba(
            (in_proj): Linear(in_features=256, out_features=1024, bias=False)
            (conv1d): Conv1d(512, 512, kernel_size=(4,), stride=(1,), padding=(3,), groups=512)
            (act): SiLU()
            (x_proj): Linear(in_features=512, out_features=48, bias=False)
            (dt_proj): Linear(in_features=16, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=256, bias=False)
          )
        )
        (4): Identity()
        (5): MambaLayer(
          (norm): LayerNorm((30,), eps=1e-05, elementwise_affine=True)
          (mamba): Mamba(
            (in_proj): Linear(in_features=30, out_features=120, bias=False)
            (conv1d): Conv1d(60, 60, kernel_size=(4,), stride=(1,), padding=(3,), groups=60)
            (act): SiLU()
            (x_proj): Linear(in_features=60, out_features=34, bias=False)
            (dt_proj): Linear(in_features=2, out_features=60, bias=True)
            (out_proj): Linear(in_features=60, out_features=30, bias=False)
          )
        )
      )
      (stages): ModuleList(
        (0): Sequential(
          (0): BasicResBlock(
            (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (act1): LeakyReLU(negative_slope=0.01, inplace=True)
            (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (act2): LeakyReLU(negative_slope=0.01, inplace=True)
            (conv3): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
          )
          (1): BasicBlockD(
            (conv1): ConvDropoutNormReLU(
              (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
              (all_modules): Sequential(
                (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
                (2): LeakyReLU(negative_slope=0.01, inplace=True)
              )
            )
            (conv2): ConvDropoutNormReLU(
              (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (all_modules): Sequential(
                (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              )
            )
            (nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
          )
        )
        (1): Sequential(
          (0): BasicResBlock(
            (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
            (norm1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (act1): LeakyReLU(negative_slope=0.01, inplace=True)
            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (act2): LeakyReLU(negative_slope=0.01, inplace=True)
            (conv3): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2))
          )
          (1): BasicBlockD(
            (conv1): ConvDropoutNormReLU(
              (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
              (all_modules): Sequential(
                (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
                (2): LeakyReLU(negative_slope=0.01, inplace=True)
              )
            )
            (conv2): ConvDropoutNormReLU(
              (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (all_modules): Sequential(
                (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              )
            )
            (nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
          )
        )
        (2): Sequential(
          (0): BasicResBlock(
            (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
            (norm1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (act1): LeakyReLU(negative_slope=0.01, inplace=True)
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (act2): LeakyReLU(negative_slope=0.01, inplace=True)
            (conv3): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2))
          )
          (1): BasicBlockD(
            (conv1): ConvDropoutNormReLU(
              (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
              (all_modules): Sequential(
                (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
                (2): LeakyReLU(negative_slope=0.01, inplace=True)
              )
            )
            (conv2): ConvDropoutNormReLU(
              (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (all_modules): Sequential(
                (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              )
            )
            (nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
          )
        )
        (3): Sequential(
          (0): BasicResBlock(
            (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
            (norm1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (act1): LeakyReLU(negative_slope=0.01, inplace=True)
            (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (act2): LeakyReLU(negative_slope=0.01, inplace=True)
            (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2))
          )
        )
        (4): Sequential(
          (0): BasicResBlock(
            (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
            (norm1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (act1): LeakyReLU(negative_slope=0.01, inplace=True)
            (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (act2): LeakyReLU(negative_slope=0.01, inplace=True)
            (conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2))
          )
        )
        (5): Sequential(
          (0): BasicResBlock(
            (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
            (norm1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (act1): LeakyReLU(negative_slope=0.01, inplace=True)
            (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (act2): LeakyReLU(negative_slope=0.01, inplace=True)
            (conv3): Conv2d(512, 512, kernel_size=(1, 1), stride=(2, 2))
          )
        )
      )
    )
    (stages): ModuleList(
      (0): Sequential(
        (0): BasicResBlock(
          (conv1): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act1): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act2): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv3): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): BasicBlockD(
          (conv1): ConvDropoutNormReLU(
            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
            (all_modules): Sequential(
              (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (2): LeakyReLU(negative_slope=0.01, inplace=True)
            )
          )
          (conv2): ConvDropoutNormReLU(
            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (all_modules): Sequential(
              (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            )
          )
          (nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
        )
      )
      (1): Sequential(
        (0): BasicResBlock(
          (conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act1): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act2): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv3): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): BasicBlockD(
          (conv1): ConvDropoutNormReLU(
            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
            (all_modules): Sequential(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (2): LeakyReLU(negative_slope=0.01, inplace=True)
            )
          )
          (conv2): ConvDropoutNormReLU(
            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (all_modules): Sequential(
              (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            )
          )
          (nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
        )
      )
      (2): Sequential(
        (0): BasicResBlock(
          (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act1): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act2): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv3): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): BasicBlockD(
          (conv1): ConvDropoutNormReLU(
            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
            (all_modules): Sequential(
              (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
              (2): LeakyReLU(negative_slope=0.01, inplace=True)
            )
          )
          (conv2): ConvDropoutNormReLU(
            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            (all_modules): Sequential(
              (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
            )
          )
          (nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
        )
      )
      (3): Sequential(
        (0): BasicResBlock(
          (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act1): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act2): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv3): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (4): Sequential(
        (0): BasicResBlock(
          (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act1): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (act2): LeakyReLU(negative_slope=0.01, inplace=True)
          (conv3): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
        )
      )
    )
    (upsample_layers): ModuleList(
      (0): UpsampleLayer(
        (conv): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
      )
      (1): UpsampleLayer(
        (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
      )
      (2): UpsampleLayer(
        (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
      )
      (3): UpsampleLayer(
        (conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
      )
      (4): UpsampleLayer(
        (conv): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
      )
    )
    (seg_layers): ModuleList(
      (0): Conv2d(512, 3, kernel_size=(1, 1), stride=(1, 1))
      (1): Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1))
      (2): Conv2d(128, 3, kernel_size=(1, 1), stride=(1, 1))
      (3): Conv2d(64, 3, kernel_size=(1, 1), stride=(1, 1))
      (4): Conv2d(32, 3, kernel_size=(1, 1), stride=(1, 1))
    )
  )
)
Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐