diffusion模型原理介绍以及pytorch代码实现
扩散模型 diffusion简单介绍
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档
文章目录
diffusion模型介绍
Diffusion Model (扩散模型) 是一类生成模型, 和 VAE (Variational Autoencoder, 变分自动编码器), GAN (Generative Adversarial Network, 生成对抗网络) 等生成网络不同的是, 扩散模型在前向阶段对图像逐步施加噪声, 直至图像被破坏变成完全的高斯噪声, 然后在逆向阶段学习从高斯噪声还原为原始图像的过程。扩散模型主要分为两个步骤,一个是前向计算,一个是反向计算,前向计算就是对一张原始图片 x 0 x_0 x0,然后不断地向其中加入噪声,图片逐渐模糊,直至全是噪声。每个时间步加入的噪声都是服从正态分布的。反向过程是对一个服从正态分布的高斯噪声进行逐步去噪,直到得到一张清晰的图像。原始论文 https://arxiv.org/abs/2006.11239
一、模型数学原理简要介绍
首先定义一个
β
t
\beta_t
βt,可以理解为加入噪声的权重,
β
t
\beta_t
βt是越要越来越大的,论文中是0.0001到0.002。因为随着一步步加入噪声,每一步必须要加入更大的噪声才能看出加了噪声,第一步加一点噪声就可以看出是加了噪的,但是越往后图片越模糊,必须加更大的噪声。
然后定义
α
t
=
1
−
β
t
\alpha_t=1-\beta_t
αt=1−βt,
α
t
\alpha_t
αt表示前一时刻的权重,
β
t
\beta_t
βt是加入噪声的权重,可见加入噪声权重越大,前一时刻图片的权重会随之变小。
首先列一个式子看看当前时刻t的图像
x
t
x_t
xt和前一时刻图像
x
t
−
1
x_{t-1}
xt−1的关系:
x
t
=
α
t
x
t
−
1
+
1
−
α
t
z
1
\begin{equation} x_t=\sqrt{\alpha_t}x_{t-1}+\sqrt{1-\alpha_t}z_1 \end{equation}
xt=αtxt−1+1−αtz1
其中
α
t
\sqrt{\alpha_t}
αt是前一时刻图像的权重,
1
−
α
t
\sqrt{1-\alpha_t}
1−αt是当前时刻加入噪音的权重,即
β
t
\sqrt{\beta_t}
βt。其中
z
1
z_1
z1是加入的噪声,服从标准正态分布。
初始时已有的是图像
x
0
x_0
x0,如何利用
x
0
x_0
x0得到当前时刻的
x
t
x_t
xt,如果一步一步算,计算效率太慢了,这种串行计算方式类似于RNN模型,只有计算完了当前时间步才能计算下一时间步的信息。所以考虑能不能进行递推,找到一个关系,直接从
x
0
x_0
x0计算出
x
t
x_t
xt。我们再来看一下
x
t
−
1
x_{t-1}
xt−1和
x
t
−
2
x_{t-2}
xt−2的关系:
x
t
−
1
=
α
t
−
1
x
t
−
2
+
1
−
α
t
−
1
z
2
\begin{equation} x_{t-1}=\sqrt{\alpha_{t-1}}x_{t-2}+\sqrt{1-\alpha_{t-1}}z_2 \end{equation}
xt−1=αt−1xt−2+1−αt−1z2
把(2)带入(1)中,可以得到
x
t
x_t
xt和
x
t
−
2
x_{t-2}
xt−2的关系:
x
t
=
α
t
(
α
t
−
1
x
t
−
2
+
1
−
α
t
−
1
z
2
)
+
1
−
α
t
z
1
=
α
t
α
t
−
1
x
t
−
2
+
(
α
t
(
1
−
α
t
−
1
)
z
2
+
1
−
α
t
z
1
)
\begin{equation} \begin{split} x_t&=\sqrt{\alpha_t}(\sqrt{\alpha_{t-1}}x_{t-2}+\sqrt{1-\alpha_{t-1}}z_2) + \sqrt{1-\alpha_t}z_1 \\ &= \sqrt{\alpha_t\alpha_{t-1}}x_{t-2} + (\sqrt{\alpha_t(1-\alpha_{t-1})}z_2+\sqrt{1-\alpha_t}z_1) \end{split} \end{equation}
xt=αt(αt−1xt−2+1−αt−1z2)+1−αtz1=αtαt−1xt−2+(αt(1−αt−1)z2+1−αtz1)
其中加入的噪声
z
1
,
z
2
.
.
.
.
z_1,z_2....
z1,z2....都属于标准正态分布。上式中后半部分的
α
t
(
1
−
α
t
−
1
)
z
2
\sqrt{\alpha_t(1-\alpha_{t-1})}z_2
αt(1−αt−1)z2,属于正态分布
N
(
0
,
α
t
(
1
−
α
t
−
1
)
)
N(0,\alpha_t(1-\alpha_{t-1}))
N(0,αt(1−αt−1)) ,
1
−
α
t
z
1
\sqrt{1-\alpha_t}z_1
1−αtz1属于正态分布
N
(
0
,
1
−
α
t
)
N(0,1-\alpha_t)
N(0,1−αt)。因为:
若
X
∼
N
(
0
,
1
)
则
a
X
∼
N
(
0
,
a
2
)
若X \sim N(0,1)则aX \sim N(0,a^2)
若X∼N(0,1)则aX∼N(0,a2)
又由于:
N
(
0
,
σ
1
2
)
+
N
(
0
,
σ
2
2
)
∼
N
(
0
,
σ
1
2
+
σ
2
2
)
N(0,\sigma^2_1)+N(0,\sigma^2_2)\sim N(0,\sigma_1^2+\sigma^2_2)
N(0,σ12)+N(0,σ22)∼N(0,σ12+σ22)
所以,基于(3)式继续化简得到:
x
t
=
α
t
α
t
−
1
x
t
−
2
+
1
−
α
t
α
t
−
1
z
ˉ
2
x_t=\sqrt{\alpha_t\alpha_{t-1}}x_{t-2} + \sqrt{1-\alpha_t\alpha_{t-1}}\bar{z}_2
xt=αtαt−1xt−2+1−αtαt−1zˉ2
可以发现规律,前一项是
α
t
α
t
−
1
\sqrt{\alpha_t\alpha_{t-1}}
αtαt−1后一项是
1
−
α
t
α
t
−
1
\sqrt{1-\alpha_t\alpha_{t-1}}
1−αtαt−1,可以看出是一个累乘的形式,如果继续把
x
t
−
2
x_{t-2}
xt−2再继续带入,一步步带入到
x
0
x_0
x0可以得到:
x
t
=
α
ˉ
t
x
0
+
1
−
α
ˉ
t
z
t
\begin{equation} x_t=\sqrt{\bar{\alpha}_t}x_0+\sqrt{1-\bar{\alpha}_t}z_t \end{equation}
xt=αˉtx0+1−αˉtzt
其中
α
ˉ
t
\bar{\alpha}_t
αˉt是一个累乘。
计算出了
x
t
和
x
0
x_t和x_0
xt和x0的关系,因此不需要一步一步进行计算,而是直接根据一个时间步t以及初始状态
x
0
x_0
x0就可以得到t时刻的状态。
目前只是完成了前向过程,扩散模型的目的是根据噪声去还原原始的图像,最重要的部分是反向过程。
我们目前知道的是根据
x
t
−
1
x_{t-1}
xt−1求出
x
t
x_t
xt的分布,但是根据
x
t
x_t
xt去求出
x
t
−
1
x_{t-1}
xt−1的分布不太容易,需要求解逆向过程,不好求解,我们借助贝叶斯公式:
P
(
A
∣
B
)
=
P
(
B
∣
A
)
∗
P
(
A
)
P
(
B
)
P(A|B)=P(B|A)*\frac{P(A)}{P(B)}
P(A∣B)=P(B∣A)∗P(B)P(A)
得到:
q
(
x
t
−
1
∣
x
t
,
x
0
)
=
q
(
x
t
∣
x
t
−
1
,
x
0
)
q
(
x
t
−
1
∣
x
0
)
q
(
x
t
∣
x
0
)
\begin{equation} q(x_{t-1}|x_t,x_0)=q(x_t|x_{t-1},x_0)\frac{q(x_{t-1}|x_0)}{q(x_t|x_0)} \end{equation}
q(xt−1∣xt,x0)=q(xt∣xt−1,x0)q(xt∣x0)q(xt−1∣x0)
其中,根据(4),式子右边的三项都可以计算出来:
q
(
x
t
−
1
∣
x
0
)
=
α
ˉ
t
−
1
x
0
+
1
−
α
ˉ
t
−
1
z
∼
N
(
α
ˉ
t
−
1
x
0
,
1
−
α
ˉ
t
−
1
)
q
(
x
t
∣
x
0
)
=
α
ˉ
t
x
0
+
1
−
α
ˉ
t
z
∼
N
(
α
ˉ
t
x
0
,
1
−
α
ˉ
t
)
q
(
x
t
∣
x
t
−
1
,
x
0
)
=
α
t
x
t
−
1
+
1
−
α
t
z
∼
N
(
α
t
x
t
−
1
,
1
−
α
t
)
\begin{equation} \begin{split} q(x_{t-1}|x_0)=\sqrt{\bar{\alpha}_{t-1}}x_0+\sqrt{1-\bar{\alpha}_{t-1}}z \sim N(\sqrt{\bar{\alpha}_{t-1}}x_0,1-\bar{\alpha}_{t-1}) \\ q(x_t|x_0)=\sqrt{\bar{\alpha}_{t}}x_0+\sqrt{1-\bar{\alpha}_{t}}z \sim N(\sqrt{\bar{\alpha}_{t}}x_0,1-\bar{\alpha}_{t}) \\ q(x_t|x_{t-1},x_0) = \sqrt{\alpha_t}x_{t-1}+\sqrt{1-\alpha_t}z \sim N(\sqrt{\alpha_{t}}x_{t-1},1-\alpha_{t}) \end{split} \end{equation}
q(xt−1∣x0)=αˉt−1x0+1−αˉt−1z∼N(αˉt−1x0,1−αˉt−1)q(xt∣x0)=αˉtx0+1−αˉtz∼N(αˉtx0,1−αˉt)q(xt∣xt−1,x0)=αtxt−1+1−αtz∼N(αtxt−1,1−αt)
由于z服从标准正态分布,所以有了后边的部分服从正态分布。
正态分布公式为:
f
(
x
)
=
1
σ
2
π
e
x
p
(
−
1
2
(
x
−
μ
)
2
σ
2
)
\begin{equation} f(x)=\frac{1}{\sigma\sqrt{2\pi}} exp(-\frac{1}{2}\frac{(x-\mu)^2}{\sigma^2} ) \end{equation}
f(x)=σ2π1exp(−21σ2(x−μ)2)
将(6)式中三项服从的正态分布带入到(7) 再带入到(5),可以得到:
q
(
x
t
−
1
∣
x
t
,
x
0
)
∝
e
x
p
(
−
1
2
(
(
x
t
−
α
t
x
t
−
1
)
2
β
t
+
(
x
t
−
1
−
α
ˉ
t
−
1
x
0
)
2
1
−
α
ˉ
t
−
1
−
(
x
t
−
α
ˉ
t
x
0
)
2
1
−
α
ˉ
t
)
)
\begin{equation} q(x_{t-1}|x_t,x_0) \propto exp(-\frac{1}{2}(\frac{(x_t-\sqrt{\alpha_t}x_{t-1})^2}{\beta_t}+\frac{(x_{t-1}-\sqrt{\bar{\alpha}_{t-1}}x_0)^2}{1-\bar{\alpha}_{t-1}}-\frac{(x_t-\sqrt{\bar{\alpha}_t}x_0)^2}{1-\bar{\alpha}_t})) \end{equation}
q(xt−1∣xt,x0)∝exp(−21(βt(xt−αtxt−1)2+1−αˉt−1(xt−1−αˉt−1x0)2−1−αˉt(xt−αˉtx0)2))
正态分布展开后,乘法相当于加,除法相当于减,把他们汇总得到上式。
把正态分布中exp的那项展开后得到:
e
x
p
(
−
(
x
−
μ
)
2
2
σ
2
)
=
e
x
p
(
−
1
2
(
1
σ
2
x
2
−
2
μ
σ
2
x
+
μ
2
σ
2
)
)
\begin{equation} exp(-\frac{(x-\mu)^2}{2\sigma^2})=exp(-\frac{1}{2}(\frac{1}{\sigma^2}x^2-\frac{2\mu}{\sigma^2}x+\frac{\mu^2}{\sigma^2})) \end{equation}
exp(−2σ2(x−μ)2)=exp(−21(σ21x2−σ22μx+σ2μ2))
我们是要得到t-1时刻与t时刻是什么关系,所以我们将(8)式展开化简,保留
x
t
−
1
x_{t-1}
xt−1的项,化简成类似(9)的形式。过程就不一步步列了,(8)式化简后的结果为:
e
x
p
(
−
1
2
(
(
α
t
β
t
+
1
1
−
α
ˉ
t
−
1
)
x
t
−
1
2
−
(
2
α
t
β
t
x
t
+
2
α
ˉ
t
−
1
1
−
α
ˉ
t
−
1
x
0
)
x
t
−
1
+
C
(
x
t
,
x
0
)
)
)
\begin{equation} exp(-\frac{1}{2}((\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha}_{t-1}})x^2_{t-1}-(\frac{2\sqrt{\alpha_t}}{\beta_t}x_t+\frac{2\sqrt{\bar{\alpha}_{t-1}}}{1-\bar{\alpha}_{t-1}}x_0)x_{t-1}+C(x_t,x_0))) \end{equation}
exp(−21((βtαt+1−αˉt−11)xt−12−(βt2αtxt+1−αˉt−12αˉt−1x0)xt−1+C(xt,x0)))
其中C为与
x
t
−
1
x_{t-1}
xt−1无关的常数项
现在式(9)和(10)已经化简为一样的形式,进行对应后可以推出
q
(
x
t
−
1
∣
x
t
,
x
0
)
q(x_{t-1}|x_t,x_0)
q(xt−1∣xt,x0)所服从的正态分布的方差
σ
\sigma
σ和均值
μ
\mu
μ:
σ
=
β
t
(
1
−
α
ˉ
t
−
1
)
α
t
(
1
−
α
ˉ
t
−
1
)
+
β
t
\begin{equation} \sigma = \frac{\beta_t(1-\bar{\alpha}_{t-1})}{\alpha_t(1-\bar{\alpha}_{t-1})+\beta_t} \end{equation}
σ=αt(1−αˉt−1)+βtβt(1−αˉt−1)
所有
α
\alpha
α和
β
\beta
β都已知,所以方差可以理解为是个常量。
μ
=
α
t
(
1
−
α
ˉ
t
−
1
)
1
−
α
ˉ
t
x
t
+
α
ˉ
t
−
1
β
t
1
−
α
ˉ
t
x
0
\begin{equation} \mu=\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}x_t+\frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t}x_0 \end{equation}
μ=1−αˉtαt(1−αˉt−1)xt+1−αˉtαˉt−1βtx0
但是
x
0
x_0
x0是多少呢,上面推出,
x
t
x_t
xt可以根据
x
0
x_0
x0算出,即式(4),然后我们反推出
x
0
x_0
x0:
x
0
=
1
α
ˉ
t
(
x
t
−
1
−
α
ˉ
t
z
t
)
x_0=\frac{1}{\sqrt{\bar{\alpha}_t}}(x_t-\sqrt{1-\bar{\alpha}_t}z_t)
x0=αˉt1(xt−1−αˉtzt)
将
x
0
x_0
x0带入式(12)得到
μ
=
1
α
t
(
x
t
−
β
t
1
−
α
ˉ
t
z
t
)
\begin{equation} \mu=\frac{1}{\sqrt{\alpha_t}}(x_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}z_t) \end{equation}
μ=αt1(xt−1−αˉtβtzt)
现在已经求出了 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt−1∣xt,x0)服从的正态分布的均值和方差,但是式(13)中的 z t z_t zt是不可知的, z t z_t zt是t时刻往模型中加入的噪声,这个加入的噪声是不可知的,如果反向过程每一时刻加入的噪声是多少已知,直接在当前时刻减去噪声就得到了前一时刻的状态,显然是不合理的。既然推不出来这个 z t z_t zt,那就求一个近似解,利用深度学习模型去拟合这个噪声的近似解。一般采用UNet这个模型。
把当前时刻
x
t
x_t
xt以及当前时间步t输入到模型中,让模型去拟合当前时间所加入的噪声
z
t
z_t
zt,这个真实的
z
t
z_t
zt是已知的,因为在前向过程中,每一步加入了多少损失都是可以记录下来的,所以就基于前向计算过程中记录的噪声为标签,去训练这个深度学习模型,简单表示为
z
t
=
m
o
d
e
l
(
x
t
,
t
)
z_t=model(x_t,t)
zt=model(xt,t)。图解如下:
接下来就是整体的流程
训练过程中,输入第t的时刻状态
x
t
x_t
xt以及时刻t到模型中,让模型去拟合当前时刻所加入的噪声。推理阶段,直接输入一个噪声点,基于式(13)以及模型预测值推断出前一时刻状态,依次迭代到
x
0
x_0
x0即为生成的图像。
二、pytorch代码实现(主要部分代码演示)
使用fishon mnist这个数据集进行实验,这个数据集比较小,容易训练。
1.UNet结构
class Unet(nn.Module):
def __init__(
self,
dim,
init_dim=None,
out_dim=None,
dim_mults=(1, 2, 4, 8),
channels=3,
with_time_emb=True,
resnet_block_groups=8,
use_convnext=True,
convnext_mult=2,
):
super().__init__()
# determine dimensions
self.channels = channels
init_dim = default(init_dim, dim // 3 * 2)
self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
if use_convnext:
block_klass = partial(ConvNextBlock, mult=convnext_mult)
else:
block_klass = partial(ResnetBlock, groups=resnet_block_groups)
# time embeddings
if with_time_emb:
time_dim = dim * 4
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(dim),
nn.Linear(dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim),
)
else:
time_dim = None
self.time_mlp = None
# layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(
nn.ModuleList(
[
block_klass(dim_in, dim_out, time_emb_dim=time_dim),
block_klass(dim_out, dim_out, time_emb_dim=time_dim),
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
Downsample(dim_out) if not is_last else nn.Identity(),
]
)
)
mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
self.ups.append(
nn.ModuleList(
[
block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
block_klass(dim_in, dim_in, time_emb_dim=time_dim),
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
Upsample(dim_in) if not is_last else nn.Identity(),
]
)
)
out_dim = default(out_dim, channels)
self.final_conv = nn.Sequential(
block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
)
def forward(self, x, time):
x = self.init_conv(x)
t = self.time_mlp(time) if exists(self.time_mlp) else None
h = []
# downsample
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
x = block2(x, t)
x = attn(x)
h.append(x)
x = downsample(x)
# bottleneck
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
# upsample
for block1, block2, attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = block1(x, t)
x = block2(x, t)
x = attn(x)
x = upsample(x)
return self.final_conv(x)
2.定义前向过程
def cosine_beta_schedule(timesteps, s=0.008):
"""
cosine schedule as proposed in https://arxiv.org/abs/2102.09672
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0.0001, 0.9999)
def linear_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
return torch.linspace(beta_start, beta_end, timesteps)
def quadratic_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2
def sigmoid_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
betas = torch.linspace(-6, 6, timesteps)
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
timesteps = 200
# define beta schedule
betas = linear_beta_schedule(timesteps=timesteps)
# define alphas
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
# calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
def extract(a, t, x_shape):
batch_size = t.shape[0]
out = a.gather(-1, t.cpu())
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
3.拿一张图片演示
得到第t时刻的噪声图片
# take time step
t = torch.tensor([40])
get_noisy_image(x_start, t)
获取多个时刻的图片状态
plot([get_noisy_image(x_start, torch.tensor([t])) for t in [0, 50, 100, 150, 199]])
4.定义损失
def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
if noise is None:
noise = torch.randn_like(x_start)
x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
predicted_noise = denoise_model(x_noisy, t)
if loss_type == 'l1':
loss = F.l1_loss(noise, predicted_noise)
elif loss_type == 'l2':
loss = F.mse_loss(noise, predicted_noise)
elif loss_type == "huber":
loss = F.smooth_l1_loss(noise, predicted_noise)
else:
raise NotImplementedError()
return loss
5.构造数据,Dataloader
from datasets import load_dataset
# load dataset from the hub
dataset = load_dataset("fashion_mnist")
image_size = 28
channels = 1
batch_size = 128
from torchvision import transforms
from torch.utils.data import DataLoader
# define image transformations (e.g. using torchvision)
transform = Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Lambda(lambda t: (t * 2) - 1)
])
# define function
def transforms(examples):
examples["pixel_values"] = [transform(image.convert("L")) for image in examples["image"]]
del examples["image"]
return examples
transformed_dataset = dataset.with_transform(transforms).remove_columns("label")
# create dataloader
dataloader = DataLoader(transformed_dataset["train"], batch_size=batch_size, shuffle=True)
5.训练模型
from torchvision.utils import save_image
epochs = 20
for epoch in range(epochs):
for step, batch in enumerate(dataloader):
optimizer.zero_grad()
batch_size = batch["pixel_values"].shape[0]
batch = batch["pixel_values"].to(device)
# Algorithm 1 line 3: sample t uniformally for every example in the batch
t = torch.randint(0, timesteps, (batch_size,), device=device).long()
loss = p_losses(model, batch, t, loss_type="huber")
if step % 500 == 0:
print("Loss:", loss.item())
loss.backward()
optimizer.step()
# save generated images
if step != 0 and step % save_and_sample_every == 0:
milestone = step // save_and_sample_every
batches = num_to_groups(4, batch_size)
all_images_list = list(map(lambda n: sample(model, batch_size=n, channels=channels), batches))
all_images = torch.cat(all_images_list, dim=0)
all_images = (all_images + 1) * 0.5
save_image(all_images, str(results_folder / f'sample-{milestone}.png'), nrow = 6)
损失逐渐降低
5.测试模型
生成结果还算可以,看一下中途时间步的生成结果
import matplotlib.animation as animation
random_index = 5
fig = plt.figure()
ims = []
for i in range(timesteps):
im = plt.imshow(samples[i][random_index].reshape(image_size, image_size, channels), cmap="gray", animated=True)
ims.append([im])
if i%20==0:
plt.show()
animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=1000)
animate.save('diffusion.gif')
plt.show()
随着时间步的回推,生成的图像逐渐清晰,模型效果还可以。
三、总结
本文简要介绍了一下diffusion模型的数学原理,以及演示了一下代码结果。但是这种模型的弊端也比较明显,就是生成的图像可控性太差,都是随机生成的,目前一些比较火的扩散模型都可以基于条件进行生成,比如根据一个文本生成一张图像,或者给定一张图像,生成另一种风格的图像。这种有条件的生成,也都是基于当前这个模型的,例如文本生成图像。训练数据是图像和文本对,训练过程中,输入UNet的不仅仅是第t时刻的图像以及t,而是加入了文本embedding,即第t时刻的图像以及t以及文本embedding 这三个输入到UNet中,去预测上一时刻状态。目前stable diffusion模型也是基于这个思想进行文本到图像生成。
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
所有评论(0)