基本概念


  • 重参数 (Reparameterization) 实际上是处理如下期望形式的目标函数的一种技巧:
    在这里插入图片描述上面的期望式可能在如下情形中出现 (e.g. VAE):假设我们在模型的前向传播过程中得到了随机变量




    Z



    Z


    Z
    的概率分布





    p


    θ



    (


    z


    )



    p_\theta(z)


    pθ(z)
    ,其中




    θ



    \theta


    θ
    为模型参数,然后需要根据





    p


    θ



    (


    z


    )



    p_\theta(z)


    pθ(z)
    对随机变量




    Z



    Z


    Z
    进行采样,再根据采样得到的值




    z



    z


    z
    完成后续的前向传播过程,例如计算训练损失




    f


    (


    z


    )



    f(z)


    f(z)
    . 此时,训练损失





    L


    θ




    L_\theta


    Lθ
    即可写为上述期望的形式。然而,这里存在一个很大的问题,就是采样操作是不可导的,虽然我们可以完成模型的前向传播,但反向传播时却无法计算出梯度








    L


    θ



    /





    θ



    \partial L_\theta/\partial \theta


    Lθ/θ
    ,也就无法进行模型的训练。而 Reparameterization 则是提供了一种变换,使得我们可以直接从





    p


    θ



    (


    z


    )



    p_θ(z)


    pθ(z)
    中采样,并且保留




    θ



    θ


    θ
    的梯度
    ,也就是将采样操作由不可导变为可导
  • 重参数假设从分布





    p


    θ



    (


    z


    )



    p_θ(z)


    pθ(z)
    中采样可以分解为两个步骤:(1) 从无参数分布




    q


    (


    ε


    )



    q(ε)


    q(ε)
    中采样一个




    ε



    ε


    ε
    ;(2) 通过变换




    z


    =



    g


    θ



    (


    ε


    )



    z=g_θ(ε)


    z=gθ(ε)
    生成




    z



    z


    z
    。那么,上述期望就变成了
    在这里插入图片描述这时候被采样的分布就没有任何参数了,全部被转移到




    f



    f


    f
    内部了,因此可以采样若干个点,当成普通的 loss 那样写下来了 (上述重参数过程假定





    p


    θ



    (


    z


    )


    =








    g


    θ



    (


    ε


    )


    =


    z




    q


    (


    ε


    )


    d


    ε


    =





    δ


    (


    z






    g


    θ



    (


    ε


    )


    )


    q


    (


    ε


    )


    d


    ε



    p_θ(z)=∫_{g_θ(ε)=z}q(ε)dε=∫δ(z−g_θ(ε))q(ε)dε


    pθ(z)=gθ(ε)=zq(ε)dε=δ(zgθ(ε))q(ε)dε





    δ


    (





    )



    δ(⋅)


    δ()
    是狄拉克函数,因此有





    L


    θ



    =



    E



    z






    p


    θ



    (


    z


    )




    [


    f


    (


    z


    )


    ]


    =





    q


    (


    ε


    )


    δ


    (


    z






    g


    θ



    (


    ε


    )


    )


    f


    (


    z


    )


    d


    ε


    d


    z


    =





    q


    (


    ε


    )


    f


    (



    g


    θ



    (


    ε


    )


    )


    d


    ε


    =



    E



    ε





    q


    (


    ε


    )




    [


    f


    (



    g


    θ



    (


    ε


    )


    )


    ]



    L_\theta=\mathbb E_{z∼p_θ(z)}[f(z)]=\iint q(\varepsilon)\delta(z - g_{\theta}(\varepsilon)) f(z)d\varepsilon dz=\int q(\varepsilon) f(g_{\theta}(\varepsilon))d\varepsilon=\mathbb{E}_{\varepsilon\sim q(\varepsilon)}[f(g_{\theta}(\varepsilon))]


    Lθ=Ezpθ(z)[f(z)]=q(ε)δ(zgθ(ε))f(z)dεdz=q(ε)f(gθ(ε))dε=Eεq(ε)[f(gθ(ε))]
    )

连续情形


  • 简单起见,我们先考虑




    z



    z


    z
    为连续随机变量的情形:
    在这里插入图片描述在 VAE 中常见的是正态分布





    p


    θ



    (


    z


    )


    =


    N



    (


    z


    ;



    μ


    θ



    ,



    σ


    θ


    2



    )




    p_{\theta}(z)=\mathcal{N}\left(z;\mu_{\theta},\sigma_{\theta}^2\right)


    pθ(z)=N(z;μθ,σθ2)
  • 总的来说,连续情形的重参数还是比较简单的。从数学本质来看,重参数是一种积分变换,即原来是关于




    z



    z


    z
    积分,通过




    z


    =



    g


    θ



    (


    ε


    )



    z=g_θ(ε)


    z=gθ(ε)
    变换之后得到新的积分形式。一个最简单的例子就是正态分布:对于正态分布来说,重参数就是 “从




    N


    (


    z


    ;



    μ


    θ



    ,



    σ


    θ


    2



    )



    N(z;μ_θ,σ^2_θ)


    N(z;μθ,σθ2)
    中采样一个




    z



    z


    z
    ” 变成 “从




    N


    (


    ε


    ;


    0


    ,


    1


    )



    N(ε;0,1)


    N(ε;0,1)
    中采样一个




    ε



    ε


    ε
    ,然后计算




    ε


    ×



    σ


    θ



    +



    μ


    θ




    ε×σ_θ+μ_θ


    ε×σθ+μθ
    ”,所以
    在这里插入图片描述

离散情形


  • 为了突出 “离散”,我们将随机变量




    z



    z


    z
    换成




    y



    y


    y
    ,即对于离散情形要面对的目标函数是
    在这里插入图片描述此时,





    p


    θ



    (


    y


    )



    p_\theta(y)


    pθ(y)
    是一个




    k



    k


    k
    分类模型:
    在这里插入图片描述
  • 看到上述期望项中的求和,第一反应可能是 “求和?那就求呗,又不是求不了”。的确,对于离散的随机变量,其期望只不过是有限项求和,理论上确实可以直接完成求和再去梯度下降。但是,如果




    k



    k


    k
    特别大呢?举个例子,假设




    y



    y


    y
    是一个 100 维的向量,每个元素不是 0 就是 1,那么所有不同的




    y



    y


    y
    的总数目就是





    2


    100




    2^{100}


    2100
    ,要对这样的





    2


    100




    2^{100}


    2100
    个单项进行求和,计算量是难以接受的 (每一项都需要计算前向传播过程




    f


    (


    y


    )



    f(y)


    f(y)
    )。所以,还是需要回到采样上去,如果能够采样若干个点就能得到期望的有效估计,并且还不损失梯度信息,那自然是最好了

Gumbel Max


  • 为此,需要先引入 Gumbel Max。假设每个类别的概率是





    p


    1



    ,



    p


    2



    ,





    ,



    p


    k




    p_1,p_2,…,p_k


    p1,p2,,pk
    ,那么 Gumbel Max 提供了一种依概率采样类别的方案:
    在这里插入图片描述也就是说,先算出各个概率的对数




    log






    p


    i




    \log p_i


    logpi
    ,然后从均匀分布




    U


    [


    0


    ,


    1


    ]



    U[0,1]


    U[0,1]
    中采样




    k



    k


    k
    个随机数





    ε


    1



    ,





    ,



    ε


    k




    ε_1,…,ε_k


    ε1,,εk
    ,把





    g


    i



    =





    log





    (





    log






    ε


    i



    )





    Gumbel(0,1)



    g_i=−\log(−\log ε_i)\sim\text{Gumbel(0,1)}


    gi=log(logεi)Gumbel(0,1)
    加到




    log






    p


    i




    \log p_i


    logpi
    上去,最后把最大值对应的类别抽取出来就行了。由于现在的随机性已经转移到




    U


    [


    0


    ,


    1


    ]



    U[0,1]


    U[0,1]
    上去了,并且




    U


    [


    0


    ,


    1


    ]



    U[0,1]


    U[0,1]
    不带有未知参数,因此 Gumbel Max 就是离散分布的一个重参数过程
  • 可以证明,这样的过程精确等价于依概率





    p


    1



    ,



    p


    2



    ,





    ,



    p


    k




    p_1,p_2,…,p_k


    p1,p2,,pk
    采样一个类别,换句话说,在 Gumbel Max 中,输出




    i



    i


    i
    的概率正好是





    p


    i




    p_i


    pi
    . 不失一般性,这里我们证明输出 1 的概率是





    p


    1




    p_1


    p1
    . 注意,输出 1 意味着




    log






    p


    1






    l


    o


    g


    (





    l


    o


    g



    ε


    1



    )



    \log p_1−log(−logε_1)


    logp1log(logε1)
    是最大的,这又意味着:















    log






    p


    1






    log





    (





    log






    ε


    1



    )


    >


    log






    p


    2






    log





    (





    log






    ε


    2



    )
















    log






    p


    1






    log





    (





    log






    ε


    1



    )


    >


    log






    p


    3






    log





    (





    log






    ε


    3



    )





































    log






    p


    1






    log





    (





    log






    ε


    1



    )


    >


    log






    p


    k






    log





    (





    log






    ε


    k



    )







    logp1log(logε1)>logp2log(logε2)logp1log(logε1)>logp3log(logε3)logp1log(logε1)>logpklog(logεk)



    logp1log(logε1)>logp2log(logε2)logp1log(logε1)>logp3log(logε3)logp1log(logε1)>logpklog(logεk)
    不失一般性,我们只分析第一个不等式,化简后得到:






    ε


    2



    <



    ε


    1




    p


    2



    /



    p


    1








    1



    \varepsilon_2 < \varepsilon_1^{p_2 / p_1}\leq 1


    ε2<ε1p2/p11
    由于





    ε


    2






    U


    [


    0


    ,


    1


    ]



    ε_2∼U[0,1]


    ε2U[0,1]
    ,所以





    ε


    2



    <



    ε


    1




    p


    2



    /



    p


    1






    ε_2<ε^{p_2/p_1}_1


    ε2<ε1p2/p1
    的概率就是





    ε


    1




    p


    2



    /



    p


    1






    ε^{p_2/p_1}_1


    ε1p2/p1
    ,这就是固定





    ε


    1




    ε_1


    ε1
    的情况下,第一个不等式成立的概率。那么,所有不等式同时成立的概率是






    ε


    1




    p


    2



    /



    p


    1






    ε


    1




    p


    3



    /



    p


    1









    ε


    1




    p


    k



    /



    p


    1





    =



    ε


    1



    (



    p


    2



    +



    p


    3



    +





    +



    p


    k



    )


    /



    p


    1





    =



    ε


    1



    (


    1


    /



    p


    1



    )





    1





    \varepsilon_1^{p_2 / p_1}\varepsilon_1^{p_3 / p_1}\dots \varepsilon_1^{p_k / p_1}=\varepsilon_1^{(p_2 + p_3 + \dots + p_k) / p_1}=\varepsilon_1^{(1/p_1)-1}


    ε1p2/p1ε1p3/p1ε1pk/p1=ε1(p2+p3++pk)/p1=ε1(1/p1)1
    然后对所有





    ε


    1




    ε_1


    ε1
    求平均,就是









    0


    1




    ε


    1



    (


    1


    /



    p


    1



    )





    1




    d



    ε


    1



    =



    p


    1




    \int_0^1 \varepsilon_1^{(1/p_1)-1}d\varepsilon_1 = p_1


    01ε1(1/p1)1dε1=p1

Gumbel Softmax


  • 我们希望重参数不丢失梯度信息,但是 Gumbel Max 做不到,因为





    arg max







    \argmax


    argmax
    不可导,为此,需要做进一步的近似
    。首先,留意到在神经网络中,处理离散输入的基本方法是转化为 one hot 形式,包括 Embedding 层的本质也是 one hot 全连接,因此





    arg max







    \argmax


    argmax
    实际上是




    one_hot


    (



    arg max






    )



    \text{one_hot}(\argmax)


    one_hot(argmax)
    ,然后,我们寻求




    one_hot


    (



    arg max






    )



    \text{one_hot}(\argmax)


    one_hot(argmax)
    的光滑近似,它就是




    s


    o


    f


    t


    m


    a


    x



    softmax


    softmax
    . 由此,我们得到 Gumbel Max 的光滑近似版本——Gumbel Softmax
    在这里插入图片描述其中参数




    τ


    >


    0



    τ>0


    τ>0
    称为退火参数,它越小输出结果就越接近 one hot 形式 (但同时梯度消失就越严重)。提示一个小技巧,如果





    p


    i




    p_i


    pi





    s


    o


    f


    t


    m


    a


    x



    softmax


    softmax
    的输出,那么大可不必先算出





    p


    i




    p_i


    pi
    再取对数,直接将




    log






    p


    i




    \log p_i


    logpi
    替换为





    o


    i




    o_i


    oi
    即可:
    在这里插入图片描述
  • 跟连续情形一样,Gumbel Softmax 就是用在需要求





    E



    y






    p


    θ



    (


    y


    )




    [


    f


    (


    y


    )


    ]



    \mathbb{E}_{y\sim p_{\theta}(y)}[f(y)]


    Eypθ(y)[f(y)]
    、且无法直接完成对




    y



    y


    y
    求和的场景,这时候我们算出





    p


    θ



    (


    y


    )



    p_θ(y)


    pθ(y)
    (或者





    o


    i




    o_i


    oi
    ),然后选定一个




    τ


    >


    0



    τ>0


    τ>0
    ,用 Gumbel Softmax 算出一个随机向量来





    y


    ~




    \tilde y


    y~
    ,代入计算得到




    f


    (



    y


    ~



    )



    f(\tilde y)


    f(y~)
    ,它就是





    E



    y






    p


    θ



    (


    y


    )




    [


    f


    (


    y


    )


    ]



    \mathbb{E}_{y\sim p_{\theta}(y)}[f(y)]


    Eypθ(y)[f(y)]
    的一个好的近似,且保留了梯度信息
  • 注意,Gumbel Softmax 不是类别采样的等价形式,Gumbel Max 才是。而 Gumbel Max 可以看成是 Gumbel Softmax 在




    τ





    0



    τ→0


    τ0
    时的极限。当




    τ



    τ


    τ
    比较小时,Gumbel Softmax 采样得到的样本接近 one-hot vector,也就比较接近实际的采样情况,但梯度的方差比较大;当




    τ



    τ


    τ
    比较大时,Gumbel Softmax 采样得到的样本比较平滑 (一个平滑的概率向量,向量的各个分量的值都差不多),但梯度的方差比较小。所以在应用 Gumbel Softmax 时,开始可以选择较大的




    τ



    τ


    τ
    (比如 1),然后慢慢退火到一个接近于 0 的数(比如 0.01),这样才能得到比较好的结果



Gumbel Softmax v.s. Softmax


  • Gumbel Softmax 通过




    τ





    0



    τ→0


    τ0
    的退火来逐渐逼近 one hot,相比直接用原始的 Softmax 进行退火,区别在于原始 Softmax 退火只能得到最大值位置为 1 的 one hot 向量,而 Gumbel Softmax 有概率得到非最大值位置的 one hot 向量,增加了随机性,会使得基于采样的训练更充分一些

Straight-Through Gumbel-Softmax Estimator


  • 由 Gumbel Softmax 得到的采样样本是实际采样样本的一个近似,它甚至都不在离散变量的取值范围之内,即使




    τ



    τ


    τ
    比较小,Gumbel Softmax 采样得到的样本也只是接近 one-hot vector,而非真正离散化的 one-hot vector. 但总存在那么一些场景,我们只想采样离散值而非连续值 (e.g. RL 中从离散的动作空间中采样)
  • 假设 Gumbel Softmax 输出的采样向量为




    y



    y


    y
    ,为了利用 Gumbel Softmax 采样离散值,我们可以在前向传播时使用




    z


    =


    one_hot


    (



    arg max






    y


    )



    z=\text{one_hot}(\argmax y)


    z=one_hot(argmaxy)
    得到离散的采样值,在反向传播时利用








    θ



    z









    θ



    y



    \nabla_\theta z\approx \nabla_\theta y


    θzθy
    ,对








    θ



    y



    \nabla_\theta y


    θy
    进行梯度回传:





    z


    =


    y


    +


    s


    g


    (


    one_hot


    (



    arg max






    y


    )





    y


    )



    z=y+sg(\text{one_hot}(\argmax y)-y)


    z=y+sg(one_hot(argmaxy)y)
    其中,




    s


    g



    sg


    sg
    为 stop gradient 操作

背后的故事: 梯度估计 (gradient estimator)


  • 重参数就这样介绍完了吗?远远没有,重参数的背后,实际上是一个称为 “梯度估计”的 大家族,而重参数只不过是这个大家族中的一员。每年的 ICLR、ICML 等顶会上搜索gradient estimatorREINFORCE 等关键词,可以搜索到不少文章,说明这是个大家还在钻研的课题。要想说清重参数的来龙去脉,也要说些梯度估计的故事

SF 估计 (Score Function Estimator)


  • 前面我们分别讲了连续型和离散型的重参数,都是在 “loss 层面” 讲述的,也就是说都是想办法把 loss 显式地定义好,剩下的交给框架自动求导、自动优化就是了。而事实上,就算不能显式地写出 loss 函数,也不妨碍我们对它求导,自然也不妨碍我们去用梯度下降了。比如 Score Function Estimator
    在这里插入图片描述这是对原来损失函数的最朴素的估计,在强化学习中




    z



    z


    z
    代表着策略,那么上式就是一个最基本的策略梯度,所以有时候也直接称上述估计为叫 REINFORCE。现在我们可以直接从





    p


    θ



    (


    z


    )



    p_θ(z)


    pθ(z)
    中采样若干个点来估算








    L


    θ



    /





    θ



    \partial L_\theta/\partial \theta


    Lθ/θ
    的值
    了,不用担心会不会没梯度
  • 同时注意到,重参数技巧要求




    f



    f


    f
    可导,但是在诸如强化学习的场景下,




    f


    (


    z


    )



    f(z)


    f(z)
    对应着奖励函数,很难做到光滑可导,此时就必须使用 SF 估计

梯度方差


  • SF 估计看上去很美好,得到了一个连续和离散变量都适用的估计式,那为什么还需要重参数呢?主要的原因是:SF 估计的方差太大。SF 估计是函数




    f


    (


    z


    )










    θ




    log






    p


    θ



    (


    z


    )



    f(z) \frac{\partial}{\partial\theta} \log p_{\theta}(z)


    f(z)θlogpθ(z)
    在分布





    p


    θ



    (


    z


    )



    p_θ(z)


    pθ(z)
    下的期望,我们要采样几个点来算 (理想情况下,希望只采样一个点),换句话说,我们想用下面的近似
    在这里插入图片描述于是问题就来了:这样的梯度估计方差很大,这导致了我们用梯度下降优化的时候相当不稳定,非常容易崩

降方差


  • 重参数就是一种降方差技巧,为此,我们写出重参数后的梯度表达式:
    在这里插入图片描述对比 SF 估计,我们可以直观感知为什么上式方差更小了 (只是一般情况下,并不是绝对成立):(1) SF 估计中包含了




    log






    p


    θ



    (


    z


    )



    \log p_θ(z)


    logpθ(z)
    ,我们知道,作为一个合理的概率分布,一般都在无穷远处 (即







    z












    ∥z∥→∞


    z
    )都会有





    p


    θ



    (


    z


    )





    0



    p_θ(z)→0


    pθ(z)0
    ,取了




    log






    \log


    log
    之后反而会趋于负无穷,换句话说,




    log






    p


    θ



    (


    z


    )



    \log p_θ(z)


    logpθ(z)
    这一项实际上放大了无穷远处的波动,从而一定程度上增加了方差;(2) SF 估计中包含的是




    f



    f


    f
    而重参数之后变成了







    f


    /





    g



    ∂f/∂g


    f/g





    f



    f


    f
    一般是神经网络,而通常我们定义的神经网络模型其实都是




    O


    (


    z


    )



    \mathscr O(z)


    O(z)
    级别的模型,从而我们可以预期它的梯度是




    O


    (


    1


    )



    \mathscr O(1)


    O(1)
    级别的(不严格成立,只能说在平均意义下基本成立),所以相对情况下更平稳一些,因此




    f



    f


    f
    的方差也比







    f


    /





    g



    ∂f/∂g


    f/g
    的方差要大

References