PPO-KL散度近端策略优化玩cartpole游戏

 

其实KL散度在这个游戏里的作用不大,游戏的action比较简单,不像LM里的action是一个很大的向量,可以直接用surr1,最大化surr1,实验测试确实是这样,而且KL的系数不能给太大,否则惩罚力度太大,action model 和ref model产生的action其实分布的差距并不太大

 

import gym import torch import torch.nn as nn import torch.optim as optim import numpy as np import pygame import sys from collections import deque  # 定义策略网络 class PolicyNetwork(nn.Module):     def __init__(self):         super(PolicyNetwork, self).__init__()         self.fc = nn.Sequential(             nn.Linear(4, 2),             nn.Tanh(),             nn.Linear(2, 2),  # CartPole的动作空间为2             nn.Softmax(dim=-1)         )      def forward(self, x):         return self.fc(x)  # 定义值网络 class ValueNetwork(nn.Module):     def __init__(self):         super(ValueNetwork, self).__init__()         self.fc = nn.Sequential(             nn.Linear(4, 2),             nn.Tanh(),             nn.Linear(2, 1)         )      def forward(self, x):         return self.fc(x)  # 经验回放缓冲区 class RolloutBuffer:     def __init__(self):         self.states = []         self.actions = []         self.rewards = []         self.dones = []         self.log_probs = []          def store(self, state, action, reward, done, log_prob):         self.states.append(state)         self.actions.append(action)         self.rewards.append(reward)         self.dones.append(done)         self.log_probs.append(log_prob)          def clear(self):         self.states = []         self.actions = []         self.rewards = []         self.dones = []         self.log_probs = []      def get_batch(self):         return (             torch.tensor(self.states, dtype=torch.float),             torch.tensor(self.actions, dtype=torch.long),             torch.tensor(self.rewards, dtype=torch.float),             torch.tensor(self.dones, dtype=torch.bool),             torch.tensor(self.log_probs, dtype=torch.float)         )  # PPO更新函数 def ppo_update(policy_net, value_net, optimizer_policy, optimizer_value, buffer, epochs=100, gamma=0.99, clip_param=0.2):     states, actions, rewards, dones, old_log_probs = buffer.get_batch()     returns = []     advantages = []     G = 0     adv = 0     dones = dones.to(torch.int)     # print(dones)     for reward, done, value in zip(reversed(rewards), reversed(dones), reversed(value_net(states))):         if done:             G = 0             adv = 0         G = reward + gamma * G  #蒙特卡洛回溯G值         delta = reward + gamma * value.item() * (1 - done) - value.item()  #TD差分         # adv = delta + gamma * 0.95 * adv * (1 - done)  #         adv = delta + adv*(1-done)         returns.insert(0, G)         advantages.insert(0, adv)      returns = torch.tensor(returns, dtype=torch.float)  #价值     advantages = torch.tensor(advantages, dtype=torch.float)     advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)  #add baseline      for _ in range(epochs):         action_probs = policy_net(states)         dist = torch.distributions.Categorical(action_probs)         new_log_probs = dist.log_prob(actions)         ratio = (new_log_probs - old_log_probs).exp()          KL = new_log_probs.exp()*(new_log_probs - old_log_probs).mean()   #KL散度 p*log(p/p')         #下面三行是核心         surr1 = ratio * advantages          PPO1,PPO2 = True,False         # print(surr1,KL*500)         if PPO1 == True:             actor_loss = -(surr1 - KL).mean()          if PPO2 == True:             surr2 = torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param) * advantages             actor_loss = -torch.min(surr1, surr2).mean()          optimizer_policy.zero_grad()         actor_loss.backward()         optimizer_policy.step()          value_loss = (returns - value_net(states)).pow(2).mean()          optimizer_value.zero_grad()         value_loss.backward()         optimizer_value.step()  # 初始化环境和模型 env = gym.make('CartPole-v1') policy_net = PolicyNetwork() value_net = ValueNetwork() optimizer_policy = optim.Adam(policy_net.parameters(), lr=3e-4) optimizer_value = optim.Adam(value_net.parameters(), lr=1e-3) buffer = RolloutBuffer()  # Pygame初始化 pygame.init() screen = pygame.display.set_mode((600, 400)) clock = pygame.time.Clock()  draw_on = False # 训练循环 state = env.reset() for episode in range(10000):  # 训练轮次     done = False     state = state[0]     step= 0     while not done:         step+=1         state_tensor = torch.FloatTensor(state).unsqueeze(0)         action_probs = policy_net(state_tensor)   #旧policy推理数据         dist = torch.distributions.Categorical(action_probs)         action = dist.sample()         log_prob = dist.log_prob(action)                  next_state, reward, done, _ ,_ = env.step(action.item())         buffer.store(state, action.item(), reward, done, log_prob)                  state = next_state          # 实时显示         for event in pygame.event.get():             if event.type == pygame.QUIT:                 pygame.quit()                 sys.exit()          if draw_on:             # 清屏并重新绘制             screen.fill((0, 0, 0))             cart_x = int(state[0] * 100 + 300)  # 位置转换为屏幕坐标             pygame.draw.rect(screen, (0, 128, 255), (cart_x, 300, 50, 30))             pygame.draw.line(screen, (255, 0, 0), (cart_x + 25, 300), (cart_x + 25 - int(50 * np.sin(state[2])), 300 - int(50 * np.cos(state[2]))), 5)             pygame.display.flip()             clock.tick(60)      if step >2000:         draw_on = True     ppo_update(policy_net, value_net, optimizer_policy, optimizer_value, buffer)     buffer.clear()     state = env.reset()     print(f'Episode {episode} completed , reward:  {step}.')  # 结束训练 env.close() pygame.quit()

 

效果:

PPO-KL散度近端策略优化玩cartpole游戏

 

发表评论

评论已关闭。

相关文章