提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档


diffusion模型介绍

Diffusion Model (扩散模型) 是一类生成模型, 和 VAE (Variational Autoencoder, 变分自动编码器), GAN (Generative Adversarial Network, 生成对抗网络) 等生成网络不同的是, 扩散模型在前向阶段对图像逐步施加噪声, 直至图像被破坏变成完全的高斯噪声, 然后在逆向阶段学习从高斯噪声还原为原始图像的过程。扩散模型主要分为两个步骤,一个是前向计算,一个是反向计算,前向计算就是对一张原始图片 x 0 x_0 x0,然后不断地向其中加入噪声,图片逐渐模糊,直至全是噪声。每个时间步加入的噪声都是服从正态分布的。反向过程是对一个服从正态分布的高斯噪声进行逐步去噪,直到得到一张清晰的图像。原始论文 https://arxiv.org/abs/2006.11239

在这里插入图片描述

在这里插入图片描述

一、模型数学原理简要介绍

首先定义一个 β t \beta_t βt,可以理解为加入噪声的权重, β t \beta_t βt是越要越来越大的,论文中是0.0001到0.002。因为随着一步步加入噪声,每一步必须要加入更大的噪声才能看出加了噪声,第一步加一点噪声就可以看出是加了噪的,但是越往后图片越模糊,必须加更大的噪声。
然后定义 α t = 1 − β t \alpha_t=1-\beta_t αt=1βt α t \alpha_t αt表示前一时刻的权重, β t \beta_t βt是加入噪声的权重,可见加入噪声权重越大,前一时刻图片的权重会随之变小。


首先列一个式子看看当前时刻t的图像 x t x_t xt和前一时刻图像 x t − 1 x_{t-1} xt1的关系:
x t = α t x t − 1 + 1 − α t z 1 \begin{equation} x_t=\sqrt{\alpha_t}x_{t-1}+\sqrt{1-\alpha_t}z_1 \end{equation} xt=αt xt1+1αt z1
其中 α t \sqrt{\alpha_t} αt 是前一时刻图像的权重, 1 − α t \sqrt{1-\alpha_t} 1αt 是当前时刻加入噪音的权重,即 β t \sqrt{\beta_t} βt 。其中 z 1 z_1 z1是加入的噪声,服从标准正态分布。


初始时已有的是图像 x 0 x_0 x0,如何利用 x 0 x_0 x0得到当前时刻的 x t x_t xt,如果一步一步算,计算效率太慢了,这种串行计算方式类似于RNN模型,只有计算完了当前时间步才能计算下一时间步的信息。所以考虑能不能进行递推,找到一个关系,直接从 x 0 x_0 x0计算出 x t x_t xt。我们再来看一下 x t − 1 x_{t-1} xt1 x t − 2 x_{t-2} xt2的关系:
x t − 1 = α t − 1 x t − 2 + 1 − α t − 1 z 2 \begin{equation} x_{t-1}=\sqrt{\alpha_{t-1}}x_{t-2}+\sqrt{1-\alpha_{t-1}}z_2 \end{equation} xt1=αt1 xt2+1αt1 z2
把(2)带入(1)中,可以得到 x t x_t xt x t − 2 x_{t-2} xt2的关系:
x t = α t ( α t − 1 x t − 2 + 1 − α t − 1 z 2 ) + 1 − α t z 1 = α t α t − 1 x t − 2 + ( α t ( 1 − α t − 1 ) z 2 + 1 − α t z 1 ) \begin{equation} \begin{split} x_t&=\sqrt{\alpha_t}(\sqrt{\alpha_{t-1}}x_{t-2}+\sqrt{1-\alpha_{t-1}}z_2) + \sqrt{1-\alpha_t}z_1 \\ &= \sqrt{\alpha_t\alpha_{t-1}}x_{t-2} + (\sqrt{\alpha_t(1-\alpha_{t-1})}z_2+\sqrt{1-\alpha_t}z_1) \end{split} \end{equation} xt=αt (αt1 xt2+1αt1 z2)+1αt z1=αtαt1 xt2+(αt(1αt1) z2+1αt z1)
其中加入的噪声 z 1 , z 2 . . . . z_1,z_2.... z1,z2....都属于标准正态分布。上式中后半部分的 α t ( 1 − α t − 1 ) z 2 \sqrt{\alpha_t(1-\alpha_{t-1})}z_2 αt(1αt1) z2,属于正态分布 N ( 0 , α t ( 1 − α t − 1 ) ) N(0,\alpha_t(1-\alpha_{t-1})) N(0,αt(1αt1)) , 1 − α t z 1 \sqrt{1-\alpha_t}z_1 1αt z1属于正态分布 N ( 0 , 1 − α t ) N(0,1-\alpha_t) N(0,1αt)。因为:
若 X ∼ N ( 0 , 1 ) 则 a X ∼ N ( 0 , a 2 ) 若X \sim N(0,1)则aX \sim N(0,a^2) XN(0,1)aXN(0,a2)

又由于:
N ( 0 , σ 1 2 ) + N ( 0 , σ 2 2 ) ∼ N ( 0 , σ 1 2 + σ 2 2 ) N(0,\sigma^2_1)+N(0,\sigma^2_2)\sim N(0,\sigma_1^2+\sigma^2_2) N(0,σ12)+N(0,σ22)N(0,σ12+σ22)
所以,基于(3)式继续化简得到:
x t = α t α t − 1 x t − 2 + 1 − α t α t − 1 z ˉ 2 x_t=\sqrt{\alpha_t\alpha_{t-1}}x_{t-2} + \sqrt{1-\alpha_t\alpha_{t-1}}\bar{z}_2 xt=αtαt1 xt2+1αtαt1 zˉ2

可以发现规律,前一项是 α t α t − 1 \sqrt{\alpha_t\alpha_{t-1}} αtαt1 后一项是 1 − α t α t − 1 \sqrt{1-\alpha_t\alpha_{t-1}} 1αtαt1 ,可以看出是一个累乘的形式,如果继续把 x t − 2 x_{t-2} xt2再继续带入,一步步带入到 x 0 x_0 x0可以得到:
x t = α ˉ t x 0 + 1 − α ˉ t z t \begin{equation} x_t=\sqrt{\bar{\alpha}_t}x_0+\sqrt{1-\bar{\alpha}_t}z_t \end{equation} xt=αˉt x0+1αˉt zt
其中 α ˉ t \bar{\alpha}_t αˉt是一个累乘。

计算出了 x t 和 x 0 x_t和x_0 xtx0的关系,因此不需要一步一步进行计算,而是直接根据一个时间步t以及初始状态 x 0 x_0 x0就可以得到t时刻的状态。


目前只是完成了前向过程,扩散模型的目的是根据噪声去还原原始的图像,最重要的部分是反向过程。
我们目前知道的是根据 x t − 1 x_{t-1} xt1求出 x t x_t xt的分布,但是根据 x t x_t xt去求出 x t − 1 x_{t-1} xt1的分布不太容易,需要求解逆向过程,不好求解,我们借助贝叶斯公式:
P ( A ∣ B ) = P ( B ∣ A ) ∗ P ( A ) P ( B ) P(A|B)=P(B|A)*\frac{P(A)}{P(B)} P(AB)=P(BA)P(B)P(A)

得到:
q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 , x 0 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) \begin{equation} q(x_{t-1}|x_t,x_0)=q(x_t|x_{t-1},x_0)\frac{q(x_{t-1}|x_0)}{q(x_t|x_0)} \end{equation} q(xt1xt,x0)=q(xtxt1,x0)q(xtx0)q(xt1x0)

其中,根据(4),式子右边的三项都可以计算出来:
q ( x t − 1 ∣ x 0 ) = α ˉ t − 1 x 0 + 1 − α ˉ t − 1 z ∼ N ( α ˉ t − 1 x 0 , 1 − α ˉ t − 1 ) q ( x t ∣ x 0 ) = α ˉ t x 0 + 1 − α ˉ t z ∼ N ( α ˉ t x 0 , 1 − α ˉ t ) q ( x t ∣ x t − 1 , x 0 ) = α t x t − 1 + 1 − α t z ∼ N ( α t x t − 1 , 1 − α t ) \begin{equation} \begin{split} q(x_{t-1}|x_0)=\sqrt{\bar{\alpha}_{t-1}}x_0+\sqrt{1-\bar{\alpha}_{t-1}}z \sim N(\sqrt{\bar{\alpha}_{t-1}}x_0,1-\bar{\alpha}_{t-1}) \\ q(x_t|x_0)=\sqrt{\bar{\alpha}_{t}}x_0+\sqrt{1-\bar{\alpha}_{t}}z \sim N(\sqrt{\bar{\alpha}_{t}}x_0,1-\bar{\alpha}_{t}) \\ q(x_t|x_{t-1},x_0) = \sqrt{\alpha_t}x_{t-1}+\sqrt{1-\alpha_t}z \sim N(\sqrt{\alpha_{t}}x_{t-1},1-\alpha_{t}) \end{split} \end{equation} q(xt1x0)=αˉt1 x0+1αˉt1 zN(αˉt1 x0,1αˉt1)q(xtx0)=αˉt x0+1αˉt zN(αˉt x0,1αˉt)q(xtxt1,x0)=αt xt1+1αt zN(αt xt1,1αt)

由于z服从标准正态分布,所以有了后边的部分服从正态分布。
正态分布公式为:
f ( x ) = 1 σ 2 π e x p ( − 1 2 ( x − μ ) 2 σ 2 ) \begin{equation} f(x)=\frac{1}{\sigma\sqrt{2\pi}} exp(-\frac{1}{2}\frac{(x-\mu)^2}{\sigma^2} ) \end{equation} f(x)=σ2π 1exp(21σ2(xμ)2)

将(6)式中三项服从的正态分布带入到(7) 再带入到(5),可以得到:
q ( x t − 1 ∣ x t , x 0 ) ∝ e x p ( − 1 2 ( ( x t − α t x t − 1 ) 2 β t + ( x t − 1 − α ˉ t − 1 x 0 ) 2 1 − α ˉ t − 1 − ( x t − α ˉ t x 0 ) 2 1 − α ˉ t ) ) \begin{equation} q(x_{t-1}|x_t,x_0) \propto exp(-\frac{1}{2}(\frac{(x_t-\sqrt{\alpha_t}x_{t-1})^2}{\beta_t}+\frac{(x_{t-1}-\sqrt{\bar{\alpha}_{t-1}}x_0)^2}{1-\bar{\alpha}_{t-1}}-\frac{(x_t-\sqrt{\bar{\alpha}_t}x_0)^2}{1-\bar{\alpha}_t})) \end{equation} q(xt1xt,x0)exp(21(βt(xtαt xt1)2+1αˉt1(xt1αˉt1 x0)21αˉt(xtαˉt x0)2))
正态分布展开后,乘法相当于加,除法相当于减,把他们汇总得到上式。


把正态分布中exp的那项展开后得到:
e x p ( − ( x − μ ) 2 2 σ 2 ) = e x p ( − 1 2 ( 1 σ 2 x 2 − 2 μ σ 2 x + μ 2 σ 2 ) ) \begin{equation} exp(-\frac{(x-\mu)^2}{2\sigma^2})=exp(-\frac{1}{2}(\frac{1}{\sigma^2}x^2-\frac{2\mu}{\sigma^2}x+\frac{\mu^2}{\sigma^2})) \end{equation} exp(2σ2(xμ)2)=exp(21(σ21x2σ22μx+σ2μ2))

我们是要得到t-1时刻与t时刻是什么关系,所以我们将(8)式展开化简,保留 x t − 1 x_{t-1} xt1的项,化简成类似(9)的形式。过程就不一步步列了,(8)式化简后的结果为:
e x p ( − 1 2 ( ( α t β t + 1 1 − α ˉ t − 1 ) x t − 1 2 − ( 2 α t β t x t + 2 α ˉ t − 1 1 − α ˉ t − 1 x 0 ) x t − 1 + C ( x t , x 0 ) ) ) \begin{equation} exp(-\frac{1}{2}((\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha}_{t-1}})x^2_{t-1}-(\frac{2\sqrt{\alpha_t}}{\beta_t}x_t+\frac{2\sqrt{\bar{\alpha}_{t-1}}}{1-\bar{\alpha}_{t-1}}x_0)x_{t-1}+C(x_t,x_0))) \end{equation} exp(21((βtαt+1αˉt11)xt12(βt2αt xt+1αˉt12αˉt1 x0)xt1+C(xt,x0)))
其中C为与 x t − 1 x_{t-1} xt1无关的常数项
现在式(9)和(10)已经化简为一样的形式,进行对应后可以推出 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt1xt,x0)所服从的正态分布的方差 σ \sigma σ和均值 μ \mu μ
σ = β t ( 1 − α ˉ t − 1 ) α t ( 1 − α ˉ t − 1 ) + β t \begin{equation} \sigma = \frac{\beta_t(1-\bar{\alpha}_{t-1})}{\alpha_t(1-\bar{\alpha}_{t-1})+\beta_t} \end{equation} σ=αt(1αˉt1)+βtβt(1αˉt1)

所有 α \alpha α β \beta β都已知,所以方差可以理解为是个常量。
μ = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t + α ˉ t − 1 β t 1 − α ˉ t x 0 \begin{equation} \mu=\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}x_t+\frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t}x_0 \end{equation} μ=1αˉtαt (1αˉt1)xt+1αˉtαˉt1 βtx0
但是 x 0 x_0 x0是多少呢,上面推出, x t x_t xt可以根据 x 0 x_0 x0算出,即式(4),然后我们反推出 x 0 x_0 x0:
x 0 = 1 α ˉ t ( x t − 1 − α ˉ t z t ) x_0=\frac{1}{\sqrt{\bar{\alpha}_t}}(x_t-\sqrt{1-\bar{\alpha}_t}z_t) x0=αˉt 1(xt1αˉt zt)

x 0 x_0 x0带入式(12)得到
μ = 1 α t ( x t − β t 1 − α ˉ t z t ) \begin{equation} \mu=\frac{1}{\sqrt{\alpha_t}}(x_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}z_t) \end{equation} μ=αt 1(xt1αˉt βtzt)

现在已经求出了 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt1xt,x0)服从的正态分布的均值和方差,但是式(13)中的 z t z_t zt是不可知的, z t z_t zt是t时刻往模型中加入的噪声,这个加入的噪声是不可知的,如果反向过程每一时刻加入的噪声是多少已知,直接在当前时刻减去噪声就得到了前一时刻的状态,显然是不合理的。既然推不出来这个 z t z_t zt,那就求一个近似解,利用深度学习模型去拟合这个噪声的近似解。一般采用UNet这个模型。

把当前时刻 x t x_t xt以及当前时间步t输入到模型中,让模型去拟合当前时间所加入的噪声 z t z_t zt,这个真实的 z t z_t zt是已知的,因为在前向过程中,每一步加入了多少损失都是可以记录下来的,所以就基于前向计算过程中记录的噪声为标签,去训练这个深度学习模型,简单表示为 z t = m o d e l ( x t , t ) z_t=model(x_t,t) zt=model(xt,t)。图解如下:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

接下来就是整体的流程
在这里插入图片描述
训练过程中,输入第t的时刻状态 x t x_t xt以及时刻t到模型中,让模型去拟合当前时刻所加入的噪声。推理阶段,直接输入一个噪声点,基于式(13)以及模型预测值推断出前一时刻状态,依次迭代到 x 0 x_0 x0即为生成的图像。

二、pytorch代码实现(主要部分代码演示)

使用fishon mnist这个数据集进行实验,这个数据集比较小,容易训练。

1.UNet结构

class Unet(nn.Module):
    def __init__(
        self,
        dim,
        init_dim=None,
        out_dim=None,
        dim_mults=(1, 2, 4, 8),
        channels=3,
        with_time_emb=True,
        resnet_block_groups=8,
        use_convnext=True,
        convnext_mult=2,
    ):
        super().__init__()

        # determine dimensions
        self.channels = channels

        init_dim = default(init_dim, dim // 3 * 2)
        self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))
        
        if use_convnext:
            block_klass = partial(ConvNextBlock, mult=convnext_mult)
        else:
            block_klass = partial(ResnetBlock, groups=resnet_block_groups)

        # time embeddings
        if with_time_emb:
            time_dim = dim * 4
            self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(dim),
                nn.Linear(dim, time_dim),
                nn.GELU(),
                nn.Linear(time_dim, time_dim),
            )
        else:
            time_dim = None
            self.time_mlp = None

        # layers
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(
                nn.ModuleList(
                    [
                        block_klass(dim_in, dim_out, time_emb_dim=time_dim),
                        block_klass(dim_out, dim_out, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                        Downsample(dim_out) if not is_last else nn.Identity(),
                    ]
                )
            )

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (num_resolutions - 1)

            self.ups.append(
                nn.ModuleList(
                    [
                        block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                        Upsample(dim_in) if not is_last else nn.Identity(),
                    ]
                )
            )

        out_dim = default(out_dim, channels)
        self.final_conv = nn.Sequential(
            block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
        )

    def forward(self, x, time):
        x = self.init_conv(x)

        t = self.time_mlp(time) if exists(self.time_mlp) else None

        h = []

        # downsample
        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
            h.append(x)
            x = downsample(x)

        # bottleneck
        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        # upsample
        for block1, block2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
            x = upsample(x)

        return self.final_conv(x)

2.定义前向过程

def cosine_beta_schedule(timesteps, s=0.008):
    """
    cosine schedule as proposed in https://arxiv.org/abs/2102.09672
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)

def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

def quadratic_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2

def sigmoid_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    betas = torch.linspace(-6, 6, timesteps)
    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
timesteps = 200

# define beta schedule
betas = linear_beta_schedule(timesteps=timesteps)

# define alphas 
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)

# calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

3.拿一张图片演示

在这里插入图片描述
得到第t时刻的噪声图片

# take time step
t = torch.tensor([40])

get_noisy_image(x_start, t)

在这里插入图片描述

获取多个时刻的图片状态

plot([get_noisy_image(x_start, torch.tensor([t])) for t in [0, 50, 100, 150, 199]])

在这里插入图片描述

4.定义损失

def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
    if noise is None:
        noise = torch.randn_like(x_start)

    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    predicted_noise = denoise_model(x_noisy, t)

    if loss_type == 'l1':
        loss = F.l1_loss(noise, predicted_noise)
    elif loss_type == 'l2':
        loss = F.mse_loss(noise, predicted_noise)
    elif loss_type == "huber":
        loss = F.smooth_l1_loss(noise, predicted_noise)
    else:
        raise NotImplementedError()

    return loss

5.构造数据,Dataloader

from datasets import load_dataset

# load dataset from the hub
dataset = load_dataset("fashion_mnist")
image_size = 28
channels = 1
batch_size = 128
from torchvision import transforms
from torch.utils.data import DataLoader

# define image transformations (e.g. using torchvision)
transform = Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Lambda(lambda t: (t * 2) - 1)
])

# define function
def transforms(examples):
   examples["pixel_values"] = [transform(image.convert("L")) for image in examples["image"]]
   del examples["image"]

   return examples

transformed_dataset = dataset.with_transform(transforms).remove_columns("label")

# create dataloader
dataloader = DataLoader(transformed_dataset["train"], batch_size=batch_size, shuffle=True)

5.训练模型

from torchvision.utils import save_image

epochs = 20

for epoch in range(epochs):
    for step, batch in enumerate(dataloader):
      optimizer.zero_grad()

      batch_size = batch["pixel_values"].shape[0]
      batch = batch["pixel_values"].to(device)

      # Algorithm 1 line 3: sample t uniformally for every example in the batch
      t = torch.randint(0, timesteps, (batch_size,), device=device).long()

      loss = p_losses(model, batch, t, loss_type="huber")

      if step % 500 == 0:
        print("Loss:", loss.item())

      loss.backward()
      optimizer.step()

      # save generated images
      if step != 0 and step % save_and_sample_every == 0:
        milestone = step // save_and_sample_every
        batches = num_to_groups(4, batch_size)
        all_images_list = list(map(lambda n: sample(model, batch_size=n, channels=channels), batches))
        all_images = torch.cat(all_images_list, dim=0)
        all_images = (all_images + 1) * 0.5
        save_image(all_images, str(results_folder / f'sample-{milestone}.png'), nrow = 6)

在这里插入图片描述
损失逐渐降低

5.测试模型

在这里插入图片描述
生成结果还算可以,看一下中途时间步的生成结果

import matplotlib.animation as animation

random_index = 5

fig = plt.figure()
ims = []
for i in range(timesteps):
    im = plt.imshow(samples[i][random_index].reshape(image_size, image_size, channels), cmap="gray", animated=True)
    ims.append([im])
    if i%20==0:
        plt.show()

animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=1000)
animate.save('diffusion.gif')
plt.show()

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
随着时间步的回推,生成的图像逐渐清晰,模型效果还可以。

三、总结

本文简要介绍了一下diffusion模型的数学原理,以及演示了一下代码结果。但是这种模型的弊端也比较明显,就是生成的图像可控性太差,都是随机生成的,目前一些比较火的扩散模型都可以基于条件进行生成,比如根据一个文本生成一张图像,或者给定一张图像,生成另一种风格的图像。这种有条件的生成,也都是基于当前这个模型的,例如文本生成图像。训练数据是图像和文本对,训练过程中,输入UNet的不仅仅是第t时刻的图像以及t,而是加入了文本embedding,即第t时刻的图像以及t以及文本embedding 这三个输入到UNet中,去预测上一时刻状态。目前stable diffusion模型也是基于这个思想进行文本到图像生成。

Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐