1、前言

本篇文章,我们讲NCSN,也就是噪声条件分数网络。这是宋飏老师在2019年提出的模型,思路与传统的生成模型大不相同,令人拍案叫绝!!!)

视频:[噪声条件得分(分数)网络——NCSN原理解析-哔哩哔哩]

Ps:这篇文章我简单讲一下思路就算了,过程并不严谨,因为这个内容并不是很重要

2、引入

回忆一下梯度下降,假设我们有一个二次函数
f ( x ) = ( 0.5 x − 3 ) 2 f(x)=(0.5x-3)^2 f(x)=(0.5x3)2
导数为 f ′ ( x ) = ( 0.5 x − 3 ) f'(x)=(0.5x-3) f(x)=(0.5x3),使用梯度下降
x t + 1 = x t − 0.1 f ′ ( x t ) (1) x_{t+1}=x_t-0.1f'(x_t)\tag{1} xt+1=xt0.1f(xt)(1)
其中 x t 、 x t + 1 x_t、x_{t+1} xtxt+1表示优化前和优化后的x对应的值, 0.1 0.1 0.1是步长。初始化蓝色点 x t = − 6 x_t=-6 xt=6,迭代100轮梯度下降,就可以得到下面的图(可以看到蓝色点逐渐向着函数最低点靠近)

在这里插入图片描述

为什么会这样?因为梯度总是指向函数值上升的方向。而Eq.(1),是减去梯度,相当于对梯度取反方向。于是x的值就沿着函数值下降的方向走了。如果换成梯度上升,则Eq.(1)改为
x t + 1 = x t + 0.1 f ′ ( x t ) (2) x_{t+1}=x_t+0.1f'(x_t)\tag{2} xt+1=xt+0.1f(xt)(2)
对应图像为

在这里插入图片描述

再回忆一下一维高斯分布的概率密度的图像

在这里插入图片描述

当y值(密度值)取到最高点,其对应样本点在均值处

此时我们注意到,高斯分布的图像,与Eq.(2)何其相像,那我们把Eq.(2)里面的 f ( x ) f(x) f(x)当作是高斯分布的密度函数,而 x x x则对应高斯分布的样本点
x t + 1 = x t + 0.1 f ′ ( x t ) x_{t+1}=x_t+0.1f'(x_t) xt+1=xt+0.1f(xt)
那么这个梯度上升的意思就变成了,对于一个样本 x t x_t xt,不断往概率密度函数 f ′ ( x t ) f'(x_t) f(xt)密度值高的地方靠近。如果优化到最优点,那么图像就会变成这样

在这里插入图片描述

也就是说,样本点 x t x_t xt,最终会走到概率值最高对应的点,那么此时的样本点 x t x_t xt,就可以认为是从高斯分布中采样出来的一个概率最高的样本。我们写成概率分布的一般形式
x t + 1 = x t + α ∇ x P ( x t ) x_{t+1}=x_t+\alpha \nabla_xP(x_t) xt+1=xt+αxP(xt)
α \alpha α表示步长,比如之前的0.1, ∇ x \nabla_x x是对x求梯度。

我们在 P ( x t ) P(x_t) P(xt)前面取一个log对数,不改变单调性,仍然会使 x t x_t xt收敛到最优值
x t + 1 = x t + α ∇ x log ⁡ P ( x t ) x_{t+1}=x_t+\alpha \nabla_x\log P(x_t) xt+1=xt+αxlogP(xt)
更一般的,从一个概率分布中采样,我们往往会存在一些偏差项,于是我们加上一个随机噪声
x t + 1 = x t + α ∇ x log ⁡ P ( x t ) + 2 α z t (3) x_{t+1}=x_t+\alpha \nabla_x\log P(x_t)+\sqrt{2\alpha}z_t\tag{3} xt+1=xt+αxlogP(xt)+2α zt(3)
2 α \sqrt{2\alpha} 2α 是缩放系数,而 z t z_t zt是标准高斯分布,加上一个噪声后, x t x_t xt的收敛值会在概率最高点处不断徘徊

图像表示为

在这里插入图片描述

现在,我们更进一步,我们把 x t x_t xt当作是一个随机初始化的图像,然后 P ( x ) P(x) P(x)是我们训练图像的所对应的分布,通过不断执行Eq.(3),便可以让随机初始化的图像,不断往 P ( x ) P(x) P(x)概率最高点周围靠近,那么就间接说明,经过了大T步Eq.(3),得到的 x t x_t xt,可以认为是从 P ( x ) P(x) P(x)中采样出来的。

仔细看一下,这不就是一个生成图像的过程吗?

这种方式,又被称为郎之万动力采样。emmmmm,不懂,物理学的东西。。。

我们看一个可视化的过程(图像来自参考①)

在这里插入图片描述

3、目标函数

既然Eq.(3)能够通过迭代的方式,生成图像,那自然只需要求解Eq.(3)就可以了。不幸的是,我们没办法求解

我们的训练图像,它们所服从的概率分布往往及其复杂,也就是说 P ( x ) P(x) P(x)是难以求解的​,好在我们的目标并不是求出 P ( x ) P(x) P(x),而是对应的梯度(也称为分数)
L = 1 2 E P d a t a ( x ) [ ∣ ∣ s θ ( x ) − ∇ x log ⁡ P d a t a ( x ) ∣ ∣ 2 2 ] (4) L_{}=\frac{1}{2}\mathbb{E}_{P_{data}(x)}\left[||s_\theta(x)-\nabla_x\log P_{data}(x)||_2^2\right]\tag{4} L=21EPdata(x)[∣∣sθ(x)xlogPdata(x)22](4)
P d a t a P_{data} Pdata表示训练数据所服从的分布

也就是通过最小化上式,便可得到 s θ ( x ) ≈ ∇ x log ⁡ P d a t a ( x ) s_\theta(x)\approx \nabla_x\log P_{data}(x) sθ(x)xlogPdata(x)

4、问题

理论上,我们直接求解Eq.(4)就可以了,但是,我们样本所服从的分布往往是服从,概率分布中往往存在一些低密度区域,那么对应的样本就很少。

而样本少,意味着对应为止的梯度分数,得不到很好的训练,那么神经网络在那些样本点就很容易估不准。作者博客给出了一张很形象的图像(图像来自参考①)

在这里插入图片描述

可以看到,数据的密度分别都在左下角和右上角,那么这些区域就能够用神经网络得到很好的拟合,对应Accurate区域。相反,低密度区域,没有得到很好的拟合,对应Inaccurate区域。

当我们使用郎之万动力采样的时候,随机初始化一个 x 0 x_0 x0,它落在低密度区域的概率非常之高。而低密度的区域没有经过很好的训练,所以郎之万动力采样在短时间内很难得到较好的结果。

那么,该如何解决这个问题呢?一个很好的方法就是——加噪声

我们通过对图像加入随机扰动噪声,会填充原本的低密度区域,从而让整个区域看起来较为的均匀(图像来自参考①)

在这里插入图片描述

也就是这样,让原本的密度点扩张开来。

加噪的过程我们可以表示为 x ~ = x + σ z \tilde x=x+\sigma z x~=x+σz x x x表示原始图像, x ~ \tilde x x~表示加噪后的图像。

我们用 q ( x ~ ∣ x ) ∼ N ( x , σ 2 I ) q(\tilde x|x)\sim N(x,\sigma^2I) q(x~x)N(x,σ2I)去表示这个加噪过程

于是Eq.(3)就可以变成
L = 1 2 E P d a t a ( x ) , x ~ ∼ N ( x , σ 2 I ) [ ∣ ∣ s θ ( x + σ z ) − ∇ x ~ log ⁡ q ( x ~ ∣ x ) ∣ ∣ 2 2 ] (5) L_{}=\frac{1}{2}\mathbb{E}_{P_{data}(x),\tilde x\sim N(x,\sigma^2I)}\left[||s_\theta(x+\sigma z)-\nabla_{\tilde x}\log q(\tilde x|x)||_2^2\right]\tag{5} L=21EPdata(x),x~N(x,σ2I)[∣∣sθ(x+σz)x~logq(x~x)22](5)
emmmm,我感觉这样讲貌似挺合理的,但是它是需要证明的,也就是证明Eq.(4)、Eq.(5)的优化等价性。我就不证明了,证明过程在参考论文②,并不难,读者自己看一下就知道了

除此之外,真正导致需要加噪的,其实有其他原因,我只讲了其中一个。其他原因请看参考②,里面讲的非常之详细。我也懒得写了

如果我们加的噪声足够小,那么 P d a t a ( x ) ≈ q ( x ) P_{data}(x)\approx q(x) Pdata(x)q(x)

现在,我们预测的是加噪后的梯度分数,通过加噪的过程,也避免了直接求解 P ( x ) P(x) P(x)的问题。那我们来看一下这个等式可以变成什么吧

因为 q ( x ~ ∣ x ) q(\tilde x|x) q(x~x)是服从高斯分布的,是完全可以求出来的,所以梯度为
∇ x ~ log ⁡ q ( x ~ ∣ x ) = ∇ x ~ log ⁡ 1 2 π σ 2 d exp ⁡ { − ∣ ∣ x ~ − x ∣ ∣ 2 2 σ 2 } = ∇ x ~ ( log ⁡ 1 2 π σ 2 d − ∣ ∣ x ~ − x ∣ ∣ 2 2 σ 2 ) = − 2 ( x ~ − x ) 2 σ 2 = − x ~ − x σ 2 = − z σ \begin{aligned}\nabla_{\tilde x}\log q(\tilde x|x)=&\nabla_{\tilde x}\log \frac{1}{\sqrt{2\pi\sigma^2}^d}\exp \left\{-\frac{||\tilde x-x||^2}{2\sigma^2}\right\}\\=&\nabla_{\tilde x}\left(\log \frac{1}{\sqrt{2\pi\sigma^2}^d}-\frac{||\tilde x-x||^2}{2\sigma^2}\right)\\=&-\frac{2(\tilde x-x)}{2\sigma^2}\\=&-\frac{\tilde x -x}{\sigma^2}\\=&-\frac{z}{\sigma}\end{aligned} x~logq(x~x)=====x~log2πσ2 d1exp{2σ2∣∣x~x2}x~(log2πσ2 d12σ2∣∣x~x2)2σ22(x~x)σ2x~xσz
所以损失函数就可以变成
L = 1 2 E P d a t a ( x ) , x ~ ∼ N ( x , σ 2 I ) [ ∣ ∣ s θ ( x + σ z ) + x ~ − x σ 2 ∣ ∣ 2 2 ] L=\frac{1}{2}\mathbb{E}_{P_{data}(x),\tilde x\sim N(x,\sigma^2I)}\left[||s_\theta(x+\sigma z)+\frac{\tilde x -x}{\sigma^2}||_2^2\right] L=21EPdata(x),x~N(x,σ2I)[∣∣sθ(x+σz)+σ2x~x22]

按理说,我们只需要最优化这个目标函数即可。

可问题又来了

我们该如何加入噪声呢?加多少?加的小了,低密度区域没有得到很好的填充。加多了,直接改变原本的数据分布了,这显然也不行。

我们干脆一不做二不休,我们加多个量级噪声,不同量级都进行训练。

当训练完成之后,就得到了不同噪声强度的噪声条件得分网络。

假设不同强度等级的噪声有S个, { σ i } i = 1 S \{\sigma_i\}_{i=1}^S {σi}i=1S,我们看一张图(里面显示了三个噪声强度的情况,图像来自参考①)

在这里插入图片描述

那么进行采样的时候,就可以从高强度的噪声,进行郎之万动力采样,然后慢慢降低噪声的强度。总而言之,就是每个噪声强度,都进行一轮郎之万动力采样,比如下图(图像来自参考①)(Gif图像太大,上传不了…看视频里面吧)

假设有S个噪声强度,那么就可以变成
L = 1 S ∑ i = 1 S λ i 1 2 E P d a t a ( x ) , x ~ ∼ N ( x , σ i 2 I ) [ ∣ ∣ s θ ( x + σ i z , σ i ) + x ~ i − x σ i 2 ∣ ∣ 2 2 ] L=\frac{1}{S}\sum\limits_{i=1}^S\lambda_i\frac{1}{2}\mathbb{E}_{P_{data}(x),\tilde x\sim N(x,\sigma_i^2I)}\left[||s_\theta(x+\sigma_i z,\sigma_i)+\frac{\tilde x_i -x}{\sigma_i^2}||_2^2\right] L=S1i=1Sλi21EPdata(x),x~N(x,σi2I)[∣∣sθ(x+σiz,σi)+σi2x~ix22]
x ~ i \tilde x_i x~i表示在噪声强度为 σ i \sigma_i σi的加噪图像。 λ i \lambda_i λi代表的是一个加权系数.一般情况下,我们取 λ i = σ i 2 \lambda_i=\sigma^2_i λi=σi2

对于噪声强度数量,一般是数百到数千;噪声强度选择一般采用几何级数。

采样的时候正如前面所说,先在高强度噪声量级进行郎之万动力采样,而后慢慢降低,所以采样方法为

5、结束

好了,本篇文章到此为止,如有问题,还望指出,阿里嘎多!!!

在这里插入图片描述

6、参考

①Generative Modeling by Estimating Gradients of the Data Distribution | Yang Song (yang-song.net)

②基于分数的生成模型(Score-based generative models) — 张振虎的博客 张振虎 文档 (zhangzhenhu.com)

Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐