在机器学习中,经常需要对为随机优化计算loss function的梯度,有时这些loss function会写成期望的形式。比如在变分推断中,需要计算ELBO loss(包含期望的项)的导数(derivative)。另外就是强化学习的Policy Gradient算法中的目标函数(也就是loss function)就是算期望的reward。
但是一般是不能直接计算期望的梯度的,而REINFORCE和Reparameterization就是两种常见的trick,用来计算这些期望函数的梯度,注意这两个trick基于假设以及用处是不同的。

问题设定
给定随机变量x 和参数分布p θ  ,即x ∼ p θ ( x ) 

 我们需要求出函数f ff的期望,即:

这里函数f ff通常就是目标函数或者损失函数(loss function),比较棘手的问题就是前面讲了我们很难直接算出这个期望然后求梯度,所以就可以用REINFORCE或者Reparameterization这两种trick来将问题转换求解。

REINFORCE
REINFORCE有时会也叫score function estimator或者likelihood ratio estimator,它主要用了一个对数微分的trick,如下:

注意这是一个无偏估计,也就是方差会很大(可以搜索trade-off bias and variance),也有一些方法可以减少方差比如 Importance Sampling或者Rao-Blackwellization,具体参考这本书的8,9,10章。

Torch中的REINFORCE以强化学习的Policy Gradient算法算法为例,这里的函数f 对应着reward function,参数分布
p θ 就对应着策略π θ 

 (输出为action),代码如下:

# from torch.distributions import Bernoulli
from torch.distributions import Categorical
probs = policy_network(state)
# Note that this is equivalent to what used to be called multinomial
m = Categorical(probs)
action = m.sample()
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward
loss.backward()

其中policy_network是一个神经网络,输入为当前的state,输出一个当前动作的概率,动作action服从Categorical分布,当然也可以是Bernoulli分布或者正态分布,看实际需求。reward函数可以通过环境的env.step给出,然后根据公式(4)就可以算出loss fucntion,最后通过loss.backward()进行反向传播了。

Reparameterization trick
回顾公式(1):

其中ε ∼ N ( 0 , 1 ) \varepsilon \sim \mathcal{N}(0,1)ε∼N(0,1),如下图可以表示Reparameterization:

在这里插入图片描述

注意图中的z zz就是这里的x xx,通过引入一个stochastic的变量ϵ \epsilonϵ并结合以上公式就可以把期望的梯度写成:

然后就可以方便地计算梯度了,注意Reparameterization比REINFORCE拥有更低的方差(variance)但是只适用于可微分的模型。

Torch中的Reparameterization
同样针对强化学习中的Policy Gradient算法,如下

params = policy_network(state)
m = Normal(*params)
# Any distribution with .has_rsample == True could work based on the application
action = m.rsample()
next_state, reward = env.step(action)  # Assuming that reward is differentiable
loss = -reward
loss.backward()

这里的params指ϵ \epsilonϵ,Normal指上面的分布q,采样出ϵ \epsilonϵ后可以直接通过torch中的rsample()函数(对应g θ 
​)得到相应的action,最后代入到f(对应程序中的env.step)就可以以-reward为loss进行梯度下降了

Refs
torch.distributions
REINFORCE vs Reparameterization Trick