1.摘要

上节我们基于U-Net模型设计并实现了在医学细胞分割上的应用(ISBI 挑战数据集),并给出了模型的详细代码解释,在上个博客中,我们为了快速训练U-Net模型对其进行了缩减,将庞大的U-Net的转换为很小&的结构,导致其准确率才达到75%左右。为了进一步提高U-Net模型在细胞分割上的准确率,本文将主要研究两个方面:一是基于U-Net的原始模型结构进行改进,引入卷积注意力机制模块(CBAM)和Focal Tversky损失函数;二是引入深监督方法(DEEP SUPERVISION)及多尺度输入作为U-Net模型的原始输入,该模型被命名为DAMU-Net。为了进一步验证该模型的性能,我们同样在ISBI 挑战数据集上进行实验,并给出相应的实验结果。

2.相关技术概述

2.1 Focal Tversky损失函数

医学影像中存在很多的数据不平衡现象,使用不平衡数据进行训练会导致严重偏向高精度但低召回率(sensitivity)的预测,这是我们不希望的,特别是在医学应用中,假阴性比假阳性多更难容忍。而Tversky广义损失函数可以有效解决了三维全卷积深神经网络训练中数据不平衡的问题,在精度和召回率之间找到更好的平衡。与Focal loss相似,Focal Tversky Loss着重于通过通过调整超参数α和β,我们可以控制假阳性和假阴性之间的权衡。较大的β会使召回的准确性高于精确度(通过更加强调假阴性)。其公式如下:

preview

2.2  深监督方法

  所谓深监督(Deep Supervision),就是在深度神经网络的某些中间隐藏层加了一个辅助的分类器作为一种网络分支来对主干网络进行监督的技巧,用来解决深度神经网络训练梯度消失和收敛速度过慢等问题。 深监督作为一个训练trick在2014年就已经通过DSN(Deeply-Supervised Nets)提出来了.

 通常而言,增加神经网络的深度可以一定程度上提高网络的表征能力,但随着深度加深,会逐渐出现神经网络难以训练的情况,其中就包括像梯度消失和梯度爆炸等现象。为了更好的训练深度网络,人们尝试给神经网络的某些层添加一些辅助的分支分类器来解决这个问题。这种辅助的分支分类器能够起到一种判断隐藏层特征图质量好坏的作用。其结构如下:

其中各个模块含义如下:

 可以看到,图中在第四个卷积块之后添加了一个监督分类器作为分支。Conv4输出的特征图除了随着主网络进入Conv5之外,也作为输入进入了分支分类器。往往分支与主网络一起训练。

3.模型实现

为了在精确性和召回性之间实现进一步的平衡,本文设计实现一种基于卷积注意力机制的U-Net模型, 该体 系结构基于流行的UNet,并将输入图像的多尺寸特征张量作为输入,以便更好的提取局部特征。其模型结构如图所示:

为了进一步细化模型实验,我们将分三个步骤实现上述最终模型。

3.1 基于卷积注意力机制的U-Net模型

该模型只是单纯的将注意力机制引入U-Net模型中,目的是将输入图像的低级特征映射中识别相关的空间信息,并将其传播到解码阶段,以达到真正地提取出积极有效的特征。其具体代码实现可以查看 上篇博客https://haosen.blog.csdn.net/article/details/117755633。在该博客中有模型的具体结构图及代码实现。
 

3.2 基于卷积注意力机制和深监督的U-Net模型

其具体代码实现可以查看上篇博客https://haosen.blog.csdn.net/article/details/117756027

3.3 模型代码实现

def attn_reg(opt,input_size, lossfxn):
    
    img_input = Input(shape=input_size, name='input_scale1')
    scale_img_2 = AveragePooling2D(pool_size=(2, 2), name='input_scale2')(img_input)
    scale_img_3 = AveragePooling2D(pool_size=(2, 2), name='input_scale3')(scale_img_2)
    scale_img_4 = AveragePooling2D(pool_size=(2, 2), name='input_scale4')(scale_img_3)

    conv1 = UnetConv2D(img_input, 32, is_batchnorm=True, name='conv1')
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    
    input2 = Conv2D(64, (3, 3), padding='same', activation='relu', name='conv_scale2')(scale_img_2)
    input2 = concatenate([input2, pool1], axis=3)
    conv2 = UnetConv2D(input2, 64, is_batchnorm=True, name='conv2')
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    
    input3 = Conv2D(128, (3, 3), padding='same', activation='relu', name='conv_scale3')(scale_img_3)
    input3 = concatenate([input3, pool2], axis=3)
    conv3 = UnetConv2D(input3, 128, is_batchnorm=True, name='conv3')
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    
    input4 = Conv2D(256, (3, 3), padding='same', activation='relu', name='conv_scale4')(scale_img_4)
    input4 = concatenate([input4, pool3], axis=3)
    conv4 = UnetConv2D(input4, 64, is_batchnorm=True, name='conv4')
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
        
    center = UnetConv2D(pool4, 512, is_batchnorm=True, name='center')
    
    g1 = UnetGatingSignal(center, is_batchnorm=True, name='g1')
    attn1 = AttnGatingBlock(conv4, g1, 128, '_1')
    up1 = concatenate([Conv2DTranspose(32, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(center), attn1], name='up1')

    g2 = UnetGatingSignal(up1, is_batchnorm=True, name='g2')
    attn2 = AttnGatingBlock(conv3, g2, 64, '_2')
    up2 = concatenate([Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(up1), attn2], name='up2')

    g3 = UnetGatingSignal(up1, is_batchnorm=True, name='g3')
    attn3 = AttnGatingBlock(conv2, g3, 32, '_3')
    up3 = concatenate([Conv2DTranspose(32, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(up2), attn3], name='up3')

    up4 = concatenate([Conv2DTranspose(32, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(up3), conv1], name='up4')
    
    conv6 = UnetConv2D(up1, 256, is_batchnorm=True, name='conv6')
    conv7 = UnetConv2D(up2, 128, is_batchnorm=True, name='conv7')
    conv8 = UnetConv2D(up3, 64, is_batchnorm=True, name='conv8')
    conv9 = UnetConv2D(up4, 32, is_batchnorm=True, name='conv9')

    out6 = Conv2D(1, (1, 1), activation='sigmoid', name='pred1')(conv6)
    out7 = Conv2D(1, (1, 1), activation='sigmoid', name='pred2')(conv7)
    out8 = Conv2D(1, (1, 1), activation='sigmoid', name='pred3')(conv8)
    out9 = Conv2D(1, (1, 1), activation='sigmoid', name='final')(conv9)

    model = Model(inputs=[img_input], outputs=[out6, out7, out8, out9])
 
    loss = {'pred1':lossfxn,
            'pred2':lossfxn,
            'pred3':lossfxn,
            'final': losses.tversky_loss}
    
    loss_weights = {'pred1':1,
                    'pred2':1,
                    'pred3':1,
                    'final':1}
    model.compile(optimizer=opt, loss=loss, loss_weights=loss_weights,
                  metrics=[losses.dsc])
    model.summary()

    from keras.utils.vis_utils import plot_model
    plot_model(model, to_file='model1.png', show_shapes=True)
    return model

模型参数结构图(点击观看) 

4. 实验结果

模型

  DES

U-Net

0.878

ATT-U-Net

DATT-U-Net

DAMU-Net

 

还有一些结果正在用CPU运行,太慢了.....

 

 

Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐