目录

SwinTransformer之CV模型详解

第一代CV大模型:Vision Transformer

第二代CV大模型:Swin Transformer

两代模型PK(VIT和Swin Transformer)

Swin Transformer是什么CV模型?

Swin Transformer应用场景是什么?

Swin Transformer到底解决了什么问题?

Swin Transformer网络架构

Patch Embbeding介绍

window_partition介绍

W-MSA(Window Multi-head Self Attention)

Window_reverse

SW-MSA(Shifted Window Multi-head Self Attention)

模型参数

核心代码讲解


SwinTransformer视觉大模型详解

第一代CV大模型:Vision Transformer

温馨提示:如果您不了解Transformer的黑科技,请补一下原理:AI大模型的知识科普(深入浅出讲原理)-CSDN博客

一、Vision Transformer如何工作?

    
    Transformer模型最开始是用于自然语言处理(NLP)领域的,NLP主要处理的是文本、句子、段落等,即序列数据。但是视觉领域处理的是图像数据,因此将Transformer模型应用到CV领域(图像数据处理)上面临着诸多挑战,分析如下:

1. 与单词、句子、段落等文本数据不同,图像中包含更多的信息,并且是以像素值的形式呈现。


2. 如果按照处理文本的方式来处理图像,即逐像素处理的话,即使是目前的硬件条件也很难。


3. Transformer缺少CNN的归纳偏差,比如平移不变性和局部受限感受野。


4. CNN是通过相似的卷积操作来提取特征,随着模型层数的加深,感受野也会逐步增加。但是由于Transformer的本质,其在计算量上会比CNN更大。


5. Transformer无法直接用于处理基于网格的数据,比如图像数据。

    总结一下,Transformer与卷积神经网络(CNN)有许多不同之处,其主要优势包括:

1. 更好的处理序列数据能力:Transformer架构在序列数据建模方面表现非常出色,它通过自注意力机制对序列中的不同位置进行加权处理,从而实现了更好的序列建模能力。相比之下,CNN对于序列建模的能力较弱,主要用于图像等非序列数据的处理。

2. 并行计算能力:Transformer中的自注意力机制允许每个时间步进行并行计算,因此Transformer的训练速度相对于CNN要更快。相比之下,CNN需要在每个时间步上执行串行卷积操作,这使得CNN在处理较长的序列时计算效率较低。

3. 更好的处理长距离依赖关系的能力:Transformer中的自注意力机制允许模型从序列中任意位置获取信息,这使得Transformer能够更好地处理长距离依赖关系,而CNN则需要通过增加卷积层数来处理这种长距离依赖。

4. 更容易扩展到其他任务:由于Transformer在序列建模方面表现优异,它在许多NLP任务中表现出色,如机器翻译、语言模型等。相比之下,CNN主要用于计算机视觉领域,如图像分类、目标检测等。因此,Transformer更容易扩展到处理其他NLP任务,而CNN则需要进行更多的改进才能适用于NLP任务。

    CNN的在处理VC大模型遇到了困境的原因分析:卷积进行中,越来越多的网络结构,必须堆叠多层卷积,逐层对特征图进行处理中,感受野才不断增大,慢慢才有了全局的信息提取;从小规模数据开始,进行模型训练。
    Transfomer网络处理VC的大模型优势表现突出,是因为从第一层开始,就全局计算序列中各个向量的关联权重。但是需要足够多的数据,全局学习需要非常大量的数据才能表现卓越,这是所有论文中模型的测试效果好的前提条件。预训练模型开始,对其微调就可以适合个性化场景。

    总之,Transformer和CNN在不同的任务中表现出色,但在处理序列数据方面,Transformer具有更好的建模能力和计算效率,可以处理更长的序列,更容易扩展到其他NLP任务。


二、Vision Transformer是第一代CV大模型

    为了解决上述问题,Google的研究团队提出了ViT模型。ViT是谷歌提出的把Transformer应用到图像分类的模型,虽然不是第一篇将transformer应用在视觉任务的论文,但是因为其模型“简单”且效果好,可扩展性强(模型越大效果越好),成为了transformer在CV领域应用的里程碑著作。

    ViT原论文中最核心的结论是,当拥有足够多的数据进行预训练的时候,ViT的表现就会超过CNN,突破transformer缺少归纳偏置的限制,可以在下游任务中获得较好的迁移效果。但是当训练数据集不够大的时候,ViT的表现通常比同等大小的ResNets要差一些,因为Transformer和CNN相比缺少归纳偏置(inductive bias),即一种先验知识,提前做好的假设。CNN具有两种归纳偏置一种是局部性,即图片上相邻区域具有相似的特征,一种是平移不变性,CNN具有上面两种归纳偏置,就有了很多先验信息,需要相对少的数据就可以学习到一个比较好的模型

    对比CNN,ViT表现出更强的性能,这是由于以下几个原因:

1. 全局视野和长距离依赖:ViT引入了Transform模型的注意力机制,可以对整个图像的全局信息进行建模。相比之下,CNN在处理图像时使用局部感受野,只能捕捉图像的局部特征。ViT通过自注意力层可以建立全局关系,并学习图像中不同区域之间的长距离依赖关系,从而更好地理解图像的结构和语义。
2. 可学习的位置编码:ViT通过对输入图像块进行位置编码,将位置信息引入模型中。这使得ViT可以处理不同位置的图像块,并学习它们之间的位置关系。相比之下,CNN在卷积和池化过程中会导致空间信息的丢失,对位置不敏感。

3. 数据效率和泛化能力:
ViT在大规模数据集上展现出出色的泛化能力。由于ViT基于Transform模型,它可以从大量的数据中学习到更丰富、更复杂的图像特征表示。相比之下,CNN在小样本数据集上可能需要更多的数据和调优才能取得好的结果。


4. 可解释性和可调节性:
ViT的自注意机制使其在解释模型预测和注意力权重时具有优势。相比之下,CNN的特征表示通常较难解释,因为它们是通过卷积和池化操作获得的。

三、ViT模型架构


我们先结合下面的动图来粗略地分析一下ViT的工作流程,如下:

1. 将一张图片分成patches;
2. 将patches铺平;
3. 将铺平后的patches的线性映射到更低维的空间;
4. 添加位置embedding编码信息;
5. 将图像序列数据送入标准Transformer encoder中去;
6. 在较大的数据集上预训练;
7. 在下游数据集上微调用于图像分类;

四、ViT模型简洁代码架构
## from https://github.com/lucidrains/vit-pytorch
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary

 # einops张量操作神器
# helpers


def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.1):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x): ## 最重要的都是forword函数了
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        ## 对tensor张量分块 x :1 197 1024   qkv 最后 是一个元组,tuple,长度是3,每个元素形状:1 197 1024
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
        # 分成多少个Head,与TRM生成qkv 的方式不同, 要更简单,不需要区分来自Encoder还是Decoder

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x
# 1. VIT整体架构从这里开始
class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        # 初始化函数内,是将输入的图片,得到 img_size ,patch_size 的宽和高
        image_height, image_width = pair(image_size) ## 224*224 *3
        patch_height, patch_width = pair(patch_size)## 16 * 16  *3
        #图像尺寸必须能被patch大小整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width) ## 步骤1.一个图像 分成 N 个patch
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),# 步骤2.1将patch 铺开
            nn.Linear(patch_dim, dim), # 步骤2.2 然后映射到指定的embedding的维度
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)  ## img 1 3 224 224  输出形状x : 1 196 1024
        b, n, _ = x.shape ## 
        #将cls 复制 batch_size 份
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        # 将cls token在维度1 扩展到输入上
        x = torch.cat((cls_tokens, x), dim=1)
        # 添加位置编码
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)
        # 输入TRM
        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)



v = ViT(
    image_size = 224,
    patch_size = 16,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 224, 224)

preds = v(img)   # (1, 1000)



第二代CV大模型:Swin Transformer

Swin Transformer是2021年微软研究院发表在ICCV上的一篇best paper。该论文已在多项视觉任务中霸榜(分类、检测、分割)。

《Swin Transformer: Hierarchical Vision Transformer using Shifted Windows》
论文地址:https://arxiv.org/pdf/2103.14030.pdf

两代模型PK(VIT和Swin Transformer)


1. 图像分块方式不同

VIT模型将图像分成固定大小的小块,每个小块都被视为一个“图像片段”,并通过Transformer编码器进行处理。而Swin Transformer模型采用了一种新的分块方式,称为“局部窗口注意力”,它将图像分成一系列大小相同的局部块

2. Transformer编码器的层数不同

VIT模型中使用的Transformer编码器层数较少,通常只有12层。而Swin Transformer模型中使用了更多的Transformer编码器层,通常为24层或48层。

3. 模型的参数量不同

由于Swin Transformer模型采用了更多的Transformer编码器层,因此其参数量比VIT模型更大。例如,Swin Transformer模型中的最大模型参数量可以达到1.5亿,而VIT模型中的最大模型参数量只有1.2亿。

4. 模型的性能不同

在ImageNet数据集上进行的实验表明,Swin Transformer模型的性能优于VIT模型。例如,在ImageNet-1K上,Swin Transformer模型的Top-1准确率为87.4%,而VIT模型的Top-1准确率为85.8%。

最后总结,二者的不同之处:

首先,Swin-Transformer所构建的特征图是具有层次性的,很像我们之前将的卷积神经网络那样,随着特征提取层的不断加深,特征图的尺寸是越来越小的(4x、8x、16x下采样)。正因为Swin Transformer拥有像CNN这样的下采样特性,能够构建出具有层次性的特征图。在论文中作者提到,这样的好处就是:正是因为这样具有层次的特征图,Swin Transformer对于目标检测和分割任务相比ViT有更大的优势。
在ViT模型中,是直接对特征图下采样16倍,在后面的结构中也一致保持这样的下采样规律不变(只有16x下采样,不Swin Transformer那样有多种下采样尺度 -> 这样就导致ViT不能构建出具有层次性的特征图

其次,在Swin Transformer的特征图中,它是用一个个窗口的形式将特征图分割开的。窗口与窗口之间是没有重叠的。而在ViT中,特征图是是一个整体,并没有对其进行分割。其中的窗口(Window)就是我们一会儿要讲的Windows Multi-head Self-attention。引入该结构之后,Swin Transformer就可以在每个Window的内部进行Multi-head Self-Attention的计算。Window与Window之间是不进行信息的传递的。这样做的好处是:可以大大降低运算量,尤其是在浅层网络,下采样倍率比较低的时候,相比ViT直接针对整张特征图进行Multi-head Self-Attention而言,能够减少计算量。

Swin Transformer是什么CV模型?

Swin Transformer是一种为视觉领域设计的分层Transformer结构。它的两大特性是滑动窗口和分层表示。滑动窗口在局部不重叠的窗口中计算自注意力,并允许跨窗口连接。分层结构允许模型适配不同尺度的图片,并且计算复杂度与图像大小呈线性关系。Swin Transformer借鉴了CNN的分层结构,不仅能够做分类,还能够和CNN一样扩展到下游任务,用于计算机视觉任务的通用主干网络,可以用于图像分类、图像分割、目标检测等一系列视觉下游任务。

Swin Transformer应用场景是什么?

Swin-Transformer是一种通过不重叠的和重叠的滑窗操作实现在一个窗口中注意力机制计算的Transformer模型。它作为计算机视觉的通用骨干网络Backbone在物体分类、目标检测、语义和实例分割和目标跟踪等任务中取得很好的性能和效果,所以Swin-Transformer大有取代CNN的趋势。不仅源码公开了,预训练模型也公开了,预训练模型提供大中小三个版本。

Swin-Transformer以及swin-transformer-ocr的工程源码地址分别为https://github.com/microsoft/Swin-Transformer.githttps://github.com/YongWookHa/swin-transformer-ocr.git

Swin Transformer到底解决了什么问题?

1. 超高分辨率的图像所带来的计算量问题,怎么办?

答:参考卷积网络的工作方式,获得全局注意力能力的同时,又将计算量从图像大小的平方关系降为线性关系,大大地减少了运算量,串联窗口自注意力运算(W-MSA)以及滑动窗口自注意力运算(SW-MSA)。

2. 最初的Vision Transformer是不具备多尺度预测,怎么办?

答:通过特征融合的方式PatchMerging(可参考卷积网络里的池化操作),每次特征抽取之后都进行一次下采样,增加了下一次窗口注意力运算在原始图像上的感受野,从而对输入图像进行了多尺度的特征提取。

3. 核心技术是什么?

SwinTransformer 针对ViT使用了“窗口”和“分层”的方式来替代长序列进行改进。

Swin Transformer网络架构

  1. 输入:首先输入还是一张图像数据,224(宽) ∗ 224(高) ∗ 3(通道) 
  2. 处理过程:通过卷积得到多个特征图,把特征图分成每个Patch,堆叠Swin Transformer Block,与Swin TransformerBlock在每次堆叠后长宽减半,特征图个数翻倍。
  3. Block含义:最核心的部分是对Attention的计算方法做出了改进,每个Block包括了一个W-MSA和一个SW-MSA,成对组合才能串联成一个Block。W-MSA是基于窗口的注意力计算。SW-MSA是窗口滑动后重新计算注意力。
Patch Embbeding介绍
  1. 输入:图像数据(224,224,3)
  2. 输出:(3136,96)相当于序列长度是3136个,每个的向量是96维特征
  3. 处理过程:通过卷积得到,Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4)),3136也就是 (224/4) * (224/4)得到的,也可以根据需求更改卷积参数
  4. 实际上就是一个下采样的操作,是不同于池化,这个相当于间接的对H和W维度进行间隔采样后拼接在一起,得到H/2,W/2,C*4。

window_partition介绍
  1. 输入:特征图(56,56,96)
  2. 默认窗口大小为7,所以总共可以分成8*8个窗口
  3. 输出:特征图(64,7,7,96)
  4. 处理过程:之前的单位是序列,现在的单位是窗口(共64个窗口),56=224/4,5656分成每个都是7*7大小的窗口,一共可以的得到8*8的窗口,因此输出为(64,7,7,96),因此输入变成了64个窗口不再是序列了。
W-MSAWindow Multi-head Self Attention
  1. 对得到的窗口,计算各个窗口自己的自注意力得分。
  2. qkv三个矩阵放在一起了:(3,64,3,49,32),3个矩阵,64个窗口,heads为3,窗口大小7*7=49,每个head特征96/3=32。
  3. attention结果为:(64,3,49,49) 每个头都会得出每个窗口内的自注意力
  4. 原来有64个窗口,每个窗口都是7*7的大小,对每个窗口都进行Self Attention的计算(3,64,3,49,32),第一个3表示的是QKV这3个,64代表64个窗口,第二个3表示的是多头注意力的头数,49就是77的大小,每头注意力机制对应32维的向量。
  5. attention权重矩阵维度(64,3,49,49),64表示64个窗口,3还是表示的是多头注意力的头数,49*49表示每一个窗口的49个特征之间的关系

Window_reverse
  1. 通过得到的attention计算得到新的特征(64,49,96),总共64个窗口,每个窗口7*7的大小,每个点对应96维向量。
  2. window_reverse就是通过reshape操作还原回去(56,56,96),还原的目的是为了循环,得到了跟输入特征图一样的大小,但是其已经计算过了attention,attention权重与(3,64,3,49,32)乘积结果为(64,49,96),这是新的特征的维度,96还是表示每个向量的维度,这个时候的特征已经经过重构,96表示了在一个窗口的每个像素与每个像素之间的关系。
SW-MSAShifted Window Multi-head Self Attention

原因分析:为什么要shift?原来的window都是算自己内部的,这样就会导致只有内部计算,没有它们之间的关系,容易上模型局限在自己的小领地,可以通过shift操作来改善

通过W-MSA我们得到的是每个窗口内的特征,还没有每个窗口与窗口之间的特征,SW-MSA就是用来得到每个窗口与窗口之间的特征。窗口与窗口之间的特征,是用一种滑动shift 的方式计算。

处理过程:实际上SW-MSA的偏移就是窗口在水平和垂直方向上分别偏移一定数量的像素,不管是SW-MSA还是W-MSA,实际上都是在做self-Attention的计算,只不过W-MSA是只对一个窗口内部做self-Attention的计算,SW-MSA是使用了一种偏移的方式,但是还是对一个窗口内部做self-Attention的计算。

实际上就是像素点发生了挪动

如图所示,红色线是窗口的分割,灰色是patch的分割,W-MSA将相邻的patch进行拼凑成窗口,但是这就导致了,窗口之间没有办法连接,SW-MSA的偏移计算会重新划分窗口,但是窗口不可以重叠的情况下,窗口由4个变成了9个。窗口的数量和大小都发生了变化,如图所示原文给出了一个办法,将窗口的大小做出了限制。

论文中使用了pad和mask的方法解决了这一问题,如上图中cyclic shift部分,对边缘部分尺寸较小的windows进行了填充(图中蓝色、绿色和黄色部分),使得每个windows都能够保持原来的大小,并且论文还采用了mask的方法来使得模型只在除了pad的部分做self-attention计算,这样一来就能够解决上面所提到的问题。

如图所示,4自始至终都没有改变,原来在W-MSA使用self-Attention进行计算,在SW-MSA还是使用self-Attention进行计算,但是比如1和7发生了变化,7和1的计算,假如了mask和padding的一些处理。一开始是4个窗口,经过偏移后变成了9个,但是计算不方便,还是按照4个窗口进行计算,多出来的值mask掉就行了。

所以一个Swin Transformer Block就是先后经过W-MSA和SW-MSA,而Swin Transformer主要就是Swin Transformer Block的堆叠。

模型参数

以下展示了Swin Transformer的模型参数,分为四中不同规模:Tiny、Small、Base、Larger。如Swin-T:concat为Patch Partition和Patch Merging操作,4×4表明高和宽变为原来的1/4,96-d表示输出通道为96维。下面×2表示堆叠两个Swin Transformer Block,窗口大小维7×7,输出通道维度为96,多头注意力机制的头数为3,其他的都类似。需要注意的是,在堆叠Swin Transformer Block时,含SW-MSA的块和含W-MSA的块是成对进行的,因此每一个stage的堆叠数都是偶数。(即就是第一块是W-MSA的Block时,则下一个块必须为SW-MSA)

核心代码讲解

 1. Patch Partition代码模块

class PatchEmbed(nn.Module):
    """
    2D Image to Patch Embedding
    split image into non-overlapping patches   即将图片划分成一个个没有重叠的patch
    """
    def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):
        super().__init__()
        patch_size = (patch_size, patch_size)
        self.patch_size = patch_size
        self.in_chans = in_c
        self.embed_dim = embed_dim
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        _, _, H, W = x.shape

        # padding
        # 如果输入图片的H,W不是patch_size的整数倍,需要进行padding
        pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
        if pad_input:
            # to pad the last 3 dimensions,
            # (W_left, W_right, H_top,H_bottom, C_front, C_back)
            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],   # 表示宽度方向右侧填充数
                          0, self.patch_size[0] - H % self.patch_size[0],   # 表示高度方向底部填充数
                          0, 0))

        # 下采样patch_size倍
        x = self.proj(x)
        _, _, H, W = x.shape
        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x, H, W

2. Patch Merging代码模块

class PatchMerging(nn.Module):
    r""" Patch Merging Layer.
        步长为2,间隔采样
    Args:
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x, H, W):
        """
        x: B, H*W, C    即输入x的通道排列顺序
        """
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)

        # padding
        # 如果输入feature map的H,W不是2的整数倍,需要进行padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            # to pad the last 3 dimensions, starting from the last dimension and moving forward.
            # (C_front, C_back, W_left, W_right, H_top, H_bottom)
            # 注意这里的Tensor通道是[B, H, W, C],所以会和官方文档有些不同
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))

        # 以2为间隔进行采样
        x0 = x[:, 0::2, 0::2, :]  # [B, H/2, W/2, C]
        x1 = x[:, 1::2, 0::2, :]  # [B, H/2, W/2, C]
        x2 = x[:, 0::2, 1::2, :]  # [B, H/2, W/2, C]
        x3 = x[:, 1::2, 1::2, :]  # [B, H/2, W/2, C]
        x = torch.cat([x0, x1, x2, x3], -1)  #  ————————>  [B, H/2, W/2, 4*C]   在channael维度上进行拼接
        x = x.view(B, -1, 4 * C)  # [B, H/2*W/2, 4*C]

        x = self.norm(x)
        x = self.reduction(x)  # [B, H/2*W/2, 2*C]

        return x

    def create_mask(self, x, H, W):
        # calculate attention mask for SW-MSA
        # 保证Hp和Wp是window_size的整数倍
        Hp = int(np.ceil(H / self.window_size)) * self.window_size
        Wp = int(np.ceil(W / self.window_size)) * self.window_size
        # 拥有和feature map一样的通道排列顺序,方便后续window_partition
        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # [1, Hp, Wp, 1]
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        # 将img_mask划分成一个一个窗口
        mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]           # 输出的是按照指定的window_size划分成一个一个窗口的数据
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # [nW, Mh*Mw]
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]  使用了广播机制
        # [nW, Mh*Mw, Mh*Mw]
        # 因为需要求得的是自身注意力机制,所以,所以相同的区域使用0表示,;不同的区域不等于0,填入-100,这样,在求得
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))   # 即对于不等于0的位置,赋值为-100;否则为0
        return attn_mask

3. mask掩码生成和stage堆叠的代码模块

  def create_mask(self, x, H, W):
        # calculate attention mask for SW-MSA
        # 保证Hp和Wp是window_size的整数倍
        Hp = int(np.ceil(H / self.window_size)) * self.window_size
        Wp = int(np.ceil(W / self.window_size)) * self.window_size
        # 拥有和feature map一样的通道排列顺序,方便后续window_partition
        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # [1, Hp, Wp, 1]
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        # 将img_mask划分成一个一个窗口
        mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]           # 输出的是按照指定的window_size划分成一个一个窗口的数据
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # [nW, Mh*Mw]
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]  使用了广播机制
        # [nW, Mh*Mw, Mh*Mw]
        # 因为需要求得的是自身注意力机制,所以,所以相同的区域使用0表示,;不同的区域不等于0,填入-100,这样,在求得
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))   # 即对于不等于0的位置,赋值为-100;否则为0
        return attn_mask

4.stage堆叠部分代码:

class BasicLayer(nn.Module):
    """
    A basic Swin Transformer layer for one stage.

    Args:
        dim (int): Number of input channels.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(self, dim, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
        super().__init__()
        self.dim = dim
        self.depth = depth
        self.window_size = window_size
        self.use_checkpoint = use_checkpoint
        self.shift_size = window_size // 2  # 表示向右和向下偏移的窗口大小   即窗口大小除以2,然后向下取整

        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(
                dim=dim,
                num_heads=num_heads,
                window_size=window_size,
                shift_size=0 if (i % 2 == 0) else self.shift_size,   # 通过判断shift_size是否等于0,来决定是使用W-MSA与SW-MSA
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop,
                attn_drop=attn_drop,
                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                norm_layer=norm_layer)
            for i in range(depth)])

        # patch merging layer    即:PatchMerging类
        if downsample is not None:
            self.downsample = downsample(dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

    def create_mask(self, x, H, W):
        # calculate attention mask for SW-MSA
        # 保证Hp和Wp是window_size的整数倍
        Hp = int(np.ceil(H / self.window_size)) * self.window_size
        Wp = int(np.ceil(W / self.window_size)) * self.window_size
        # 拥有和feature map一样的通道排列顺序,方便后续window_partition
        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # [1, Hp, Wp, 1]
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        # 将img_mask划分成一个一个窗口
        mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]           # 输出的是按照指定的window_size划分成一个一个窗口的数据
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # [nW, Mh*Mw]
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]  使用了广播机制
        # [nW, Mh*Mw, Mh*Mw]
        # 因为需要求得的是自身注意力机制,所以,所以相同的区域使用0表示,;不同的区域不等于0,填入-100,这样,在求得
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))   # 即对于不等于0的位置,赋值为-100;否则为0
        return attn_mask

    def forward(self, x, H, W):
        attn_mask = self.create_mask(x, H, W)  # [nW, Mh*Mw, Mh*Mw]   # 制作mask蒙版
        for blk in self.blocks:
            blk.H, blk.W = H, W
            if not torch.jit.is_scripting() and self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x, attn_mask)
            else:
                x = blk(x, attn_mask)
        if self.downsample is not None:
            x = self.downsample(x, H, W)
            H, W = (H + 1) // 2, (W + 1) // 2

        return x, H, W


5.SW-MSA或者W-MSA模块代码:
class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)    # 先经过层归一化处理

        # WindowAttention即为:SW-MSA或者W-MSA模块
        self.attn = WindowAttention(
            dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,
            attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x, attn_mask):
        H, W = self.H, self.W
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # pad feature maps to multiples of window size
        # 把feature map给pad到window size的整数倍
        pad_l = pad_t = 0
        pad_r = (self.window_size - W % self.window_size) % self.window_size
        pad_b = (self.window_size - H % self.window_size) % self.window_size
        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
        _, Hp, Wp, _ = x.shape

        # cyclic shift
        # 判断是进行SW-MSA或者是W-MSA模块
        if self.shift_size > 0:
            # https://blog.csdn.net/ooooocj/article/details/126046858?ops_request_misc=&request_id=&biz_id=102&utm_term=torch.roll()%E7%94%A8%E6%B3%95&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduweb~default-0-126046858.142^v73^control,201^v4^add_ask,239^v1^control&spm=1018.2226.3001.4187
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))    #进行数据移动操作
        else:
            shifted_x = x
            attn_mask = None

        # partition windows
        # 将窗口按照window_size的大小进行划分,得到一个个窗口
        x_windows = window_partition(shifted_x, self.window_size)  # [nW*B, Mh, Mw, C]
        # 将数据进行展平操作
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # [nW*B, Mh*Mw, C]

        # W-MSA/SW-MSA
        """
            # 进行多头自注意力机制操作
        """
        attn_windows = self.attn(x_windows, mask=attn_mask)  # [nW*B, Mh*Mw, C]

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)  # [nW*B, Mh, Mw, C]
        # 将多窗口拼接回大的featureMap
        shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # [B, H', W', C]

        # reverse cyclic shift
        # 将移位的数据进行还原
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        # 如果进行了padding操作,需要移出掉相应的pad
        if pad_r > 0 or pad_b > 0:
            # 把前面pad的数据移除掉
            x = x[:, :H, :W, :].contiguous()

        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

4. SW-MSA或者W-MSA模块代码

class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)    # 先经过层归一化处理

        # WindowAttention即为:SW-MSA或者W-MSA模块
        self.attn = WindowAttention(
            dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,
            attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x, attn_mask):
        H, W = self.H, self.W
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # pad feature maps to multiples of window size
        # 把feature map给pad到window size的整数倍
        pad_l = pad_t = 0
        pad_r = (self.window_size - W % self.window_size) % self.window_size
        pad_b = (self.window_size - H % self.window_size) % self.window_size
        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
        _, Hp, Wp, _ = x.shape

        # cyclic shift
        # 判断是进行SW-MSA或者是W-MSA模块
        if self.shift_size > 0:
            # https://blog.csdn.net/ooooocj/article/details/126046858?ops_request_misc=&request_id=&biz_id=102&utm_term=torch.roll()%E7%94%A8%E6%B3%95&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduweb~default-0-126046858.142^v73^control,201^v4^add_ask,239^v1^control&spm=1018.2226.3001.4187
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))    #进行数据移动操作
        else:
            shifted_x = x
            attn_mask = None

        # partition windows
        # 将窗口按照window_size的大小进行划分,得到一个个窗口
        x_windows = window_partition(shifted_x, self.window_size)  # [nW*B, Mh, Mw, C]
        # 将数据进行展平操作
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # [nW*B, Mh*Mw, C]

        # W-MSA/SW-MSA
        """
            # 进行多头自注意力机制操作
        """
        attn_windows = self.attn(x_windows, mask=attn_mask)  # [nW*B, Mh*Mw, C]

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)  # [nW*B, Mh, Mw, C]
        # 将多窗口拼接回大的featureMap
        shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # [B, H', W', C]

        # reverse cyclic shift
        # 将移位的数据进行还原
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        # 如果进行了padding操作,需要移出掉相应的pad
        if pad_r > 0 or pad_b > 0:
            # 把前面pad的数据移除掉
            x = x[:, :H, :W, :].contiguous()

        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

5. 整体流程代码实现 

""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`
    - https://arxiv.org/pdf/2103.14030

Code/weights from https://github.com/microsoft/Swin-Transformer

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import numpy as np
from typing import Optional


def drop_path_f(x, drop_prob: float = 0., training: bool = False):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.

    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path_f(x, self.drop_prob, self.training)

"""
    将窗口按照window_size的大小进行划分,得到一个个窗口
"""
def window_partition(x, window_size: int):
    """
    将feature map按照window_size划分成一个个没有重叠的window
    Args:
        x: (B, H, W, C)
        window_size (int): window size(M)

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    # permute: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H//Mh, W//Mh, Mw, Mw, C]
    # view: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B*num_windows, Mh, Mw, C]
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)   # 输出的是按照指定的window_size划分成一个一个窗口的数据
    return windows


def window_reverse(windows, window_size: int, H: int, W: int):
    """
    将一个个window还原成一个feature map
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size(M)
        H (int): Height of image
        W (int): Width of image

    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    # view: [B*num_windows, Mh, Mw, C] -> [B, H//Mh, W//Mw, Mh, Mw, C]
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    # permute: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B, H//Mh, Mh, W//Mw, Mw, C]
    # view: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H, W, C]
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x


class PatchEmbed(nn.Module):
    """
    2D Image to Patch Embedding
    split image into non-overlapping patches   即将图片划分成一个个没有重叠的patch
    """
    def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):
        super().__init__()
        patch_size = (patch_size, patch_size)
        self.patch_size = patch_size
        self.in_chans = in_c
        self.embed_dim = embed_dim
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        _, _, H, W = x.shape

        # padding
        # 如果输入图片的H,W不是patch_size的整数倍,需要进行padding
        pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
        if pad_input:
            # to pad the last 3 dimensions,
            # (W_left, W_right, H_top,H_bottom, C_front, C_back)
            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],   # 表示宽度方向右侧填充数
                          0, self.patch_size[0] - H % self.patch_size[0],   # 表示高度方向底部填充数
                          0, 0))

        # 下采样patch_size倍
        x = self.proj(x)
        _, _, H, W = x.shape
        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x, H, W


class PatchMerging(nn.Module):
    r""" Patch Merging Layer.
        步长为2,间隔采样
    Args:
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x, H, W):
        """
        x: B, H*W, C    即输入x的通道排列顺序
        """
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)

        # padding
        # 如果输入feature map的H,W不是2的整数倍,需要进行padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            # to pad the last 3 dimensions, starting from the last dimension and moving forward.
            # (C_front, C_back, W_left, W_right, H_top, H_bottom)
            # 注意这里的Tensor通道是[B, H, W, C],所以会和官方文档有些不同
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))

        # 以2为间隔进行采样
        x0 = x[:, 0::2, 0::2, :]  # [B, H/2, W/2, C]
        x1 = x[:, 1::2, 0::2, :]  # [B, H/2, W/2, C]
        x2 = x[:, 0::2, 1::2, :]  # [B, H/2, W/2, C]
        x3 = x[:, 1::2, 1::2, :]  # [B, H/2, W/2, C]
        x = torch.cat([x0, x1, x2, x3], -1)  #  ————————>  [B, H/2, W/2, 4*C]   在channael维度上进行拼接
        x = x.view(B, -1, 4 * C)  # [B, H/2*W/2, 4*C]

        x = self.norm(x)
        x = self.reduction(x)  # [B, H/2*W/2, 2*C]

        return x

"""
MLP模块
"""
class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop)
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop2 = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x

"""
WindowAttention即为:SW-MSA或者W-MSA模块
"""
class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # [Mh, Mw]
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        # define a parameter table of relative position bias
        # 创建偏置bias项矩阵
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # [2*Mh-1 * 2*Mw-1, nH]    其元素的个数===>>[(2*Mh-1) * (2*Mw-1)]

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])  # 如果此处的self.window_size[0]为2的话,则生成的coords_h为[0,1]
        coords_w = torch.arange(self.window_size[1])  # 同理得
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # [2, Mh, Mw]
        coords_flatten = torch.flatten(coords, 1)  # [2, Mh*Mw]
        # [2, Mh*Mw, 1] - [2, 1, Mh*Mw]
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # [2, Mh*Mw, Mh*Mw]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # [Mh*Mw, Mh*Mw, 2]
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0  行标+(M-1)
        relative_coords[:, :, 1] += self.window_size[1] - 1     # 列表标+(M-1)
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # [Mh*Mw, Mh*Mw]
        self.register_buffer("relative_position_index", relative_position_index)   # 将relative_position_index放入到模型的缓存当中

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask: Optional[torch.Tensor] = None):
        """
        Args:
            x: input features with shape of (num_windows*B, Mh*Mw, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        # [batch_size*num_windows, Mh*Mw, total_embed_dim]
        B_, N, C = x.shape
        # qkv(): -> [batch_size*num_windows, Mh*Mw, 3 * total_embed_dim]
        # reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)

        # transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]
        # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        # relative_position_bias_table.view: [Mh*Mw*Mh*Mw,nH] -> [Mh*Mw,Mh*Mw,nH]
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # [nH, Mh*Mw, Mh*Mw]
        attn = attn + relative_position_bias.unsqueeze(0)

        # 进行mask,相同区域使用0表示;不同区域使用-100表示
        if mask is not None:
            # mask: [nW, Mh*Mw, Mh*Mw]
            nW = mask.shape[0]  # num_windows
            # attn.view: [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw]
            # mask.unsqueeze: [1, nW, 1, Mh*Mw, Mh*Mw]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        # transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim]
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

"""
    SwinTransformerBlock
"""
class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)    # 先经过层归一化处理

        # WindowAttention即为:SW-MSA或者W-MSA模块
        self.attn = WindowAttention(
            dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,
            attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x, attn_mask):
        H, W = self.H, self.W
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # pad feature maps to multiples of window size
        # 把feature map给pad到window size的整数倍
        pad_l = pad_t = 0
        pad_r = (self.window_size - W % self.window_size) % self.window_size
        pad_b = (self.window_size - H % self.window_size) % self.window_size
        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
        _, Hp, Wp, _ = x.shape

        # cyclic shift
        # 判断是进行SW-MSA或者是W-MSA模块
        if self.shift_size > 0:
            # https://blog.csdn.net/ooooocj/article/details/126046858?ops_request_misc=&request_id=&biz_id=102&utm_term=torch.roll()%E7%94%A8%E6%B3%95&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduweb~default-0-126046858.142^v73^control,201^v4^add_ask,239^v1^control&spm=1018.2226.3001.4187
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))    #进行数据移动操作
        else:
            shifted_x = x
            attn_mask = None

        # partition windows
        # 将窗口按照window_size的大小进行划分,得到一个个窗口
        x_windows = window_partition(shifted_x, self.window_size)  # [nW*B, Mh, Mw, C]
        # 将数据进行展平操作
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # [nW*B, Mh*Mw, C]

        # W-MSA/SW-MSA
        """
            # 进行多头自注意力机制操作
        """
        attn_windows = self.attn(x_windows, mask=attn_mask)  # [nW*B, Mh*Mw, C]

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)  # [nW*B, Mh, Mw, C]
        # 将多窗口拼接回大的featureMap
        shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # [B, H', W', C]

        # reverse cyclic shift
        # 将移位的数据进行还原
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        # 如果进行了padding操作,需要移出掉相应的pad
        if pad_r > 0 or pad_b > 0:
            # 把前面pad的数据移除掉
            x = x[:, :H, :W, :].contiguous()

        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x


class BasicLayer(nn.Module):
    """
    A basic Swin Transformer layer for one stage.

    Args:
        dim (int): Number of input channels.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(self, dim, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
        super().__init__()
        self.dim = dim
        self.depth = depth
        self.window_size = window_size
        self.use_checkpoint = use_checkpoint
        self.shift_size = window_size // 2  # 表示向右和向下偏移的窗口大小   即窗口大小除以2,然后向下取整

        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(
                dim=dim,
                num_heads=num_heads,
                window_size=window_size,
                shift_size=0 if (i % 2 == 0) else self.shift_size,   # 通过判断shift_size是否等于0,来决定是使用W-MSA与SW-MSA
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop,
                attn_drop=attn_drop,
                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                norm_layer=norm_layer)
            for i in range(depth)])

        # patch merging layer    即:PatchMerging类
        if downsample is not None:
            self.downsample = downsample(dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

    def create_mask(self, x, H, W):
        # calculate attention mask for SW-MSA
        # 保证Hp和Wp是window_size的整数倍
        Hp = int(np.ceil(H / self.window_size)) * self.window_size
        Wp = int(np.ceil(W / self.window_size)) * self.window_size
        # 拥有和feature map一样的通道排列顺序,方便后续window_partition
        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # [1, Hp, Wp, 1]
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        # 将img_mask划分成一个一个窗口
        mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]           # 输出的是按照指定的window_size划分成一个一个窗口的数据
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # [nW, Mh*Mw]
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]  使用了广播机制
        # [nW, Mh*Mw, Mh*Mw]
        # 因为需要求得的是自身注意力机制,所以,所以相同的区域使用0表示,;不同的区域不等于0,填入-100,这样,在求得
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))   # 即对于不等于0的位置,赋值为-100;否则为0
        return attn_mask

    def forward(self, x, H, W):
        attn_mask = self.create_mask(x, H, W)  # [nW, Mh*Mw, Mh*Mw]   # 制作mask蒙版
        for blk in self.blocks:
            blk.H, blk.W = H, W
            if not torch.jit.is_scripting() and self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x, attn_mask)
            else:
                x = blk(x, attn_mask)
        if self.downsample is not None:
            x = self.downsample(x, H, W)
            H, W = (H + 1) // 2, (W + 1) // 2

        return x, H, W


class SwinTransformer(nn.Module):
    r""" Swin Transformer
        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
          https://arxiv.org/pdf/2103.14030

    Args:
        patch_size (int | tuple(int)): Patch size. Default: 4   表示通过Patch Partition层后,下采样几倍
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        embed_dim (int): Patch embedding dimension. Default: 96
        depths (tuple(int)): Depth of each Swin Transformer layer.
        num_heads (tuple(int)): Number of attention heads in different layers.
        window_size (int): Window size. Default: 7
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        drop_rate (float): Dropout rate. Default: 0
        attn_drop_rate (float): Attention dropout rate. Default: 0
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        patch_norm (bool): If True, add normalization after patch embedding. Default: True
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
    """

    def __init__(self, patch_size=4,  # 表示通过Patch Partition层后,下采样几倍
                 in_chans=3,           # 输入图像通道
                 num_classes=1000,     # 类别数
                 embed_dim=96,         # Patch partition层后的LinearEmbedding层映射后的维度,之后的几层都是该数的整数倍  分别是 C、2C、4C、8C
                 depths=(2, 2, 6, 2),  # 表示每一个Stage模块内,Swin Transformer Block重复的次数
                 num_heads=(3, 6, 12, 24),  # 表示每一个Stage模块内,Swin Transformer Block中采用的Multi-Head self-Attention的head的个数
                 window_size=7,         # 表示W-MSA与SW-MSA所采用的window的大小
                 mlp_ratio=4.,          # 表示MLP模块中,第一个全连接层增大的倍数
                 qkv_bias=True,
                 drop_rate=0.,          # 对应的PatchEmbed层后面的
                 attn_drop_rate=0.,     # 对应于Multi-Head self-Attention模块中对应的dropRate
                 drop_path_rate=0.1,    # 对应于每一个Swin-Transformer模块中采用的DropRate   其是慢慢的递增的,从0增长到drop_path_rate
                 norm_layer=nn.LayerNorm,
                 patch_norm=True,
                 use_checkpoint=False, **kwargs):
        super().__init__()

        self.num_classes = num_classes
        self.num_layers = len(depths)  # depths:表示重复的Swin Transoformer Block模块的次数  表示每一个Stage模块内,Swin Transformer Block重复的次数
        self.embed_dim = embed_dim
        self.patch_norm = patch_norm
        # stage4输出特征矩阵的channels
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.mlp_ratio = mlp_ratio

        # split image into non-overlapping patches   即将图片划分成一个个没有重叠的patch
        self.patch_embed = PatchEmbed(
            patch_size=patch_size, in_c=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        self.pos_drop = nn.Dropout(p=drop_rate)   # PatchEmbed层后面的Dropout层

        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            # 注意这里构建的stage和论文图中有些差异
            # 这里的stage不包含该stage的patch_merging层,包含的是下个stage的
            layers = BasicLayer(dim=int(embed_dim * 2 ** i_layer),  # 传入特征矩阵的维度,即channel方向的深度
                                depth=depths[i_layer],              # 表示当前stage中需要堆叠的多少Swin Transformer Block
                                num_heads=num_heads[i_layer],       # 表示每一个Stage模块内,Swin Transformer Block中采用的Multi-Head self-Attention的head的个数
                                window_size=window_size,            # 表示W-MSA与SW-MSA所采用的window的大小
                                mlp_ratio=self.mlp_ratio,           # 表示MLP模块中,第一个全连接层增大的倍数
                                qkv_bias=qkv_bias,
                                drop=drop_rate,                     # 对应的PatchEmbed层后面的
                                attn_drop=attn_drop_rate,           # 对应于Multi-Head self-Attention模块中对应的dropRate
                                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],     # 对应于每一个Swin-Transformer模块中采用的DropRate   其是慢慢的递增的,从0增长到drop_path_rate
                                norm_layer=norm_layer,
                                downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,   # 判断是否是第四个,因为第四个Stage是没有PatchMerging层的
                                use_checkpoint=use_checkpoint)
            self.layers.append(layers)

        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)   # 自适应的全局平均池化
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        # x: [B, L, C]
        x, H, W = self.patch_embed(x)  # 对图像下采样4倍
        x = self.pos_drop(x)

        # 依次传入各个stage中
        for layer in self.layers:
            x, H, W = layer(x, H, W)

        x = self.norm(x)  # [B, L, C]
        x = self.avgpool(x.transpose(1, 2))  # [B, C, 1]
        x = torch.flatten(x, 1)
        x = self.head(x)   # 经过全连接层,得到输出
        return x


def swin_tiny_patch4_window7_224(num_classes: int = 1000, **kwargs):
    # trained ImageNet-1K
    # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=7,
                            embed_dim=96,
                            depths=(2, 2, 6, 2),
                            num_heads=(3, 6, 12, 24),
                            num_classes=num_classes,
                            **kwargs)
    return model


def swin_small_patch4_window7_224(num_classes: int = 1000, **kwargs):
    # trained ImageNet-1K
    # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=7,
                            embed_dim=96,
                            depths=(2, 2, 18, 2),
                            num_heads=(3, 6, 12, 24),
                            num_classes=num_classes,
                            **kwargs)
    return model


def swin_base_patch4_window7_224(num_classes: int = 1000, **kwargs):
    # trained ImageNet-1K
    # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=7,
                            embed_dim=128,
                            depths=(2, 2, 18, 2),
                            num_heads=(4, 8, 16, 32),
                            num_classes=num_classes,
                            **kwargs)
    return model


def swin_base_patch4_window12_384(num_classes: int = 1000, **kwargs):
    # trained ImageNet-1K
    # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=12,
                            embed_dim=128,
                            depths=(2, 2, 18, 2),
                            num_heads=(4, 8, 16, 32),
                            num_classes=num_classes,
                            **kwargs)
    return model


def swin_base_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs):
    # trained ImageNet-22K
    # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=7,
                            embed_dim=128,
                            depths=(2, 2, 18, 2),
                            num_heads=(4, 8, 16, 32),
                            num_classes=num_classes,
                            **kwargs)
    return model


def swin_base_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs):
    # trained ImageNet-22K
    # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=12,
                            embed_dim=128,
                            depths=(2, 2, 18, 2),
                            num_heads=(4, 8, 16, 32),
                            num_classes=num_classes,
                            **kwargs)
    return model


def swin_large_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs):
    # trained ImageNet-22K
    # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=7,
                            embed_dim=192,
                            depths=(2, 2, 18, 2),
                            num_heads=(6, 12, 24, 48),
                            num_classes=num_classes,
                            **kwargs)
    return model


def swin_large_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs):
    # trained ImageNet-22K
    # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth
    model = SwinTransformer(in_chans=3,
                            patch_size=4,
                            window_size=12,
                            embed_dim=192,
                            depths=(2, 2, 18, 2),
                            num_heads=(6, 12, 24, 48),
                            num_classes=num_classes,
                            **kwargs)
    return model

Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐