概述

attention-unet主要的贡献就是提出了attention gate,它即插即用,可以直接集成到unet模型当中,作用在于抑制输入图像中的不相关区域,同时突出特定局部区域的显著特征,并且它用soft-attention 代替hard-attention,所以attention权重可以由网络学习,并且不需要额外的label,只增加少量的计算量。

细节

结构

核心还是unet的结构,但是在做skip-connection的时候,中间加了一个attention gate,经过这个ag之后,再进行concat操作。因为encoder中的细粒度信息相对多一点,但是很多是不需要的冗余的,ag相当于是对encoder的当前层进行了一个过滤,抑制图像中的无关信息,突出局部的重要特征。
在这里插入图片描述

attention gate

两个输入分别是encoder的当前层 x l x^l xl和decoder的下一层 g g g,他们经过1x1的卷积(将通道数变为一致之后),再做逐元素的相加,然后经过relu,1x1的卷积(将通道数降为1)和sigmoid得到注意力系数,然后再经过一个resample模块将尺寸还原回来,最后就可以使用注意力系数对特征图进行加权了。
注:这里是3D的,2D理解的话,直接去掉最后一个维度就好了。

在这里插入图片描述
一些解释
为什么要两个输入做加法而不是直接根据encoder的当前层得到注意力系数呢?
可能是因为,首先处理完成之后的两张相同尺寸和通道数的特征图,提取的特征是不同的。那么这么操作能够使相同的感兴趣区域的信号加强,同时各自不同的区域也能作为辅助,两份加起来辅助信息也会更多。或者说是对核心信息的进一步强调,同时又不忽视那些细节信息。
在这里插入图片描述
为什么需要resample呢?
因为 x l 与 g x^l与g xlg的尺寸是不相同的,显然 g g g的尺寸是 x l x^l xl的一半,他们不可能进行逐元素的相加,所以需要将两个尺寸变得一致,要么大的下采样要么小的上采样,实验出来是大的下采样效果好。但是这个操作之后得到的就是注意力系数了,要和 x l x^l xl做加权肯定要尺寸相同的,所以还得重新上采样。

attention

Attention函数的本质可以被描述为一个查询(query)到一系列(键key-值value)对的映射
在计算attention时主要分为三步:

  • 第一步是将query和每个key进行相似度计算得到权重,常用的相似度函数有点积,拼接,感知机等;
  • 第二步一般是使用一个softmax函数对这些权重进行归一化;
  • 最后将权重和相应的键值value进行加权求和得到最后的attention。

hard-attention:一次选择一个图像的一个区域作为注意力,设成1,其他设为0。他是不能微分的,无法进行标准的反向传播,因此需要蒙特卡洛采样来计算各个反向传播阶段的精度。 考虑到精度取决于采样的完成程度,因此需要其 他技术(例如强化学习)。

soft-attention:加权图像的每个像素。 高相关性区域乘以较大的权重,而低相关性区域标记为较小的权重。权重范围是(0-1)。他是可微的,可以正常进行反向传播。

简单实现

import paddle
import paddle.nn as nn


# 两次卷积操作
# 卷积计算公式:
# 输出大小 = (输入大小 − Filter + 2Padding )/Stride+1
class VGGBlock(nn.Layer):
    def __init__(self,in_channels,out_channels):
        super(VGGBlock, self).__init__()
        self.layer=nn.Sequential(
            nn.Conv2D(in_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2D(out_channels),
            nn.LeakyReLU(),

            nn.Conv2D(out_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2D(out_channels),
            nn.LeakyReLU()
        )
    def forward(self,x):
        return self.layer(x)

# 先让encoder当前层通过attention_gate
# 然后再将decoder当前层上采样并且和encoder当前层 做concat
# 反卷积(转置卷积)计算公式:
# 输出大小 = (输入大小 − 1) * Stride + Filter - 2 * Padding
# 当前这种设置使得输入输出尺寸相同
class Up(nn.Layer):
    def __init__(self,in_channels,out_channels):
        super(Up, self).__init__()
        self.layer=nn.Sequential(
            nn.Conv2DTranspose(in_channels, out_channels, 4, 2, 1)
        )

    def forward(self,x1,x2,attention_gate):
        x1=self.layer(x1)
        x2=attention_gate(x1,x2)
        # 因为tensor是ncwh的 我们需要在c维度上concat 所以axis是1
        return paddle.concat([x2,x1],axis=1)

# 软注意力机制 相当于是一个筛子 对encoder的当前层进行筛选
class AttentionGate(nn.Layer):
    # 因为我们是将encoder的当前层 以及decoder的下一层上采样之后的结果送入Attention_Block的
    # 所以他们的尺寸以及通道数适相同的
    def __init__(self,in_channels,out_channels):
        super(AttentionGate, self).__init__()
        self.w_g=nn.Sequential(
            nn.Conv2D(in_channels,out_channels,1,1,0),
            nn.BatchNorm2D(out_channels)
        )
        self.w_x=nn.Sequential(
            nn.Conv2D(in_channels, out_channels, 1, 1, 0),
            nn.BatchNorm2D(out_channels)
        )
        self.relu=nn.LeakyReLU()
        self.psi=nn.Sequential(
            nn.Conv2D(out_channels, 1, 1, 1, 0),
            nn.BatchNorm2D(1),
            nn.Sigmoid()
        )


    def forward(self, g,x):
        g1=self.w_g(g)
        x1=self.w_x(x)
        psi=self.relu(g1+x1)
        psi=self.psi(psi)
        return x*psi

class AttentionUNet(nn.Layer):
    def __init__(self,num_classes=2):
        super(AttentionUNet, self).__init__()
        filters = [64, 128, 256, 512, 1024]
        self.pool = nn.MaxPool2D(2)
        ## -------------encoder-------------
        self.encoder_1 = VGGBlock(3, filters[0])
        self.encoder_2 = VGGBlock(filters[0], filters[1])
        self.encoder_3 = VGGBlock(filters[1], filters[2])
        self.encoder_4 = VGGBlock(filters[2], filters[3])
        self.encoder_5 = VGGBlock(filters[3], filters[4])

        ## -------------decoder-------------
        self.up_4 = Up(filters[4], filters[3])
        self.up_3 = Up(filters[3], filters[2])
        self.up_2 = Up(filters[2], filters[1])
        self.up_1 = Up(filters[1], filters[0])

        self.decoder_4 = VGGBlock(filters[4], filters[3])
        self.decoder_3 = VGGBlock(filters[3], filters[2])
        self.decoder_2 = VGGBlock(filters[2], filters[1])
        self.decoder_1 = VGGBlock(filters[1], filters[0])

        self.attention_gate4=AttentionGate(512,256)
        self.attention_gate3=AttentionGate(256,128)
        self.attention_gate2=AttentionGate(128,64)
        self.attention_gate1=AttentionGate(64,32)

        self.final = nn.Sequential(
            nn.Conv2D(64,num_classes,3,1,1),
        )
    def forward(self,x):
        ## -------------encoder-------------
        encoder_1 = self.encoder_1(x)
        encoder_2 = self.encoder_2(self.pool(encoder_1))
        encoder_3 = self.encoder_3(self.pool(encoder_2))
        encoder_4 = self.encoder_4(self.pool(encoder_3))
        encoder_5 = self.encoder_5(self.pool(encoder_4))

        ## -------------decoder-------------
        decoder_4 = self.up_4(encoder_5, encoder_4,self.attention_gate4)
        decoder_4 = self.decoder_4(decoder_4)

        decoder_3 = self.up_3(decoder_4, encoder_3,self.attention_gate3)
        decoder_3 = self.decoder_3(decoder_3)

        decoder_2 = self.up_2(decoder_3, encoder_2,self.attention_gate2)
        decoder_2 = self.decoder_2(decoder_2)

        decoder_1 = self.up_1(decoder_2, encoder_1,self.attention_gate1)
        decoder_1 = self.decoder_1(decoder_1)

        output = self.final(decoder_1)
        return output

if __name__ == '__main__':
    # x=paddle.randn(shape=[2,3,256,256])
    unet=AttentionUNet()
    # print(net(x).shape)
    paddle.summary(unet, (1,3,256,256))

Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐