点个赞啊亲,写的很累的啊

TRPO (Trust Region Policy Optimization)

- on-policy
- either discrete or continuous action spaces

Principle

TRPO译为信赖域策略优化,TRPO的出现是要解决VPG存在的问题的:VPG的更新步长 是个固定值,很容易产生从一个不好的策略'提升'到另一个更差的策略上。

这让我想起了优化中对步长的估计:Armijo-Goldstein准则、Wolfe-Powell准则等。当然和TRPO关系不大。

TRPO有一个大胆的想法,要让更新后的策略回报函数单调不减。一个自然的想法是,将新策略所对应的回报函数表示成旧策略所对应的回报函数+其他项。下式就是TRPO的起手式:

[公式]

[公式]

Proof: (也可以通过构造法反推) 
[公式]

由此,我们就实现了将新策略的回报表示为旧策略回报的目标。

有了起手式,我们在实际操作时候具体怎么计算呢?尤其是优势函数外那个期望的怎么处理?

将其分解成state和action的求和:

[公式]

  • [公式] 是在状态 s 时,计算动作 a 的边际分布
  • [公式] 是在时间 t 时,求状态 s 的边际分布
  • [公式] 对整个时间序列求和。

定义

[公式]

即有

[公式]

这个公式在应用时也无法使用,因为状态 s 是根据新策略的分布产生的,而新策略又是我们要求的,这就导致了含有 $\hat{\pi}$ 的项我们都无从得知。

Tricks

  1. 一个简单的想法:用旧策略代替上式中的新策略
  2. 重要性采样来处理动作分布,也是TRPO的关键: 
    [公式]
    得到 $\hat{\pi}$ 的一阶近似,替代回报函数 [公式]
    [公式]


说完了损失函数的构建,那么步长到底怎么定呢?
[公式]

    • 惩罚因子 [公式]
    • [公式] 为每个状态下动作分布的最大值

得到一个单调递增的策略序列:
[公式] 
可知
[公式] 
[公式] 
我们只需要在新策略序列中找到一个使

 [公式] 最大的策略即可,对策略的搜寻就变成了优化问题:
[公式] 
由于在实际中,C的限制会导致步长过小。因此,在TRPO原文中写作了约束优化问题:
[公式]

3. 利用平均KL散度代替最大KL散度,最大KL不利于数值数值优化。

4. 对约束问题二次近似,非约束问题一次近似,这是凸优化的一种常见改法。最后TRPO利用共轭梯度的方法进行最终的优化。

Q: 为什么觉得TRPO的叙述方式反了?私以为应该是在约束新旧策略的散度的前提下,找到使替代回报函数
 [公式] 最大的
 [公式] -> 转化为约束优化问题,这样就自然多了嘛。所以那一步惩罚因子的作用很让人迷惑,烦请大佬们在评论区解惑

Pseudocode

Implement

class Policy_Network(nn.Module):
    def __init__(self, obs_space, act_space):
        super(Policy_Network, self).__init__()
        self.affine1 = nn.Linear(obs_space, 64)
        self.affine2 = nn.Linear(64, 64)

        self.action_mean = nn.Linear(64, act_space)
        self.action_mean.weight.data.mul_(0.1)
        self.action_mean.bias.data.mul_(0.0)

        self.action_log_std = nn.Parameter(torch.zeros(1, act_space))

        self.saved_actions = []
        self.rewards = []
        self.final_value = 0

    def forward(self, x):
        x = torch.tanh(self.affine1(x))
        x = torch.tanh(self.affine2(x))

        action_mean = self.action_mean(x)
        action_log_std = self.action_log_std.expand_as(action_mean)
        action_std = torch.exp(action_log_std)

        return action_mean, action_log_std, action_std

class Value_Network(nn.Module):
    def __init__(self, obs_space):
        super(Value_Network, self).__init__()
        self.affine1 = nn.Linear(obs_space, 64)
        self.affine2 = nn.Linear(64, 64)
        self.value_head = nn.Linear(64, 1)
        self.value_head.weight.data.mul_(0.1)
        self.value_head.bias.data.mul_(0.0)

    def forward(self, x):
        x = torch.tanh(self.affine1(x))
        x = torch.tanh(self.affine2(x))

        state_values = self.value_head(x)
        return state_values

Transition = namedtuple('Transition', ('state', 'action', 'mask',
                                       'reward', 'next_state'))
class Memory(object):
    def __init__(self):
        self.memory = []

    def push(self, *args):
        """Saves a transition."""
        self.memory.append(Transition(*args))

    def sample(self):
        return Transition(*zip(*self.memory))

    def __len__(self):
        return len(self.memory)

class Skylark_TRPO():
    def __init__(self, env, alpha = 0.1, gamma = 0.6, 
                    tau = 0.97, max_kl = 1e-2, l2reg = 1e-3, damping = 1e-1):
        self.obs_space = 80*80
        self.act_space = env.action_space.n
        self.policy = Policy_Network(self.obs_space, self.act_space)
        self.value = Value_Network(self.obs_space)
        self.env = env
        self.alpha = alpha      # learning rate
        self.gamma = gamma      # discount rate
        self.tau = tau          # 
        self.max_kl = max_kl
        self.l2reg = l2reg
        self.damping = damping

        self.replay_buffer = Memory()
        self.buffer_size = 1000
        self.total_step = 0


    def choose_action(self, state):
        state = torch.unsqueeze(torch.FloatTensor(state), 0)
        action_mean, _, action_std = self.policy(Variable(state))
        action = torch.normal(action_mean, action_std)   
        return action

    def conjugate_gradients(self, Avp, b, nsteps, residual_tol=1e-10):
        x = torch.zeros(b.size())
        r = b.clone()
        p = b.clone()
        rdotr = torch.dot(r, r)
        for i in range(nsteps):
            _Avp = Avp(p)
            alpha = rdotr / torch.dot(p, _Avp)
            x += alpha * p
            r -= alpha * _Avp
            new_rdotr = torch.dot(r, r)
            betta = new_rdotr / rdotr
            p = r + betta * p
            rdotr = new_rdotr
            if rdotr < residual_tol:
                break
        return x


    def linesearch(self, model,
                f,
                x,
                fullstep,
                expected_improve_rate,
                max_backtracks=10,
                accept_ratio=.1):
        fval = f(True).data
        print("fval before", fval.item())
        for (_n_backtracks, stepfrac) in enumerate(.5**np.arange(max_backtracks)):
            xnew = x + stepfrac * fullstep
            set_flat_params_to(model, xnew)
            newfval = f(True).data
            actual_improve = fval - newfval
            expected_improve = expected_improve_rate * stepfrac
            ratio = actual_improve / expected_improve
            print("a/e/r", actual_improve.item(), expected_improve.item(), ratio.item())

            if ratio.item() > accept_ratio and actual_improve.item() > 0:
                print("fval after", newfval.item())
                return True, xnew
        return False, x

    def trpo_step(self, model, get_loss, get_kl, max_kl, damping):
        loss = get_loss()
        grads = torch.autograd.grad(loss, model.parameters())
        loss_grad = torch.cat([grad.view(-1) for grad in grads]).data

        def Fvp(v):
            kl = get_kl()
            kl = kl.mean() # 平均散度

            grads = torch.autograd.grad(kl, model.parameters(), create_graph=True)
            flat_grad_kl = torch.cat([grad.view(-1) for grad in grads])

            kl_v = (flat_grad_kl * Variable(v)).sum()
            grads = torch.autograd.grad(kl_v, model.parameters())
            flat_grad_grad_kl = torch.cat([grad.contiguous().view(-1) for grad in grads]).data

            return flat_grad_grad_kl + v * damping

        stepdir = self.conjugate_gradients(Fvp, -loss_grad, 10)

        shs = 0.5 * (stepdir * Fvp(stepdir)).sum(0, keepdim=True)

        lm = torch.sqrt(shs / max_kl)
        fullstep = stepdir / lm[0]

        neggdotstepdir = (-loss_grad * stepdir).sum(0, keepdim=True)
        print(("lagrange multiplier:", lm[0], "grad_norm:", loss_grad.norm()))

        prev_params = get_flat_params_from(model)
        success, new_params = self.linesearch(model, get_loss, prev_params, fullstep,
                                        neggdotstepdir / lm[0])
        set_flat_params_to(model, new_params)
        return loss

    def learn(self, batch_size=128):
        batch = self.replay_buffer.sample()
        rewards = torch.Tensor(batch.reward)
        masks = torch.Tensor(batch.mask)
        actions = torch.Tensor(np.concatenate(batch.action, 0))
        states = torch.Tensor(batch.state)
        values = self.value(Variable(states))

        returns = torch.Tensor(actions.size(0),1)
        deltas = torch.Tensor(actions.size(0),1)
        advantages = torch.Tensor(actions.size(0),1)

        prev_return = 0
        prev_value = 0
        prev_advantage = 0

        for i in reversed(range(rewards.size(0))):
            returns[i] = rewards[i] + self.gamma * prev_return * masks[i] # 计算了折扣累计回报
            deltas[i] = rewards[i] + self.gamma * prev_value * masks[i] - values.data[i] # V - Q state value的偏差
            advantages[i] = deltas[i] + self.gamma * self.tau * prev_advantage * masks[i] # 优势函数 A

            prev_return = returns[i, 0]
            prev_value = values.data[i, 0]
            prev_advantage = advantages[i, 0]

        targets = Variable(returns)

        # Original code uses the same LBFGS to optimize the value loss
        def get_value_loss(flat_params):
            '''
            构建替代回报函数 L_\pi(\hat{\pi})
            '''
            set_flat_params_to(self.value, torch.Tensor(flat_params))
            for param in self.value.parameters():
                if param.grad is not None:
                    param.grad.data.fill_(0)

            values_ = self.value(Variable(states))

            value_loss = (values_ - targets).pow(2).mean() # (f(s)-r)^2

            # weight decay
            for param in  self.value.parameters():
                value_loss += param.pow(2).sum() * self.l2reg # 参数正则项
            value_loss.backward()
            return (value_loss.data.double().numpy(), get_flat_grad_from(self.value).data.double().numpy())

        # 使用 scipy 的 l_bfgs_b 算法来优化无约束问题
        flat_params, _, opt_info = optimize.fmin_l_bfgs_b(func=get_value_loss, x0=get_flat_params_from(self.value).double().numpy(), maxiter=25)
        set_flat_params_to(self.value, torch.Tensor(flat_params))

        # 归一化优势函数
        advantages = (advantages - advantages.mean()) / advantages.std()

        action_means, action_log_stds, action_stds =  self.policy(Variable(states))
        fixed_log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds).data.clone()

        def get_loss(volatile=False):
            '''
            计算策略网络的loss
            '''
            if volatile:
                with torch.no_grad():
                    action_means, action_log_stds, action_stds = self.policy(Variable(states))
            else:
                action_means, action_log_stds, action_stds = self.policy(Variable(states))

            log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds)
            # -A * e^{\hat{\pi}/\pi_{old}}
            action_loss = -Variable(advantages) * torch.exp(log_prob - Variable(fixed_log_prob))
            return action_loss.mean()


        def get_kl():
            mean1, log_std1, std1 = self.policy(Variable(states))

            mean0 = Variable(mean1.data)
            log_std0 = Variable(log_std1.data)
            std0 = Variable(std1.data)
            kl = log_std1 - log_std0 + (std0.pow(2) + (mean0 - mean1).pow(2)) / (2.0 * std1.pow(2)) - 0.5
            return kl.sum(1, keepdim=True)

        self.trpo_step(self.policy, get_loss, get_kl, self.max_kl, self.damping)


    def train(self, num_episodes, batch_size = 128, num_steps = 100):
        for i in range(num_episodes):
            state = self.env.reset()

            steps, reward, sum_rew = 0, 0, 0
            done = False
            while not done and steps < num_steps:
                state = preprocess(state)
                action = self.choose_action(state)
                action = action.data[0].numpy()
                action_ = np.argmax(action)
                # Interaction with Env
                next_state, reward, done, info = self.env.step(action_) 
                next_state_ = preprocess(next_state)
                mask = 0 if done else 1
                self.replay_buffer.push(state, np.array([action]), mask, reward, next_state_)
                if len(self.replay_buffer) > self.buffer_size:
                    self.learn(batch_size)

                sum_rew += reward
                state = next_state
                steps += 1
                self.total_step += 1
            print('Episode: {} | Avg_reward: {} | Length: {}'.format(i, sum_rew/steps, steps))
        print("Training finished.")

更多实现方式见本专栏关联Github

Reference

  1. TRPO - Medium
  2. TRPO - OpenAI SpinningUp
  3. TRPO与PPO
  4. 强化学习进阶 第七讲 TRPO
  5. TRPO pytorch实现

我终于搞明白怎么把公式居中了!!竟然需要在公式后面打//,kind of weird.

然而我发现在手机端公式变小了??不知道该怎么办了,大家投票吧。