基本概念
- 重参数 (Reparameterization) 实际上是处理如下期望形式的目标函数的一种技巧:
上面的期望式可能在如下情形中出现 (e.g. VAE):假设我们在模型的前向传播过程中得到了随机变量
Z
Z
Z 的概率分布
p
θ
(
z
)
p_\theta(z)
pθ(z),其中
θ
\theta
θ 为模型参数,然后需要根据
p
θ
(
z
)
p_\theta(z)
pθ(z) 对随机变量
Z
Z
Z 进行采样,再根据采样得到的值
z
z
z 完成后续的前向传播过程,例如计算训练损失
f
(
z
)
f(z)
f(z). 此时,训练损失
L
θ
L_\theta
Lθ 即可写为上述期望的形式。然而,这里存在一个很大的问题,就是采样操作是不可导的,虽然我们可以完成模型的前向传播,但反向传播时却无法计算出梯度
∂
L
θ
/
∂
θ
\partial L_\theta/\partial \theta
∂Lθ/∂θ,也就无法进行模型的训练。而 Reparameterization 则是提供了一种变换,使得我们可以直接从
p
θ
(
z
)
p_θ(z)
pθ(z) 中采样,并且保留
θ
θ
θ 的梯度,也就是将采样操作由不可导变为可导 - 重参数假设从分布
p
θ
(
z
)
p_θ(z)
pθ(z) 中采样可以分解为两个步骤:(1) 从无参数分布
q
(
ε
)
q(ε)
q(ε) 中采样一个
ε
ε
ε;(2) 通过变换
z
=
g
θ
(
ε
)
z=g_θ(ε)
z=gθ(ε) 生成
z
z
z。那么,上述期望就变成了
这时候被采样的分布就没有任何参数了,全部被转移到
f
f
f 内部了,因此可以采样若干个点,当成普通的 loss 那样写下来了 (上述重参数过程假定
p
θ
(
z
)
=
∫
g
θ
(
ε
)
=
z
q
(
ε
)
d
ε
=
∫
δ
(
z
−
g
θ
(
ε
)
)
q
(
ε
)
d
ε
p_θ(z)=∫_{g_θ(ε)=z}q(ε)dε=∫δ(z−g_θ(ε))q(ε)dε
pθ(z)=∫gθ(ε)=zq(ε)dε=∫δ(z−gθ(ε))q(ε)dε,
δ
(
⋅
)
δ(⋅)
δ(⋅) 是狄拉克函数,因此有
L
θ
=
E
z
∼
p
θ
(
z
)
[
f
(
z
)
]
=
∬
q
(
ε
)
δ
(
z
−
g
θ
(
ε
)
)
f
(
z
)
d
ε
d
z
=
∫
q
(
ε
)
f
(
g
θ
(
ε
)
)
d
ε
=
E
ε
∼
q
(
ε
)
[
f
(
g
θ
(
ε
)
)
]
L_\theta=\mathbb E_{z∼p_θ(z)}[f(z)]=\iint q(\varepsilon)\delta(z - g_{\theta}(\varepsilon)) f(z)d\varepsilon dz=\int q(\varepsilon) f(g_{\theta}(\varepsilon))d\varepsilon=\mathbb{E}_{\varepsilon\sim q(\varepsilon)}[f(g_{\theta}(\varepsilon))]
Lθ=Ez∼pθ(z)[f(z)]=∬q(ε)δ(z−gθ(ε))f(z)dεdz=∫q(ε)f(gθ(ε))dε=Eε∼q(ε)[f(gθ(ε))])
连续情形
- 简单起见,我们先考虑
z
z
z 为连续随机变量的情形:
在 VAE 中常见的是正态分布
p
θ
(
z
)
=
N
(
z
;
μ
θ
,
σ
θ
2
)
p_{\theta}(z)=\mathcal{N}\left(z;\mu_{\theta},\sigma_{\theta}^2\right)
pθ(z)=N(z;μθ,σθ2) - 总的来说,连续情形的重参数还是比较简单的。从数学本质来看,重参数是一种积分变换,即原来是关于
z
z
z 积分,通过
z
=
g
θ
(
ε
)
z=g_θ(ε)
z=gθ(ε) 变换之后得到新的积分形式。一个最简单的例子就是正态分布:对于正态分布来说,重参数就是 “从
N
(
z
;
μ
θ
,
σ
θ
2
)
N(z;μ_θ,σ^2_θ)
N(z;μθ,σθ2) 中采样一个
z
z
z” 变成 “从
N
(
ε
;
0
,
1
)
N(ε;0,1)
N(ε;0,1) 中采样一个
ε
ε
ε,然后计算
ε
×
σ
θ
+
μ
θ
ε×σ_θ+μ_θ
ε×σθ+μθ”,所以
离散情形
- 为了突出 “离散”,我们将随机变量
z
z
z 换成
y
y
y,即对于离散情形要面对的目标函数是
此时,
p
θ
(
y
)
p_\theta(y)
pθ(y) 是一个
k
k
k 分类模型: - 看到上述期望项中的求和,第一反应可能是 “求和?那就求呗,又不是求不了”。的确,对于离散的随机变量,其期望只不过是有限项求和,理论上确实可以直接完成求和再去梯度下降。但是,如果
k
k
k 特别大呢?举个例子,假设
y
y
y 是一个 100 维的向量,每个元素不是 0 就是 1,那么所有不同的
y
y
y 的总数目就是
2
100
2^{100}
2100,要对这样的
2
100
2^{100}
2100 个单项进行求和,计算量是难以接受的 (每一项都需要计算前向传播过程
f
(
y
)
f(y)
f(y))。所以,还是需要回到采样上去,如果能够采样若干个点就能得到期望的有效估计,并且还不损失梯度信息,那自然是最好了
Gumbel Max
- 为此,需要先引入 Gumbel Max。假设每个类别的概率是
p
1
,
p
2
,
…
,
p
k
p_1,p_2,…,p_k
p1,p2,…,pk,那么 Gumbel Max 提供了一种依概率采样类别的方案:
也就是说,先算出各个概率的对数
log
p
i
\log p_i
logpi,然后从均匀分布
U
[
0
,
1
]
U[0,1]
U[0,1] 中采样
k
k
k 个随机数
ε
1
,
…
,
ε
k
ε_1,…,ε_k
ε1,…,εk,把
g
i
=
−
log
(
−
log
ε
i
)
∼
Gumbel(0,1)
g_i=−\log(−\log ε_i)\sim\text{Gumbel(0,1)}
gi=−log(−logεi)∼Gumbel(0,1) 加到
log
p
i
\log p_i
logpi 上去,最后把最大值对应的类别抽取出来就行了。由于现在的随机性已经转移到
U
[
0
,
1
]
U[0,1]
U[0,1] 上去了,并且
U
[
0
,
1
]
U[0,1]
U[0,1] 不带有未知参数,因此 Gumbel Max 就是离散分布的一个重参数过程 - 可以证明,这样的过程精确等价于依概率
p
1
,
p
2
,
…
,
p
k
p_1,p_2,…,p_k
p1,p2,…,pk 采样一个类别,换句话说,在 Gumbel Max 中,输出
i
i
i 的概率正好是
p
i
p_i
pi. 不失一般性,这里我们证明输出 1 的概率是
p
1
p_1
p1. 注意,输出 1 意味着
log
p
1
−
l
o
g
(
−
l
o
g
ε
1
)
\log p_1−log(−logε_1)
logp1−log(−logε1) 是最大的,这又意味着:
log
p
1
−
log
(
−
log
ε
1
)
>
log
p
2
−
log
(
−
log
ε
2
)
log
p
1
−
log
(
−
log
ε
1
)
>
log
p
3
−
log
(
−
log
ε
3
)
⋮
log
p
1
−
log
(
−
log
ε
1
)
>
log
p
k
−
log
(
−
log
ε
k
)
logp1−log(−logε1)>logp2−log(−logε2)logp1−log(−logε1)>logp3−log(−logε3)⋮logp1−log(−logε1)>logpk−log(−logεk)不失一般性,我们只分析第一个不等式,化简后得到:
ε
2
<
ε
1
p
2
/
p
1
≤
1
\varepsilon_2 < \varepsilon_1^{p_2 / p_1}\leq 1
ε2<ε1p2/p1≤1由于
ε
2
∼
U
[
0
,
1
]
ε_2∼U[0,1]
ε2∼U[0,1],所以
ε
2
<
ε
1
p
2
/
p
1
ε_2<ε^{p_2/p_1}_1
ε2<ε1p2/p1 的概率就是
ε
1
p
2
/
p
1
ε^{p_2/p_1}_1
ε1p2/p1,这就是固定
ε
1
ε_1
ε1 的情况下,第一个不等式成立的概率。那么,所有不等式同时成立的概率是
ε
1
p
2
/
p
1
ε
1
p
3
/
p
1
…
ε
1
p
k
/
p
1
=
ε
1
(
p
2
+
p
3
+
⋯
+
p
k
)
/
p
1
=
ε
1
(
1
/
p
1
)
−
1
\varepsilon_1^{p_2 / p_1}\varepsilon_1^{p_3 / p_1}\dots \varepsilon_1^{p_k / p_1}=\varepsilon_1^{(p_2 + p_3 + \dots + p_k) / p_1}=\varepsilon_1^{(1/p_1)-1}
ε1p2/p1ε1p3/p1…ε1pk/p1=ε1(p2+p3+⋯+pk)/p1=ε1(1/p1)−1然后对所有
ε
1
ε_1
ε1 求平均,就是
∫
0
1
ε
1
(
1
/
p
1
)
−
1
d
ε
1
=
p
1
\int_0^1 \varepsilon_1^{(1/p_1)-1}d\varepsilon_1 = p_1
∫01ε1(1/p1)−1dε1=p1
Gumbel Softmax
- 我们希望重参数不丢失梯度信息,但是 Gumbel Max 做不到,因为
arg max
\argmax
argmax 不可导,为此,需要做进一步的近似。首先,留意到在神经网络中,处理离散输入的基本方法是转化为 one hot 形式,包括 Embedding 层的本质也是 one hot 全连接,因此
arg max
\argmax
argmax 实际上是
one_hot
(
arg max
)
\text{one_hot}(\argmax)
one_hot(argmax),然后,我们寻求
one_hot
(
arg max
)
\text{one_hot}(\argmax)
one_hot(argmax) 的光滑近似,它就是
s
o
f
t
m
a
x
softmax
softmax. 由此,我们得到 Gumbel Max 的光滑近似版本——Gumbel Softmax:
其中参数
τ
>
0
τ>0
τ>0 称为退火参数,它越小输出结果就越接近 one hot 形式 (但同时梯度消失就越严重)。提示一个小技巧,如果
p
i
p_i
pi 是
s
o
f
t
m
a
x
softmax
softmax 的输出,那么大可不必先算出
p
i
p_i
pi 再取对数,直接将
log
p
i
\log p_i
logpi 替换为
o
i
o_i
oi 即可: - 跟连续情形一样,Gumbel Softmax 就是用在需要求
E
y
∼
p
θ
(
y
)
[
f
(
y
)
]
\mathbb{E}_{y\sim p_{\theta}(y)}[f(y)]
Ey∼pθ(y)[f(y)]、且无法直接完成对
y
y
y 求和的场景,这时候我们算出
p
θ
(
y
)
p_θ(y)
pθ(y)(或者
o
i
o_i
oi),然后选定一个
τ
>
0
τ>0
τ>0,用 Gumbel Softmax 算出一个随机向量来
y
~
\tilde y
y~,代入计算得到
f
(
y
~
)
f(\tilde y)
f(y~),它就是
E
y
∼
p
θ
(
y
)
[
f
(
y
)
]
\mathbb{E}_{y\sim p_{\theta}(y)}[f(y)]
Ey∼pθ(y)[f(y)] 的一个好的近似,且保留了梯度信息 - 注意,Gumbel Softmax 不是类别采样的等价形式,Gumbel Max 才是。而 Gumbel Max 可以看成是 Gumbel Softmax 在
τ
→
0
τ→0
τ→0 时的极限。当
τ
τ
τ 比较小时,Gumbel Softmax 采样得到的样本接近 one-hot vector,也就比较接近实际的采样情况,但梯度的方差比较大;当
τ
τ
τ 比较大时,Gumbel Softmax 采样得到的样本比较平滑 (一个平滑的概率向量,向量的各个分量的值都差不多),但梯度的方差比较小。所以在应用 Gumbel Softmax 时,开始可以选择较大的
τ
τ
τ(比如 1),然后慢慢退火到一个接近于 0 的数(比如 0.01),这样才能得到比较好的结果
Gumbel Softmax v.s. Softmax
- Gumbel Softmax 通过
τ
→
0
τ→0
τ→0 的退火来逐渐逼近 one hot,相比直接用原始的 Softmax 进行退火,区别在于原始 Softmax 退火只能得到最大值位置为 1 的 one hot 向量,而 Gumbel Softmax 有概率得到非最大值位置的 one hot 向量,增加了随机性,会使得基于采样的训练更充分一些
Straight-Through Gumbel-Softmax Estimator
- 由 Gumbel Softmax 得到的采样样本是实际采样样本的一个近似,它甚至都不在离散变量的取值范围之内,即使
τ
τ
τ 比较小,Gumbel Softmax 采样得到的样本也只是接近 one-hot vector,而非真正离散化的 one-hot vector. 但总存在那么一些场景,我们只想采样离散值而非连续值 (e.g. RL 中从离散的动作空间中采样) - 假设 Gumbel Softmax 输出的采样向量为
y
y
y,为了利用 Gumbel Softmax 采样离散值,我们可以在前向传播时使用
z
=
one_hot
(
arg max
y
)
z=\text{one_hot}(\argmax y)
z=one_hot(argmaxy) 得到离散的采样值,在反向传播时利用
∇
θ
z
≈
∇
θ
y
\nabla_\theta z\approx \nabla_\theta y
∇θz≈∇θy,对
∇
θ
y
\nabla_\theta y
∇θy 进行梯度回传:
z
=
y
+
s
g
(
one_hot
(
arg max
y
)
−
y
)
z=y+sg(\text{one_hot}(\argmax y)-y)
z=y+sg(one_hot(argmaxy)−y)其中,
s
g
sg
sg 为 stop gradient 操作
背后的故事: 梯度估计 (gradient estimator)
- 重参数就这样介绍完了吗?远远没有,重参数的背后,实际上是一个称为 “梯度估计”的 大家族,而重参数只不过是这个大家族中的一员。每年的 ICLR、ICML 等顶会上搜索gradient estimator、REINFORCE 等关键词,可以搜索到不少文章,说明这是个大家还在钻研的课题。要想说清重参数的来龙去脉,也要说些梯度估计的故事
SF 估计 (Score Function Estimator)
- 前面我们分别讲了连续型和离散型的重参数,都是在 “loss 层面” 讲述的,也就是说都是想办法把 loss 显式地定义好,剩下的交给框架自动求导、自动优化就是了。而事实上,就算不能显式地写出 loss 函数,也不妨碍我们对它求导,自然也不妨碍我们去用梯度下降了。比如 Score Function Estimator:
这是对原来损失函数的最朴素的估计,在强化学习中
z
z
z 代表着策略,那么上式就是一个最基本的策略梯度,所以有时候也直接称上述估计为叫 REINFORCE。现在我们可以直接从
p
θ
(
z
)
p_θ(z)
pθ(z) 中采样若干个点来估算
∂
L
θ
/
∂
θ
\partial L_\theta/\partial \theta
∂Lθ/∂θ 的值了,不用担心会不会没梯度 - 同时注意到,重参数技巧要求
f
f
f 可导,但是在诸如强化学习的场景下,
f
(
z
)
f(z)
f(z) 对应着奖励函数,很难做到光滑可导,此时就必须使用 SF 估计
梯度方差
- SF 估计看上去很美好,得到了一个连续和离散变量都适用的估计式,那为什么还需要重参数呢?主要的原因是:SF 估计的方差太大。SF 估计是函数
f
(
z
)
∂
∂
θ
log
p
θ
(
z
)
f(z) \frac{\partial}{\partial\theta} \log p_{\theta}(z)
f(z)∂θ∂logpθ(z) 在分布
p
θ
(
z
)
p_θ(z)
pθ(z) 下的期望,我们要采样几个点来算 (理想情况下,希望只采样一个点),换句话说,我们想用下面的近似
于是问题就来了:这样的梯度估计方差很大,这导致了我们用梯度下降优化的时候相当不稳定,非常容易崩
降方差
- 重参数就是一种降方差技巧,为此,我们写出重参数后的梯度表达式:
对比 SF 估计,我们可以直观感知为什么上式方差更小了 (只是一般情况下,并不是绝对成立):(1) SF 估计中包含了
log
p
θ
(
z
)
\log p_θ(z)
logpθ(z),我们知道,作为一个合理的概率分布,一般都在无穷远处 (即
∥
z
∥
→
∞
∥z∥→∞
∥z∥→∞)都会有
p
θ
(
z
)
→
0
p_θ(z)→0
pθ(z)→0,取了
log
\log
log 之后反而会趋于负无穷,换句话说,
log
p
θ
(
z
)
\log p_θ(z)
logpθ(z) 这一项实际上放大了无穷远处的波动,从而一定程度上增加了方差;(2) SF 估计中包含的是
f
f
f 而重参数之后变成了
∂
f
/
∂
g
∂f/∂g
∂f/∂g,
f
f
f 一般是神经网络,而通常我们定义的神经网络模型其实都是
O
(
z
)
\mathscr O(z)
O(z) 级别的模型,从而我们可以预期它的梯度是
O
(
1
)
\mathscr O(1)
O(1) 级别的(不严格成立,只能说在平均意义下基本成立),所以相对情况下更平稳一些,因此
f
f
f 的方差也比
∂
f
/
∂
g
∂f/∂g
∂f/∂g 的方差要大
评论(0)
您还未登录,请登录后发表或查看评论