噪声条件分数网络——NCSN原理解析
本篇文章,我们讲NCSN,也就是噪声条件分数网络。这是宋飏老师在2019年提出的模型,思路与传统的生成模型大不相同,令人拍案叫绝!!![噪声条件得分(分数)网络——NCSN原理解析-哔哩哔哩]Ps:这篇文章我简单讲一下思路就算了,过程并不严谨,因为这个内容并不是很重要。
1、前言
本篇文章,我们讲NCSN,也就是噪声条件分数网络。这是宋飏老师在2019年提出的模型,思路与传统的生成模型大不相同,令人拍案叫绝!!!)
视频:[噪声条件得分(分数)网络——NCSN原理解析-哔哩哔哩]
Ps:这篇文章我简单讲一下思路就算了,过程并不严谨,因为这个内容并不是很重要
2、引入
回忆一下梯度下降,假设我们有一个二次函数
f
(
x
)
=
(
0.5
x
−
3
)
2
f(x)=(0.5x-3)^2
f(x)=(0.5x−3)2
导数为
f
′
(
x
)
=
(
0.5
x
−
3
)
f'(x)=(0.5x-3)
f′(x)=(0.5x−3),使用梯度下降
x
t
+
1
=
x
t
−
0.1
f
′
(
x
t
)
(1)
x_{t+1}=x_t-0.1f'(x_t)\tag{1}
xt+1=xt−0.1f′(xt)(1)
其中
x
t
、
x
t
+
1
x_t、x_{t+1}
xt、xt+1表示优化前和优化后的x对应的值,
0.1
0.1
0.1是步长。初始化蓝色点
x
t
=
−
6
x_t=-6
xt=−6,迭代100轮梯度下降,就可以得到下面的图(可以看到蓝色点逐渐向着函数最低点靠近)
为什么会这样?因为梯度总是指向函数值上升的方向。而Eq.(1),是减去梯度,相当于对梯度取反方向。于是x的值就沿着函数值下降的方向走了。如果换成梯度上升,则Eq.(1)改为
x
t
+
1
=
x
t
+
0.1
f
′
(
x
t
)
(2)
x_{t+1}=x_t+0.1f'(x_t)\tag{2}
xt+1=xt+0.1f′(xt)(2)
对应图像为
再回忆一下一维高斯分布的概率密度的图像
当y值(密度值)取到最高点,其对应样本点在均值处
此时我们注意到,高斯分布的图像,与Eq.(2)何其相像,那我们把Eq.(2)里面的
f
(
x
)
f(x)
f(x)当作是高斯分布的密度函数,而
x
x
x则对应高斯分布的样本点
x
t
+
1
=
x
t
+
0.1
f
′
(
x
t
)
x_{t+1}=x_t+0.1f'(x_t)
xt+1=xt+0.1f′(xt)
那么这个梯度上升的意思就变成了,对于一个样本
x
t
x_t
xt,不断往概率密度函数
f
′
(
x
t
)
f'(x_t)
f′(xt)密度值高的地方靠近。如果优化到最优点,那么图像就会变成这样
也就是说,样本点
x
t
x_t
xt,最终会走到概率值最高对应的点,那么此时的样本点
x
t
x_t
xt,就可以认为是从高斯分布中采样出来的一个概率最高的样本。我们写成概率分布的一般形式
x
t
+
1
=
x
t
+
α
∇
x
P
(
x
t
)
x_{t+1}=x_t+\alpha \nabla_xP(x_t)
xt+1=xt+α∇xP(xt)
α
\alpha
α表示步长,比如之前的0.1,
∇
x
\nabla_x
∇x是对x求梯度。
我们在
P
(
x
t
)
P(x_t)
P(xt)前面取一个log对数,不改变单调性,仍然会使
x
t
x_t
xt收敛到最优值
x
t
+
1
=
x
t
+
α
∇
x
log
P
(
x
t
)
x_{t+1}=x_t+\alpha \nabla_x\log P(x_t)
xt+1=xt+α∇xlogP(xt)
更一般的,从一个概率分布中采样,我们往往会存在一些偏差项,于是我们加上一个随机噪声
x
t
+
1
=
x
t
+
α
∇
x
log
P
(
x
t
)
+
2
α
z
t
(3)
x_{t+1}=x_t+\alpha \nabla_x\log P(x_t)+\sqrt{2\alpha}z_t\tag{3}
xt+1=xt+α∇xlogP(xt)+2αzt(3)
2
α
\sqrt{2\alpha}
2α是缩放系数,而
z
t
z_t
zt是标准高斯分布,加上一个噪声后,
x
t
x_t
xt的收敛值会在概率最高点处不断徘徊
图像表示为
现在,我们更进一步,我们把 x t x_t xt当作是一个随机初始化的图像,然后 P ( x ) P(x) P(x)是我们训练图像的所对应的分布,通过不断执行Eq.(3),便可以让随机初始化的图像,不断往 P ( x ) P(x) P(x)概率最高点周围靠近,那么就间接说明,经过了大T步Eq.(3),得到的 x t x_t xt,可以认为是从 P ( x ) P(x) P(x)中采样出来的。
仔细看一下,这不就是一个生成图像的过程吗?
这种方式,又被称为郎之万动力采样。emmmmm,不懂,物理学的东西。。。
我们看一个可视化的过程(图像来自参考①)
3、目标函数
既然Eq.(3)能够通过迭代的方式,生成图像,那自然只需要求解Eq.(3)就可以了。不幸的是,我们没办法求解
我们的训练图像,它们所服从的概率分布往往及其复杂,也就是说
P
(
x
)
P(x)
P(x)是难以求解的,好在我们的目标并不是求出
P
(
x
)
P(x)
P(x),而是对应的梯度(也称为分数)
L
=
1
2
E
P
d
a
t
a
(
x
)
[
∣
∣
s
θ
(
x
)
−
∇
x
log
P
d
a
t
a
(
x
)
∣
∣
2
2
]
(4)
L_{}=\frac{1}{2}\mathbb{E}_{P_{data}(x)}\left[||s_\theta(x)-\nabla_x\log P_{data}(x)||_2^2\right]\tag{4}
L=21EPdata(x)[∣∣sθ(x)−∇xlogPdata(x)∣∣22](4)
P
d
a
t
a
P_{data}
Pdata表示训练数据所服从的分布
也就是通过最小化上式,便可得到 s θ ( x ) ≈ ∇ x log P d a t a ( x ) s_\theta(x)\approx \nabla_x\log P_{data}(x) sθ(x)≈∇xlogPdata(x)。
4、问题
理论上,我们直接求解Eq.(4)就可以了,但是,我们样本所服从的分布往往是服从,概率分布中往往存在一些低密度区域,那么对应的样本就很少。
而样本少,意味着对应为止的梯度分数,得不到很好的训练,那么神经网络在那些样本点就很容易估不准。作者博客给出了一张很形象的图像(图像来自参考①)
可以看到,数据的密度分别都在左下角和右上角,那么这些区域就能够用神经网络得到很好的拟合,对应Accurate区域。相反,低密度区域,没有得到很好的拟合,对应Inaccurate区域。
当我们使用郎之万动力采样的时候,随机初始化一个 x 0 x_0 x0,它落在低密度区域的概率非常之高。而低密度的区域没有经过很好的训练,所以郎之万动力采样在短时间内很难得到较好的结果。
那么,该如何解决这个问题呢?一个很好的方法就是——加噪声
我们通过对图像加入随机扰动噪声,会填充原本的低密度区域,从而让整个区域看起来较为的均匀(图像来自参考①)
也就是这样,让原本的密度点扩张开来。
加噪的过程我们可以表示为 x ~ = x + σ z \tilde x=x+\sigma z x~=x+σz。 x x x表示原始图像, x ~ \tilde x x~表示加噪后的图像。
我们用 q ( x ~ ∣ x ) ∼ N ( x , σ 2 I ) q(\tilde x|x)\sim N(x,\sigma^2I) q(x~∣x)∼N(x,σ2I)去表示这个加噪过程
于是Eq.(3)就可以变成
L
=
1
2
E
P
d
a
t
a
(
x
)
,
x
~
∼
N
(
x
,
σ
2
I
)
[
∣
∣
s
θ
(
x
+
σ
z
)
−
∇
x
~
log
q
(
x
~
∣
x
)
∣
∣
2
2
]
(5)
L_{}=\frac{1}{2}\mathbb{E}_{P_{data}(x),\tilde x\sim N(x,\sigma^2I)}\left[||s_\theta(x+\sigma z)-\nabla_{\tilde x}\log q(\tilde x|x)||_2^2\right]\tag{5}
L=21EPdata(x),x~∼N(x,σ2I)[∣∣sθ(x+σz)−∇x~logq(x~∣x)∣∣22](5)
emmmm,我感觉这样讲貌似挺合理的,但是它是需要证明的,也就是证明Eq.(4)、Eq.(5)的优化等价性。我就不证明了,证明过程在参考论文②,并不难,读者自己看一下就知道了
除此之外,真正导致需要加噪的,其实有其他原因,我只讲了其中一个。其他原因请看参考②,里面讲的非常之详细。我也懒得写了
如果我们加的噪声足够小,那么 P d a t a ( x ) ≈ q ( x ) P_{data}(x)\approx q(x) Pdata(x)≈q(x)
现在,我们预测的是加噪后的梯度分数,通过加噪的过程,也避免了直接求解 P ( x ) P(x) P(x)的问题。那我们来看一下这个等式可以变成什么吧
因为
q
(
x
~
∣
x
)
q(\tilde x|x)
q(x~∣x)是服从高斯分布的,是完全可以求出来的,所以梯度为
∇
x
~
log
q
(
x
~
∣
x
)
=
∇
x
~
log
1
2
π
σ
2
d
exp
{
−
∣
∣
x
~
−
x
∣
∣
2
2
σ
2
}
=
∇
x
~
(
log
1
2
π
σ
2
d
−
∣
∣
x
~
−
x
∣
∣
2
2
σ
2
)
=
−
2
(
x
~
−
x
)
2
σ
2
=
−
x
~
−
x
σ
2
=
−
z
σ
\begin{aligned}\nabla_{\tilde x}\log q(\tilde x|x)=&\nabla_{\tilde x}\log \frac{1}{\sqrt{2\pi\sigma^2}^d}\exp \left\{-\frac{||\tilde x-x||^2}{2\sigma^2}\right\}\\=&\nabla_{\tilde x}\left(\log \frac{1}{\sqrt{2\pi\sigma^2}^d}-\frac{||\tilde x-x||^2}{2\sigma^2}\right)\\=&-\frac{2(\tilde x-x)}{2\sigma^2}\\=&-\frac{\tilde x -x}{\sigma^2}\\=&-\frac{z}{\sigma}\end{aligned}
∇x~logq(x~∣x)=====∇x~log2πσ2d1exp{−2σ2∣∣x~−x∣∣2}∇x~(log2πσ2d1−2σ2∣∣x~−x∣∣2)−2σ22(x~−x)−σ2x~−x−σz
所以损失函数就可以变成
L
=
1
2
E
P
d
a
t
a
(
x
)
,
x
~
∼
N
(
x
,
σ
2
I
)
[
∣
∣
s
θ
(
x
+
σ
z
)
+
x
~
−
x
σ
2
∣
∣
2
2
]
L=\frac{1}{2}\mathbb{E}_{P_{data}(x),\tilde x\sim N(x,\sigma^2I)}\left[||s_\theta(x+\sigma z)+\frac{\tilde x -x}{\sigma^2}||_2^2\right]
L=21EPdata(x),x~∼N(x,σ2I)[∣∣sθ(x+σz)+σ2x~−x∣∣22]
按理说,我们只需要最优化这个目标函数即可。
可问题又来了
我们该如何加入噪声呢?加多少?加的小了,低密度区域没有得到很好的填充。加多了,直接改变原本的数据分布了,这显然也不行。
我们干脆一不做二不休,我们加多个量级噪声,不同量级都进行训练。
当训练完成之后,就得到了不同噪声强度的噪声条件得分网络。
假设不同强度等级的噪声有S个, { σ i } i = 1 S \{\sigma_i\}_{i=1}^S {σi}i=1S,我们看一张图(里面显示了三个噪声强度的情况,图像来自参考①)
那么进行采样的时候,就可以从高强度的噪声,进行郎之万动力采样,然后慢慢降低噪声的强度。总而言之,就是每个噪声强度,都进行一轮郎之万动力采样,比如下图(图像来自参考①)(Gif图像太大,上传不了…看视频里面吧)
假设有S个噪声强度,那么就可以变成
L
=
1
S
∑
i
=
1
S
λ
i
1
2
E
P
d
a
t
a
(
x
)
,
x
~
∼
N
(
x
,
σ
i
2
I
)
[
∣
∣
s
θ
(
x
+
σ
i
z
,
σ
i
)
+
x
~
i
−
x
σ
i
2
∣
∣
2
2
]
L=\frac{1}{S}\sum\limits_{i=1}^S\lambda_i\frac{1}{2}\mathbb{E}_{P_{data}(x),\tilde x\sim N(x,\sigma_i^2I)}\left[||s_\theta(x+\sigma_i z,\sigma_i)+\frac{\tilde x_i -x}{\sigma_i^2}||_2^2\right]
L=S1i=1∑Sλi21EPdata(x),x~∼N(x,σi2I)[∣∣sθ(x+σiz,σi)+σi2x~i−x∣∣22]
x
~
i
\tilde x_i
x~i表示在噪声强度为
σ
i
\sigma_i
σi的加噪图像。
λ
i
\lambda_i
λi代表的是一个加权系数.一般情况下,我们取
λ
i
=
σ
i
2
\lambda_i=\sigma^2_i
λi=σi2。
对于噪声强度数量,一般是数百到数千;噪声强度选择一般采用几何级数。
采样的时候正如前面所说,先在高强度噪声量级进行郎之万动力采样,而后慢慢降低,所以采样方法为
5、结束
好了,本篇文章到此为止,如有问题,还望指出,阿里嘎多!!!
6、参考
①Generative Modeling by Estimating Gradients of the Data Distribution | Yang Song (yang-song.net)
②基于分数的生成模型(Score-based generative models) — 张振虎的博客 张振虎 文档 (zhangzhenhu.com)
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
所有评论(0)