基于改进注意力机制的U-Net模型实现及应用(keras框架实现)
1.摘要上节我们基于U-Net模型设计并实现了在医学细胞分割上的应用(ISBI 挑战数据集),并给出了模型的详细代码解释,在上个博客中,我们为了快速训练U-Net模型对其进行了缩减,将庞大的U-Net的转换为很小&的结构,导致其准确率才达到75%左右。为了进一步提高U-Net模型在细胞分割上的准确率,本文将主要研究两个方面:一是基于U-Net的原始模型结构进行改进,引入卷积注意力机制模块(
1.摘要
上节我们基于U-Net模型设计并实现了在医学细胞分割上的应用(ISBI 挑战数据集),并给出了模型的详细代码解释,在上个博客中,我们为了快速训练U-Net模型对其进行了缩减,将庞大的U-Net的转换为很小&的结构,导致其准确率才达到75%左右。为了进一步提高U-Net模型在细胞分割上的准确率,本文将主要研究两个方面:一是基于U-Net的原始模型结构进行改进,引入卷积注意力机制模块(CBAM)和Focal Tversky损失函数;二是引入深监督方法(DEEP SUPERVISION)及多尺度输入作为U-Net模型的原始输入,该模型被命名为DAMU-Net。为了进一步验证该模型的性能,我们同样在ISBI 挑战数据集上进行实验,并给出相应的实验结果。
2.相关技术概述
2.1 Focal Tversky损失函数
医学影像中存在很多的数据不平衡现象,使用不平衡数据进行训练会导致严重偏向高精度但低召回率(sensitivity)的预测,这是我们不希望的,特别是在医学应用中,假阴性比假阳性多更难容忍。而Tversky广义损失函数可以有效解决了三维全卷积深神经网络训练中数据不平衡的问题,在精度和召回率之间找到更好的平衡。与Focal loss相似,Focal Tversky Loss着重于通过通过调整超参数α和β,我们可以控制假阳性和假阴性之间的权衡。较大的β会使召回的准确性高于精确度(通过更加强调假阴性)。其公式如下:
2.2 深监督方法
所谓深监督(Deep Supervision),就是在深度神经网络的某些中间隐藏层加了一个辅助的分类器作为一种网络分支来对主干网络进行监督的技巧,用来解决深度神经网络训练梯度消失和收敛速度过慢等问题。 深监督作为一个训练trick在2014年就已经通过DSN(Deeply-Supervised Nets)提出来了.
通常而言,增加神经网络的深度可以一定程度上提高网络的表征能力,但随着深度加深,会逐渐出现神经网络难以训练的情况,其中就包括像梯度消失和梯度爆炸等现象。为了更好的训练深度网络,人们尝试给神经网络的某些层添加一些辅助的分支分类器来解决这个问题。这种辅助的分支分类器能够起到一种判断隐藏层特征图质量好坏的作用。其结构如下:
其中各个模块含义如下:
可以看到,图中在第四个卷积块之后添加了一个监督分类器作为分支。Conv4输出的特征图除了随着主网络进入Conv5之外,也作为输入进入了分支分类器。往往分支与主网络一起训练。
3.模型实现
3.1 基于卷积注意力机制的U-Net模型
3.2 基于卷积注意力机制和深监督的U-Net模型
其具体代码实现可以查看上篇博客:https://haosen.blog.csdn.net/article/details/117756027;
3.3 模型代码实现
def attn_reg(opt,input_size, lossfxn):
img_input = Input(shape=input_size, name='input_scale1')
scale_img_2 = AveragePooling2D(pool_size=(2, 2), name='input_scale2')(img_input)
scale_img_3 = AveragePooling2D(pool_size=(2, 2), name='input_scale3')(scale_img_2)
scale_img_4 = AveragePooling2D(pool_size=(2, 2), name='input_scale4')(scale_img_3)
conv1 = UnetConv2D(img_input, 32, is_batchnorm=True, name='conv1')
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
input2 = Conv2D(64, (3, 3), padding='same', activation='relu', name='conv_scale2')(scale_img_2)
input2 = concatenate([input2, pool1], axis=3)
conv2 = UnetConv2D(input2, 64, is_batchnorm=True, name='conv2')
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
input3 = Conv2D(128, (3, 3), padding='same', activation='relu', name='conv_scale3')(scale_img_3)
input3 = concatenate([input3, pool2], axis=3)
conv3 = UnetConv2D(input3, 128, is_batchnorm=True, name='conv3')
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
input4 = Conv2D(256, (3, 3), padding='same', activation='relu', name='conv_scale4')(scale_img_4)
input4 = concatenate([input4, pool3], axis=3)
conv4 = UnetConv2D(input4, 64, is_batchnorm=True, name='conv4')
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
center = UnetConv2D(pool4, 512, is_batchnorm=True, name='center')
g1 = UnetGatingSignal(center, is_batchnorm=True, name='g1')
attn1 = AttnGatingBlock(conv4, g1, 128, '_1')
up1 = concatenate([Conv2DTranspose(32, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(center), attn1], name='up1')
g2 = UnetGatingSignal(up1, is_batchnorm=True, name='g2')
attn2 = AttnGatingBlock(conv3, g2, 64, '_2')
up2 = concatenate([Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(up1), attn2], name='up2')
g3 = UnetGatingSignal(up1, is_batchnorm=True, name='g3')
attn3 = AttnGatingBlock(conv2, g3, 32, '_3')
up3 = concatenate([Conv2DTranspose(32, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(up2), attn3], name='up3')
up4 = concatenate([Conv2DTranspose(32, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(up3), conv1], name='up4')
conv6 = UnetConv2D(up1, 256, is_batchnorm=True, name='conv6')
conv7 = UnetConv2D(up2, 128, is_batchnorm=True, name='conv7')
conv8 = UnetConv2D(up3, 64, is_batchnorm=True, name='conv8')
conv9 = UnetConv2D(up4, 32, is_batchnorm=True, name='conv9')
out6 = Conv2D(1, (1, 1), activation='sigmoid', name='pred1')(conv6)
out7 = Conv2D(1, (1, 1), activation='sigmoid', name='pred2')(conv7)
out8 = Conv2D(1, (1, 1), activation='sigmoid', name='pred3')(conv8)
out9 = Conv2D(1, (1, 1), activation='sigmoid', name='final')(conv9)
model = Model(inputs=[img_input], outputs=[out6, out7, out8, out9])
loss = {'pred1':lossfxn,
'pred2':lossfxn,
'pred3':lossfxn,
'final': losses.tversky_loss}
loss_weights = {'pred1':1,
'pred2':1,
'pred3':1,
'final':1}
model.compile(optimizer=opt, loss=loss, loss_weights=loss_weights,
metrics=[losses.dsc])
model.summary()
from keras.utils.vis_utils import plot_model
plot_model(model, to_file='model1.png', show_shapes=True)
return model
模型参数结构图(点击观看)
4. 实验结果
模型 | DES |
U-Net | 0.878 |
ATT-U-Net | ? |
DATT-U-Net | ? |
DAMU-Net | ? |
还有一些结果正在用CPU运行,太慢了.....
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
所有评论(0)