TD3 (Twin Delayed DDPG)

  • off-policy
  • only continuous action spaces
  • Actor-Critic structure Sequential Decision

Principle

尽管DDPG有时可以实现出色的性能,但它在超参数和其他类型的调整方面通常很脆弱。 DDPG的常见问题在于,学习到的Q函数对Q值的过估计。然后导致策略中断,因为它利用了Q函数中的错误。 双延迟DDPG(TD3)是一种通过引入三个关键tricks来解决此问题的算法:

  • Clipped Double-Q Learning:跟Double DQN 解决Q值过估计的做法一样,学习两个Q-functions而不是一个(这也是名字里"twin"的由来),并使用两个Q值中较小的一个来做Bellman损失函数中的target;
  • “Delayed” Policy Updates:TD3的策略以及target networks的更新频率低于Q-functions,由于一次策略更新会改变target,延缓更新有助于缓解DDPG中通常出现的波动性。建议每两个Q-func更新进行一次策略更新
  • Target Policy Smoothing:TD3会给target action增加噪声,从而通过沿动作变化平滑Q来使策略更难利用Q-func的error。

target policy smoothing:

构成Q-learning target的action是基于target policy

 [公式] 的,TD3在action的每个维度上都添加了clipped noise,从而使target action将被限定在有效动作范围内(all valid actions, a , satisfy

 [公式] )。因此,target actions写作:

[公式]

target policy smoothing实质上是算法的正则化器。 它解决了DDPG中可能发生的特定故障:如果Q函数逼近器为某些操作产生了不正确的尖峰,该策略将迅速利用该峰,并出现脆性或错误行为。 可以通过在类似action上使Q函数变得平滑来修正,即target policy smoothing。

clipped double-Q learning: 两个Q函数都使用一个目标,使用两个Q函数中的任何一个计算得出的目标值都较小:

[公式]

然后通过回归此目标来学习两者:

[公式]

使用较小的Q值作为目标值,然后逐步回归该值,有助于避免Q函数的过高估计。

Pseudocode

Implement

class Skylark_TD3():
    def __init__(
            self,
            env,
            gamma=0.99,
            tau=0.005,
            policy_noise=0.2,
            noise_clip=0.5,
            policy_freq=2):

        self.env = env
        # Varies by environment
        self.state_dim = self.env.observation_space.shape[0]
        self.action_dim = self.env.action_space.shape[0]
        self.max_action = float(self.env.action_space.high[0])

        self.actor = Actor(self.state_dim, 256, self.action_dim, self.max_action).to(device)
        self.actor_target = copy.deepcopy(self.actor)
        self.actor_optimizer = torch.optim.Adam(
            self.actor.parameters(), lr=3e-4)

        self.critic = Critic(self.state_dim, 256, self.action_dim).to(device)
        self.critic_target = copy.deepcopy(self.critic)
        self.critic_optimizer = torch.optim.Adam(
            self.critic.parameters(), lr=3e-4)

        self.discount = gamma
        self.tau = tau
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        self.policy_freq = policy_freq
        self.start_timesteps = 1e3  # Time steps initial random policy is used
        self.expl_noise = 0.1        # Std of Gaussian exploration noise

        self.total_iteration = 0

    def select_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1)).to(device)
        return self.actor(state).cpu().data.numpy().flatten()

    def learn(self, replay_buffer, batch_size=100):
        self.total_iteration += 1

        # Sample replay buffer
        state, action, next_state, reward, not_done = replay_buffer.sample(
            batch_size)

        with torch.no_grad():
            # Select action according to policy and add clipped noise
            noise = (
                torch.randn_like(action) * self.policy_noise
            ).clamp(-self.noise_clip, self.noise_clip)

            next_action = (
                self.actor_target(next_state) + noise
            ).clamp(-self.max_action, self.max_action)

            # Compute the target Q value
            target_Q1, target_Q2 = self.critic_target(next_state, next_action)
            target_Q = torch.min(target_Q1, target_Q2)
            target_Q = reward + not_done * self.discount * target_Q

        # Get current Q estimates
        current_Q1, current_Q2 = self.critic(state, action)

        # Compute critic loss
        critic_loss = F.mse_loss(current_Q1, target_Q) + \
            F.mse_loss(current_Q2, target_Q)

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # Delayed policy updates
        if self.total_iteration % self.policy_freq == 0:

            # Compute actor losse
            actor_loss = -self.critic.Q1(state, self.actor(state)).mean()

            # Optimize the actor
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            # Update the frozen target models
            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                target_param.data.copy_(
                    self.tau * param.data + (1 - self.tau) * target_param.data)

            for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                target_param.data.copy_(
                    self.tau * param.data + (1 - self.tau) * target_param.data)

    def train(self, num_episodes, batch_size = 256):
        replay_buffer = ReplayBuffer(self.state_dim, self.action_dim)

        episode_num = 0
        for i in range(1, num_episodes):
            state, done = self.env.reset(), False
            episode_reward = 0
            episode_timesteps = 0

            for t in range(1, 1000):
                episode_timesteps += 1

                # Select action randomly or according to policy
                if i * 1000 < self.start_timesteps:
                    action = self.env.action_space.sample()
                else:
                    action = (
                        self.select_action(np.array(state))
                        + np.random.normal(0, self.max_action * self.expl_noise, size=self.action_dim)
                    ).clip(-self.max_action, self.max_action)

                # Perform action
                next_state, reward, done, _ = self.env.step(action) 
                done_bool = float(done) if episode_timesteps < 1000 else 0

                # Store data in replay buffer
                replay_buffer.add(state, action, next_state, reward, done_bool)

                state = next_state
                episode_reward += reward

                # Train agent after collecting sufficient data
                if i * 1000 >= self.start_timesteps:
                    self.learn(replay_buffer, batch_size)

            print('Episode {} : {}'.format(i, episode_reward))

更多实现方式见本专栏关联Github

Reference

  1. TD3 Spinning Up
  2. TD3 pytorch实现