Double DQN算法
问题
Double DQN 原理
代码
1.游戏环境
import gym
#定义环境
class MyWrapper(gym.Wrapper):
def __init__(self):
env = gym.make('CartPole-v1', render_mode='rgb_array')
super().__init__(env)
self.env = env
self.step_n = 0
def reset(self):
state, _ = self.env.reset()
self.step_n = 0
return state
def step(self, action):
state, reward, terminated, truncated, info = self.env.step(action)
over = terminated or truncated
#限制最大步数
self.step_n += 1
if self.step_n >= 200:
over = True
#没坚持到最后,扣分
if over and self.step_n < 200:
reward = -1000
return state, reward, over
#打印游戏图像
def show(self):
from matplotlib import pyplot as plt
plt.figure(figsize=(3, 3))
plt.imshow(self.env.render())
plt.show()
env = MyWrapper()
env.reset()
env.show()
2.Q价值函数
import torch
#定义模型,评估状态下每个动作的价值
model = torch.nn.Sequential(
torch.nn.Linear(4, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, 2),
)
#延迟更新的模型,用于计算target
model_delay = torch.nn.Sequential(
torch.nn.Linear(4, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, 2),
)
#复制参数
model_delay.load_state_dict(model.state_dict())
model, model_delay
3.单条轨迹
from IPython import display
import random
#玩一局游戏并记录数据
def play(show=False):
data = []
reward_sum = 0
state = env.reset()
over = False
while not over:
action = model(torch.FloatTensor(state).reshape(1, 4)).argmax().item()
if random.random() < 0.1:
action = env.action_space.sample()
next_state, reward, over = env.step(action)
data.append((state, action, reward, next_state, over))
reward_sum += reward
state = next_state
if show:
display.clear_output(wait=True)
env.show()
return data, reward_sum
play()[-1]
4.经验池
#数据池
class Pool:
def __init__(self):
self.pool = []
def __len__(self):
return len(self.pool)
def __getitem__(self, i):
return self.pool[i]
#更新动作池
def update(self):
#每次更新不少于N条新数据
old_len = len(self.pool)
while len(pool) - old_len < 200:
self.pool.extend(play()[0])
#只保留最新的N条数据
self.pool = self.pool[-2_0000:]
#获取一批数据样本
def sample(self):
data = random.sample(self.pool, 64)
state = torch.FloatTensor([i[0] for i in data]).reshape(-1, 4)
action = torch.LongTensor([i[1] for i in data]).reshape(-1, 1)
reward = torch.FloatTensor([i[2] for i in data]).reshape(-1, 1)
next_state = torch.FloatTensor([i[3] for i in data]).reshape(-1, 4)
over = torch.LongTensor([i[4] for i in data]).reshape(-1, 1)
return state, action, reward, next_state, over
pool = Pool()
pool.update()
pool.sample()
len(pool), pool[0]
5.训练
#训练
def train():
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
loss_fn = torch.nn.MSELoss()
#共更新N轮数据
for epoch in range(1000):
pool.update()
#每次更新数据后,训练N次
for i in range(200):
#采样N条数据
state, action, reward, next_state, over = pool.sample()
#计算value
value = model(state).gather(dim=1, index=action)
#计算target
with torch.no_grad():
target = model_delay(next_state)
target = target.max(dim=1)[0].reshape(-1, 1)
target = target * 0.99 * (1 - over) + reward
loss = loss_fn(value, target)
loss.backward()
optimizer.step()
optimizer.zero_grad()
#复制参数
if (epoch + 1) % 5 == 0:
model_delay.load_state_dict(model.state_dict())
if epoch % 100 == 0:
test_result = sum([play()[-1] for _ in range(20)]) / 20
print(epoch, len(pool), test_result)
train()
评论(0)
您还未登录,请登录后发表或查看评论