【优化】近端梯度下降(Proximal Gradient Descent)求解Lasso线性回归问题
近端梯度下降近端梯度下降(Proximal Gradient Descent, PGD)是众多梯度下降算法中的一种,与传统的梯度下降算法以及随机梯度下降算法相比,近端梯度下降算法的使用范围相对狭窄,对于凸优化问题,PGD常用与目标函数中包含不可微分项时,如L1L1L1范数、迹范数或者全变正则项等。常见线性回归问题很多优化问题都可以转换为线性回归问题,假设线性回归的表达式是y=Xwy = Xw...
近端梯度下降的背景
近端梯度下降(Proximal Gradient Descent, PGD)是众多梯度下降算法中的一种,与传统的梯度下降算法以及随机梯度下降算法相比,近端梯度下降算法的使用范围相对狭窄,对于凸优化问题,PGD常用与目标函数中包含不可微分项时,如 L 1 L1 L1范数、迹范数或者全变正则项等。
常见线性回归问题
很多优化问题都可以转换为线性回归问题,假设线性回归的表达式是 y = X w y = Xw y=Xw其中 y ∈ R m y \in R^{m} y∈Rm, X ∈ R m × n X \in R^{m \times n} X∈Rm×n,是已知的, w ∈ R n w \in R^{n} w∈Rn表示参数向量,是未知的。根据应用场景不同,变量和参数具有的意义也不同。
最常见的线性回归模型的目标函数即可表示为: f ( w ) = 1 2 ∣ ∣ y − X w ∣ ∣ 2 2 f(w) = \frac{1}{2}||y - Xw||^{2}_{2} f(w)=21∣∣y−Xw∣∣22求解上述优化问题可通过最小二乘法或者梯度下降的方法。在实际情况中,我们通常会对参数向量 w w w进行限制,如为了减小模型的复杂度,会要求参数向量为稀疏的形式,此时会加入 L 1 L1 L1正则项;为了提高模型的泛化能力,会要求参数比较小,此时会加入 L 2 L2 L2正则项,则得到的回归模型分别为Lasso回归和Ridge回归模型。
Lasso回归模型的目标函数:
f
(
w
)
=
1
2
∣
∣
y
−
X
w
∣
∣
2
2
+
λ
∣
∣
w
∣
∣
1
f(w) = \frac{1}{2}||y - Xw||^{2}_{2}+\lambda||w||_{1}
f(w)=21∣∣y−Xw∣∣22+λ∣∣w∣∣1Ridge回归模型的目标函数:
f
(
w
)
=
f
(
w
)
=
1
2
∣
∣
y
−
X
w
∣
∣
2
2
+
β
∣
∣
w
∣
∣
2
2
f(w) = f(w) = \frac{1}{2}||y - Xw||^{2}_{2} + \beta||w||_{2}^{2}
f(w)=f(w)=21∣∣y−Xw∣∣22+β∣∣w∣∣22
对于Lasso回归模型的目标函数, ∣ ∣ w ∣ ∣ 1 ||w||_{1} ∣∣w∣∣1是一个凸函数,并且是不可微的,传统的梯度下降则通常要求目标函数是可微的,所以为了解决含有不可微凸函数项的目标函数优化问题,近端梯度下降算法就此提出。近端梯度下降主要解决的问题可表示为: m i n f ( w ) = m i n { g ( w ) + h ( w ) } min f(w) = min \{g(w) + h(w)\} minf(w)=min{g(w)+h(w)}其中 g ( w ) g(w) g(w)是凸函数,并且可微; h ( w ) h(w) h(w)也是凸函数,但是在某些地方不可微,对应于Lasso回归模型中就是 ∣ ∣ w ∣ ∣ 1 ||w||_{1} ∣∣w∣∣1项。
近端算子(Proximal Operator)
在介绍近端梯度下降之前,我们需要先引入近端算子的概念。近端算子是一种映射,并且它只和不可微的凸函数 h ( w ) h(w) h(w)有关,它的表现形式是: p r o x h ( w ) = a r g m i n u { h ( u ) + 1 2 ∣ ∣ u − w ∣ ∣ 2 2 } prox_{h}(w) = arg \mathop{min} \limits_{u}\{h(u) + \frac{1}{2}||u - w||_{2}^{2}\} proxh(w)=argumin{h(u)+21∣∣u−w∣∣22}其中 p r o x h ( w ) prox_{h}(w) proxh(w)表示变量 w w w和函数 h ( . ) h(.) h(.)的近端算子。上面的公式的意义是:对于任意给定的 w ∈ R n w \in R^{n} w∈Rn,我们希望找到使得 h ( u ) + 1 2 ∣ ∣ u − w ∣ ∣ 2 2 h(u) + \frac{1}{2}||u - w||_{2}^{2} h(u)+21∣∣u−w∣∣22最小化的解。若 u = p r o x h ( w ) u = prox_{h}(w) u=proxh(w)为最优解,则这个解的意义是,当我们知道存在不可微点的函数 h ( w ) h(w) h(w)在点 w w w处不可微时,则我们就去找一个点 u u u,这个点 u u u不仅仅使得函数 h ( w ) h(w) h(w)取得较小的值,还非常接近不可微分点 w w w。
通常在通过近端算子进行迭代递推时,会引入一个迭代步长
t
t
t,即:
p
r
o
x
h
(
.
)
,
t
(
w
)
=
a
r
g
m
i
n
u
{
h
(
u
)
+
1
2
t
∣
∣
u
−
w
∣
∣
2
2
}
prox_{h(.),t}(w) = arg \mathop{min} \limits_{u}\{h(u) + \frac{1}{2t}||u - w||_{2}^{2}\}
proxh(.),t(w)=argumin{h(u)+2t1∣∣u−w∣∣22}
特别地,当 h ( w ) = λ ∣ ∣ x ∣ ∣ 1 h(w) = \lambda ||x||_{1} h(w)=λ∣∣x∣∣1时, p r o x h ( w ) prox_{h}(w) proxh(w)就是所谓的软阈值函数(soft thresholding function),即 p r o x h ( w ) = s o f t λ ( w ) prox_{h}(w) = soft_{\lambda}(w) proxh(w)=softλ(w),其中 s o f t λ ( w ) = s g n ( w ) ( ∣ w ∣ − λ ) + = { w − λ , w ≥ λ 0 , ∣ w ∣ ≤ λ w + λ , w ≤ − λ soft_{\lambda}(w) = sgn(w)(|w| - \lambda)_{+} = \left \{ \begin{aligned} &w - \lambda, &w \geq \lambda \\ &0, &|w| \leq \lambda \\ &w+\lambda, &w\leq -\lambda \end{aligned} \right. softλ(w)=sgn(w)(∣w∣−λ)+=⎩⎪⎨⎪⎧w−λ,0,w+λ,w≥λ∣w∣≤λw≤−λ加入迭代步长 t t t之后的形式是: s o f t λ , t ( w ) = s g n ( w ) ( ∣ w ∣ − λ t ) + = { w − λ t , w ≥ λ t 0 , ∣ w ∣ ≤ λ t w + λ t , w ≤ − λ t soft_{\lambda, t}(w) = sgn(w)(|w| - \lambda t)_{+} = \left \{ \begin{aligned} &w - \lambda t, &w \geq \lambda t \\ &0, &|w| \leq \lambda t \\ &w+\lambda t, &w\leq -\lambda t \end{aligned} \right. softλ,t(w)=sgn(w)(∣w∣−λt)+=⎩⎪⎨⎪⎧w−λt,0,w+λt,w≥λt∣w∣≤λtw≤−λt
软阈值算子计算时针对的是向量
w
w
w的分量形式。软阈值函数的图像形式是:
近端梯度下降迭代递推方法
对于问题优化 a r g m i n w f ( w ) = g ( w ) + h ( w ) arg\mathop{min} \limits_{w} f(w) = g(w) + h(w) argwminf(w)=g(w)+h(w),通过近端梯度下降算法进行迭代求解时,变量 w w w的迭代递推公式是: w k = p r o x t , h ( . ) ( w k − 1 − t ∇ g ( w k − 1 ) ) w_{k} = prox_{t, h(.)}(w_{k-1} - t \nabla g(w_{k-1})) wk=proxt,h(.)(wk−1−t∇g(wk−1))其中, w w w的下标表示迭代次数, t t t表示迭代步长。
下面简单介绍如何进行证明。首先,在每一步进行迭代中,近端梯度下降将点 w k − 1 w_{k-1} wk−1处的近似函数取得最小值的点作为下一次迭代的起始点 w k w_{k} wk。对于 f ( w ) f(w) f(w)在点 w k − 1 w_{k-1} wk−1处的近似函数可以通过泰勒公式以及Lipschitz continuous gradient进行二阶近似,即 Q ( w , w k − 1 ) = g ( w k − 1 ) + < ∇ g ( w k − 1 ) , w − w k − 1 > + L 2 ∣ ∣ w − w k − 1 ∣ ∣ 2 2 + h ( w ) Q(w, w_{k-1}) = g(w_{k-1}) + <\nabla g(w_{k-1}), w - w_{k-1}> + \frac{L}{2}||w - w_{k-1}||_{2}^{2} + h(w) Q(w,wk−1)=g(wk−1)+<∇g(wk−1),w−wk−1>+2L∣∣w−wk−1∣∣22+h(w)
所以我们即是需要证明: w k = p r o x t , h ( . ) ( w k − 1 − t ∇ g ( w k − 1 ) ) = a r g m i n w Q ( w , w k − 1 ) w_{k} =prox_{t, h(.)}(w_{k-1} - t \nabla g(w_{k-1}))= arg\mathop{min} \limits_{w}Q(w, w_{k-1}) wk=proxt,h(.)(wk−1−t∇g(wk−1))=argwminQ(w,wk−1)
接着,我们将软阈值算子进行展开:
w
k
=
p
r
o
x
t
,
h
(
.
)
(
w
k
−
1
−
t
∇
g
(
w
k
−
1
)
)
=
a
r
g
m
i
n
w
h
(
w
)
+
1
2
t
∣
∣
w
−
(
w
k
−
1
−
t
∇
g
(
w
k
−
1
)
)
∣
∣
2
2
=
a
r
g
m
i
n
w
h
(
w
)
+
t
2
∣
∣
∇
g
(
w
k
−
1
)
∣
∣
2
2
+
<
∇
g
(
w
k
−
1
)
,
w
−
w
k
−
1
>
+
1
2
t
∣
∣
w
−
w
k
−
1
∣
∣
2
2
=
a
r
g
m
i
n
w
h
(
w
)
+
g
(
w
k
−
1
)
+
<
∇
g
(
w
k
−
1
)
,
w
−
w
k
−
1
>
+
1
2
t
∣
∣
w
−
w
k
−
1
∣
∣
2
2
\begin{aligned} w_{k} &= prox_{t, h(.)}(w_{k-1} - t \nabla g(w_{k-1}))\\&= arg\mathop{min} \limits_{w} h(w) + \frac{1}{2t}||w - (w_{k-1} - t \nabla g(w_{k-1}))||_{2}^{2} \\&=arg\mathop{min} \limits_{w} h(w)+ \frac{t}{2}||\nabla g(w_{k-1})||_{2}^{2}+<\nabla g(w_{k-1}), w - w_{k-1}> + \frac{1}{2t}||w - w_{k-1}||_{2}^{2} \\&=arg\mathop{min} \limits_{w} h(w)+ g(w_{k-1}) + <\nabla g(w_{k-1}), w - w_{k-1}> + \frac{1}{2t}||w - w_{k-1}||_{2}^{2}\end{aligned}
wk=proxt,h(.)(wk−1−t∇g(wk−1))=argwminh(w)+2t1∣∣w−(wk−1−t∇g(wk−1))∣∣22=argwminh(w)+2t∣∣∇g(wk−1)∣∣22+<∇g(wk−1),w−wk−1>+2t1∣∣w−wk−1∣∣22=argwminh(w)+g(wk−1)+<∇g(wk−1),w−wk−1>+2t1∣∣w−wk−1∣∣22因为
t
/
2
∣
∣
∇
g
(
w
k
−
1
)
∣
∣
2
2
t/2||\nabla g(w_{k-1})||_{2}^{2}
t/2∣∣∇g(wk−1)∣∣22是常数,与所求变量
w
w
w无关,所以最后两步是等价的。
又因为: w k = a r g m i n w Q ( w , w k − 1 ) = a r g m i n w h ( w ) + g ( w k − 1 ) + < ∇ g ( w k − 1 ) , w − w k − 1 > + L 2 ∣ ∣ w − w k − 1 ∣ ∣ 2 2 \begin{aligned} w_{k} &= arg\mathop{min} \limits_{w}Q(w, w_{k-1}) \\ &=arg\mathop{min} \limits_{w} h(w)+ g(w_{k-1}) + <\nabla g(w_{k-1}), w - w_{k-1}> + \frac{L}{2}||w - w_{k-1}||_{2}^{2} \end{aligned} wk=argwminQ(w,wk−1)=argwminh(w)+g(wk−1)+<∇g(wk−1),w−wk−1>+2L∣∣w−wk−1∣∣22所以得证。并且从结果看,两者区别只是在于迭代步长的选取。其中 t = 1 / L t = 1/L t=1/L在理论上迭代速度最快的。
以Lasso线性回归问题为例
对于Lasso线性回归问题,即是求解
a
r
g
m
i
n
w
f
(
w
)
=
g
(
w
)
+
h
(
w
)
arg\mathop{min} \limits_{w} f(w) = g(w) + h(w)
argwminf(w)=g(w)+h(w),其中
g
(
w
)
=
1
2
∣
∣
y
−
X
w
∣
∣
2
2
g(w) = \frac{1}{2}||y - Xw||_{2}^{2}
g(w)=21∣∣y−Xw∣∣22,
h
(
w
)
=
λ
∣
∣
w
∣
∣
1
h(w) = \lambda||w||_{1}
h(w)=λ∣∣w∣∣1。
由近端算子以及近端梯度算法递推公式可知变量
w
w
w的迭代递推公式是:
w
k
=
p
r
o
x
t
,
h
(
.
)
(
w
k
−
1
−
t
∇
g
(
w
k
−
1
)
)
=
s
o
f
t
t
,
λ
(
w
k
−
1
−
t
∇
g
(
w
k
−
1
)
)
w_{k} = prox_{t, h(.)}(w_{k-1} - t \nabla g(w_{k-1}))=soft_{t,\lambda}(w_{k-1} - t \nabla g(w_{k-1}))
wk=proxt,h(.)(wk−1−t∇g(wk−1))=softt,λ(wk−1−t∇g(wk−1))其中,
∇
g
(
w
k
−
1
)
=
X
T
(
X
w
−
y
)
\nabla g(w_{k-1}) = X^{T}(Xw - y)
∇g(wk−1)=XT(Xw−y),则上式即:
w
k
=
p
r
o
x
t
,
h
(
.
)
(
w
k
−
1
−
t
∇
g
(
w
k
−
1
)
)
=
s
o
f
t
t
,
λ
(
w
k
−
1
−
t
X
T
X
w
+
t
X
T
y
)
w_{k} = prox_{t, h(.)}(w_{k-1} - t \nabla g(w_{k-1}))=soft_{t,\lambda}(w_{k-1} - tX^{T}Xw + tX^{T}y )
wk=proxt,h(.)(wk−1−t∇g(wk−1))=softt,λ(wk−1−tXTXw+tXTy)这里每次迭代中通过一个软阈值(收缩)的操作来更新
w
w
w,实际上就是迭代软阈值算法 (Iterative Soft-Thresholding Algorithm, ISTA),或者称为迭代阈值收缩算法(Iterative Shrinkage Thresholding Algorithm, ISTA)。
参考资料
机器学习 | 近端梯度下降法 (proximal gradient descent)
LASSO回归与L1正则化 西瓜书
软阈值迭代算法(ISTA)和快速软阈值迭代算法(FISTA)
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
所有评论(0)