【技术追踪】HiDiff:医学图像分割的混合扩散框架(TMI-2024)
HiDiff:一种用于医学图像分割的新型混合扩散框架,它可以协同现有判别分割模型和新型生成扩散模型的优势,在腹部器官、脑肿瘤、息肉和视网膜血管分割数据集上性能表现 SOTA !
传统分割方法与扩散分割方法结合,做大做强~
HiDiff:一种用于医学图像分割的新型混合扩散框架,它可以协同现有判别分割模型和新型生成扩散模型的优势,在腹部器官、脑肿瘤、息肉和视网膜血管分割数据集上性能表现 SOTA !
论文:HiDiff: Hybrid Diffusion Framework for Medical Image Segmentation
代码:https://github.com/takimailto/HiDiff
0、摘要
随着深度学习(DL)技术的快速发展,医学图像分割取得了显著进展。现有的基于 DL 的分割模型通常是判别性的;也就是说,他们的目标是学习从输入图像到分割掩码的映射。然而,这些判别方法忽略了底层数据的分布和固有的类特征,存在特征空间不稳定的问题。
本文建议用生成模型的底层数据分布知识来补充判别分割:
本文提出一种新的混合扩散框架,称为 HiDiff,它可以协同现有的判别分割模型和生成扩散模型的优势,HiDiff 包括两个关键组件:判别分割器(discriminative segmentor)和扩散细化器(diffusion refiner);
(1)利用任意传统的训练好的分割模型作为判别分割器,它可以为扩散细化器提供先验的分割掩码;
(2)提出一种新的二元伯努利扩散模型(binary Bernoulli diffusion model,BBDM)作为扩散细化器,该模型可以通过对底层数据分布进行建模,有效、高效、交互式地细化分割掩码;(更细,更强,再创辉煌~)
(3)以交替协作的方式训练分割器和 BBDM,相互促进;
在腹部器官、脑肿瘤、息肉和视网膜血管分割数据集上的大量实验结果,涵盖了四种广泛使用的影像模态,证明了 HiDiff 优于现有的医学分割算法,包括最先进的基于 transformer 和 diffusion 的分割算法;此外,HiDiff 擅长分割小目标和泛化到新的数据集。(好、很好、非常好~)
1、引言
1.1、现有基于 DL 的分割方法的局限
(1)基于 CNN 和 ViT 变体的分割方法,通常使用交叉熵或 Dice 损失来学习从输入医学图像到分割掩码的映射函数,该范式被称为直接学习图像像素分类概率的判别方法;
(2)这种判别方式,只关注学习像素特征空间中类之间的决策边界,而没有捕获底层数据分布,无法捕获内在的类特征;
(3)此外,它们学习了一个不稳定的特征空间,当远离决策边界时,会导致性能迅速下降,这使得处理模糊边界和精细目标变得具有挑战性;
1.2、基于生成方法的优势与局限
(1)基于生成的方法首先对输入数据和分割掩码的联合概率进行建模,然后利用学习到的联合概率来评估给定输入图像的分割掩码的条件分布,最后输出预测 mask;
(2)基于生成的方法有可能减轻与判别方法相关的局限性,因为它直接对底层数据分布进行建模;
(3)然而,生成模型具有训练不稳定和推理缓慢的局限;
(4)大多数基于 DPM 的分割方法都依赖于高斯噪声作为扩散核,而忽略了分割任务固有的离散性,此外,DPM 的迭代去噪过程非常耗时;
1.3、本文贡献
本文探索了如何有效地、高效地、交互式地协同现有的判别分割模型和生成式 DPM 的优势。为此,提出了一种新的用于医学图像分割的混合扩散框架,称为 HiDiff,如图1所示:
HiDiff 的概念说明:
(1)提出一种新的混合扩散框架 (HiDiff) 用于医学图像分割,可将现有的判别式分割模型和生成式扩散模型的优点协同起来;
(2)提出一种新的二元伯努利扩散模型 (BBDM) 作为扩散细化器,通过对底层数据分布进行建模,可以有效、高效、交互地细化分割掩码;
(3)引入一种交替协同训练策略来训练判别分割器和扩散细化器,它们在训练过程中可以相互提高;
(4)在腹部器官 (Synapse)、脑肿瘤 (BraTS-2021)、息肉 (Kvasir-SEG 和 CVC-ClinicDB) 和视网膜血管 (DRIVE 和 CHASE_DB1) 分割数据集上的大量实验结果表明,HiDiff 比现有的医学分割算法具有更优越的性能,并且在分割小物体和推广到新数据集方面表现出色;
BBDM的新颖之处:
(1)基于伯努利扩散核增强了扩散模型对分割任务离散目标的建模能力;
(2)二进制化扩散细化器显著提高了推理的效率,计算成本可以忽略不计;
(3)交叉 transformer 使扩散生成特征和鉴别特征之间交互交换;
2、方法
2.1、现有的判别分割器
设 x ∈ R H × W {x \in \mathbb {R}^{H×W}} x∈RH×W 为输入医学图像, y 0 ∈ { 0 , 1 } H × W × C {y_0 \in \{0,1\}^{H×W×C}} y0∈{0,1}H×W×C 为 ground-truth 掩码, C {C} C 代表类别数。给定一个图像掩码对 ( x , y 0 ) {(x, y_0)} (x,y0) ,现有判别分割模型通过使用神经网络 f ( ⋅ ) {f(·)} f(⋅) 来预测分割 mask 的可能性, f ( x ) ∈ R H × W × C {f(x) \in \mathbb {R}^{H×W×C}} f(x)∈RH×W×C,每个元素在 ( 0 , 1 ) {(0,1)} (0,1) 的范围内,大多数方法都是通过最小化交叉熵损失或 Dice 损失来进行端到端训练。
尽管这种判别分割表现尚可,但它们无法捕获底层数据分布和内在类特征,导致特征空间不稳定,使得处理模糊边界和精细目标变得困难。为了解决这些限制,本文提出了一个混合扩散框架,以协同现有判别分割模型和 BBDM 的优势。
为预训练 HiDiff 中的判别分割器,使用交叉熵损失和 Dice 损失的组合作为目标函数:
2.2、二元伯努利扩散细化器(Binary Bernoulli Diffusion Refiner)
2.2.1 基于伯努利的扩散模型(BBDM)
采用 U-Net 的一个变体作为扩散细化器
g
(
⋅
)
{g(·)}
g(⋅),迭代地细化由任意判别分割器
f
(
⋅
)
{f(·)}
f(⋅) 生成的先验掩码
f
(
x
)
{f(x)}
f(x),
f
(
x
)
{f(x)}
f(x) 在正向过程中的添加噪声,并作为反向过程的起始采样点。BBDM的整个扩散过程可以表示为:
(1)扩散正向过程
在扩散正向过程中,扩散细化器使用余弦噪声调度
β
0
,
.
.
.
,
β
T
{\beta_0,...,\beta_T}
β0,...,βT 逐渐增加伯努利噪声(伯努利噪声取值为0或1,均值为0,方差为1/2,是离散的),伯努利正向过程如下:
式中
B
(
(
1
−
β
t
)
y
t
−
1
+
β
t
f
(
x
)
)
{\mathcal B((1−β_t)y_{t−1} + β_t f(x))}
B((1−βt)yt−1+βtf(x)) 表示伯努利分布,其概率质量函数定义为:
令
α
t
=
1
−
β
t
{α_t = 1 − β_t}
αt=1−βt,
α
‾
t
=
∏
τ
=
1
t
α
τ
{\overline α_t = \prod_{\tau = 1}^{t} α_\tau}
αt=∏τ=1tατ,可用任意时间步长
t
{t}
t 以采样
y
t
{y_t}
yt :
随时间步长增加,该伯努利分布的平均参数可以看作是先验掩码
f
(
x
)
{f(x)}
f(x) 与真实掩码
y
0
{y_0}
y0之间的插值,可进一步使用伯努利采样噪声
ϵ
∼
B
(
(
1
−
α
‾
t
)
∣
f
(
x
)
−
y
0
∣
)
{\epsilon \sim \mathcal B((1−\overline α_t) |f(x)-y_0|)}
ϵ∼B((1−αt)∣f(x)−y0∣) 来重参数化式(6)中的
y
t
{y_t}
yt 为
y
0
⊕
ϵ
{y_0⊕\epsilon}
y0⊕ϵ,其中
∣
⋅
∣
{|·|}
∣⋅∣ 表示绝对值运算,
⊕
{⊕}
⊕ 表示“异或”运算。
伯努利后验可以表示为:
式中:
ϕ
p
o
s
t
(
y
t
,
y
0
,
f
(
x
)
)
=
∥
{
α
t
[
1
−
y
t
,
y
t
]
+
(
1
−
α
t
)
∣
1
−
y
t
−
f
(
x
)
∣
}
⊙
{
α
‾
t
−
1
[
1
−
y
0
,
y
0
]
+
(
1
−
α
‾
t
−
1
)
[
1
−
f
(
x
)
,
f
(
x
)
]
}
∥
1
{ϕ_{post}(y_t, y_0, f(x)) = ∥\{α_t [1 − y_t,y_t] + (1 − α_t)|1 − y_t − f (x)|\} ⊙ \{\overlineα_{t-1} [1 − y_0,y_0] + (1 − \overlineα_{t-1})[1-f (x),f (x)]\}∥_1}
ϕpost(yt,y0,f(x))=∥{αt[1−yt,yt]+(1−αt)∣1−yt−f(x)∣}⊙{αt−1[1−y0,y0]+(1−αt−1)[1−f(x),f(x)]}∥1
其中,
⊙
{⊙}
⊙ 表示逐元素积,
[
⋅
,
⋅
]
{[·,·]}
[⋅,⋅] 为沿着通道维度的矩阵拼接,
∥
⋅
∥
1
{∥ · ∥_1}
∥⋅∥1 为沿着通道维度的
l
1
{\mathscr l_1}
l1 标准化。
在第一个加法运算中,矩阵维数不匹配;广播第二个矩阵来匹配第一个矩阵的维数。
(2)扩散逆向过程
逆向过程由先验掩码
y
T
{y_T}
yT 开始,从一个由预训练的分割模型
f
(
x
)
∈
R
H
×
W
×
C
{f(x) \in \mathbb {R}^{H×W×C}}
f(x)∈RH×W×C 或
y
T
∼
B
(
f
(
x
)
)
{y_T \sim \mathcal B(f(x))}
yT∼B(f(x)) 参数化的伯努利分布中采样,并通过受先验掩码
f
(
x
)
{f(x)}
f(x) 约束的中间潜在变量来学习底层数据分布:
扩散细化器
g
(
⋅
)
{g(·)}
g(⋅) 估计在第
t
{t}
t 个时间步长下的伯努利噪声
ϵ
^
(
y
t
,
t
,
f
(
x
)
)
{\hat \epsilon (y_t, t, f (x))}
ϵ^(yt,t,f(x)),通过校准函数
F
C
{\mathcal {F}_C}
FC 重新参数化
μ
^
(
y
t
,
t
,
f
(
x
)
)
{\hat \mu (y_t, t, f (x))}
μ^(yt,t,f(x)),如下所示:
F
C
{\mathcal {F}_C}
FC 旨在通过两个步骤将潜在变量
y
t
{y_t}
yt 校准为噪声较小的潜在变量
y
t
−
1
{y_{t-1}}
yt−1:
① 通过计算
y
t
{y_t}
yt 与估计噪声
ϵ
^
{\hat \epsilon}
ϵ^ 之间的绝对偏差来估计掩码
y
0
{y_0}
y0;
② 利用等式(7),通过计算伯努利后验
q
(
y
t
−
1
∣
y
t
,
y
0
,
f
(
x
)
)
{q (y_{t−1} | y_t, y_0, f (x))}
q(yt−1∣yt,y0,f(x)) 估计
y
t
−
1
{y_{t-1}}
yt−1 的分布;
(3)扩散目标函数
基于之前 DPM 中负对数似然的变分上界,给定一个图像掩模对
(
x
,
y
0
)
{(x, y_0)}
(x,y0) 和第
t
{t}
t 个潜在变量
y
t
{y_t}
yt,采用Kullback-Leibler (KL) 散度和 focal loss 对扩散细化器进行如下优化:
扩散目标函数定义为:
2.2.2、有效细化的二进制模块
为减轻迭代扩散过程的计算负担,提出使用定制的时间依赖二进制化模块 (TB) 和时间依赖激活 (TA) 模块将扩散细化器二进制化,使其轻量,以可忽略的资源进行有效的细化。值得注意的是,上标 b b b 和 r r r 分别表示二进制和实值。
(1)二进制计算
首先将实值输入张量
U
r
{U^r}
Ur 和权值
W
r
{W^r}
Wr 二进制化为
U
b
{U^b}
Ub 和
W
b
{W^b}
Wb,实值输入张量
U
r
{U^r}
Ur 和权值
W
r
{W^r}
Wr 之间计算量大的浮点矩阵乘法可以被二进制
U
b
{U^b}
Ub 和
W
b
{W^b}
Wb之间轻量级的按位 XNOR 和 popcount 运算所取代,定义如下:
(2)时间依赖二进制模块
为有效适应时间步长条件下 DPM 的迭代性质,受自适应实例归一化的启发,本文设计了 TB 模块对输入张量进行二进制化,TA模块动态激活输入张量:
TB 模块采用通道级时间依赖的二进制阈值
α
i
{α_i}
αi 来实现:
其中,
u
i
b
{u^b_i}
uib 和
u
i
r
{u^r_i}
uir 分别是第
i
{i}
i 个通道上相同输入张量元素的二进制表示和实值表示。此外,
α
i
{α_i}
αi 是由一个轻量级的全连接层生成的,它以时间步长
t
{t}
t 作为输入。
TA 模块的实现如下:
其中
γ
i
{γ_i}
γi 和
ζ
i
{ζ_i}
ζi 是第
i
{i}
i 个通道上可学习的移位参数,从时间步长
t
{t}
t 线性变换。
β
i
{β_i}
βi 是可学习的缩放系数。
2.2.3、交互式增强交叉Transformer
为了交互交换扩散生成特征与鉴别特征进行增强,提出一种新的交叉变压器,称为 X-Frormer,X-Frormer 由两个交叉 transformer 块(CTB)组成。
第一个块利用 U 形扩散细化器的 bottleneck 提取的生成知识将特征编码到判别分割器的中间位置:
f
d
,
f
p
→
f
p
′
{f_d, f_p → f^′_p}
fd,fp→fp′ ;
第二个块在相反的编码方向上操作,将具有判别知识的特征注入扩散细化器中:
f
p
′
,
f
d
→
f
d
′
{f^′_p, f_d → f^′_d}
fp′,fd→fd′ ;
二进制交叉 transformer 块与二进制交叉多头注意:
X-Former 可以使用定制的 TA 和 TB 模块进行二进制化,二进制化的 X-Former 或 BX-Former,由两个二进制化的交叉 transformer 块 (BCTB) 组成,如图3所示,这种二进制化可以减轻transformer 块的计算负担。
2.3、混合扩散框架
提出混合扩散框架,将现有分割器的判别能力与扩散细化器的生成能力结合起来,以改进医学图像分割,如图2所示:
HiDiff 总览:
在训练过程中,以交替协作的方式优化扩散细化器和判别分割器。具体而言,在优化扩散细化器时,冻结了判别分割器,并使用式 (13) 中的扩散目标函数,在优化判别分割器时,冻结扩散细化器,并协同使用判别和扩散目标函数,如下所示:
HiDiff 训练:
推理过程中,首先利用判别分割器生成先验掩码
f
(
x
)
{f(x)}
f(x),然后 HiDiff 从先验掩码中采样初始潜在变量
y
T
{y_T}
yT,经过迭代细化以获得更好的掩码,HiDiff 可使用 DDIM 采样策略。
HiDiff 推理:
3、实验与结果
3.1、实验设置
3.1.1、数据集
(1)Synapse数据集:CT图像腹部器官分割
① 30例腹部CT增强扫描,3779张轴向图像。每次CT扫描约有85到198个切片,分辨率为512×512;
② 18例训练(2211张轴向图像),12例测试;
③ 分割8个腹部器官:主动脉、胆囊(GB)、左肾(KL)、右肾(KR)、肝、胰腺(PC)、脾(SP)和胃(SM);
(2)BraTS数据集:MRI图像脑肿瘤分割
① MRI序列(T1、T2、FLAIR、T1CE)全部拼接在一起作为输入,分辨率为224×224;
② 1126例训练(55174张2D切片),125例测试(3991张2D切片);
③ 分割:坏死性肿瘤核心(NT)、瘤周水肿(ED)和增强肿瘤(ET);
(3)Kvasir-SEG 和 CVC-ClinicDB 数据集:内镜息肉分割
① Kvasir-SEG 1000张,CVC-ClinicDB 612张;
② 训练、验证、测试划分:80%:10%:10%;
(4)DRIVE 和 CHASE_DB1 数据集:视网膜血管分割
① DRIVE 40张,分辨率为 565×584;
② CHASE_DB1 28张,分辨率为 999×960;
3.1.2、实施细节
(1)显卡:1张 NVIDIA V100 GPU;
(2)AdamW 优化器,batch size 32,学习率
1
×
1
0
−
4
{1×10^{−4}}
1×10−4;
(3)超参数:
γ
=
2.0
{ γ= 2.0}
γ=2.0,
λ
D
i
c
e
=
1.0
{ λ_{Dice} = 1.0}
λDice=1.0,
λ
F
o
c
a
l
=
1.0
{ λ_{Focal} = 1.0}
λFocal=1.0,
λ
D
i
f
f
=
1.0
{ λ_{Diff} = 1.0}
λDiff=1.0;
(4)余弦噪声调度:T=10;
(5)
s
=
0.008
{ s= 0.008}
s=0.008
3.1.3、评价指标
(1)Dice系数,95%豪斯多夫距离(HD95);
(2)IoU、recall 和 accuracy;
3.2、与 SOTA 方法比较
比较类别:传统的判别法、生成扩散法和综合方法;
(1)在 Synapse 数据集上的结果
定量结果:
0038 和 0008 两例不同分割方法的定性结果:
(2)在 BraTS 数据集上的结果
定量结果:
4例不同分割方法的定性结果:
(3)在 Kvasir-SEG 和 CVC-ClinicDB 数据集上的结果
定量结果:
4例不同分割方法的定性结果:
(4)在 DRIVE 和 CHASE_DB1 数据集上的结果
定量结果:
4例不同分割方法的定性结果:
3.3、跨数据集评估
通过跨数据集评估,以进一步评估 HiDiff 的通用性:
(1)腹部器官分割任务: Synapse上训练,医学十项分割全能 Medical Segmentation Decathlon (MSD) 测试;
(2)息肉分割任务:Kvasir-SEG训练,ClinicDB测试,或,ClinicDB训练,Kvasir-SEG测试;
腹部器官分割:
息肉分割:
3.4、小目标分割评估
评估 HiDiff 对小目标分割的能力,对 Synapse 和 BraTS 测试集的小目标子集进行了评估:
Synapse小目标子集:
BraTS小目标子集:
3.5、消融实验
除特殊说明,所有消融实验都是在 Synapse 数据集上进行的;
(1)扩散核影响(Gaussian v.s. Bernoulli): 与 BerDiff 比较,伯努利扩散优于高斯扩散,突出了离散核在分割任务中的优势;
(2)HiDiff 的兼容性: HiDiff 可兼容任何 SOTA 判别分割器;
(3)扩散细化过程的有效性: 用判别式改进网络取代所提出的 BBDM 来验证所提出的扩散细化的有效性;
(4)交替协同训练策略: (V1和V2)交替协同训练策略促进了判别分割器和扩散细化器之间的双向知识蒸馏;
(5)Focal Loss: (V2和V3)HiDiff 基于先验掩码引入了伯努利噪声,与传统的BerDiff相比,产生了更小的扰动,这进一步引入了训练过程中的类不平衡挑战;
(6)X-Former: (V3和V4)引入 X-Former 可以通过判别和扩散生成特征之间的双向注入来增强分割结果;
(7)Binarization: (V4和V5)二进制化的引入对分割性能的影响可以忽略不计;二进制化时,HiDiff 进一步为扩散细化过程获得了22×的加速,为整个推理过程得到了10×的加速,有效地减少了计算负担;
太强了,这一抹多的实验~
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
所有评论(0)