[zero basic reinforcement learning] teach you to train the rocket lunar lander based on gym with Q-learning

Time:2022-1-25

More code
Gitee home page: https://gitee.com/GZHzzz

Blog home page
:https://blog.csdn.net/gzhzzaa

Write in front

  • As a novice, I write this column of reinforcement learning – basic knowledge to share my learning process of reinforcement learning with you. I hope you can communicate with each other and make progress together! stayMy giteeCollectedClassic papers on Reinforcement LearningClassic papers on Reinforcement Learning, builtTypical agent model based on pytorch, let’s have more exchanges and learn from each other! (●’◡’●)

Algorithm flow

  • Need a small rocket to land on the apron accurately and smoothly (slowly)!

[zero basic reinforcement learning] teach you to train the rocket lunar lander based on gym with Q-learning
[zero basic reinforcement learning] teach you to train the rocket lunar lander based on gym with Q-learning

show me code, no bb

import sys
import logging
import itertools
import copy

import numpy as np
np.random.seed(0)
import pandas as pd
import gym
import matplotlib.pyplot as plt
import torch
torch.manual_seed(0)
import torch.nn as nn
import torch.optim as optim
import torch.distributions as distributions

logging.basicConfig(level=logging.DEBUG,
        format='%(asctime)s [%(levelname)s] %(message)s',
        stream=sys.stdout, datefmt='%H:%M:%S')
     
env = gym.make('LunarLander-v2')
env.seed(0)
for key in vars(env):
    logging.info('%s: %s', key, vars(env)[key])
for key in vars(env.spec):
    logging.info('%s: %s', key, vars(env.spec)[key])
class DQNReplayer:
    def __init__(self, capacity):
        self.memory = pd.DataFrame(index=range(capacity),
                columns=['state', 'action', 'reward', 'next_state', 'done'])
        self.i = 0
        self.count = 0
        self.capacity = capacity

    def store(self, *args):
        self.memory.loc[self.i] = args
        self.i = (self.i + 1) % self.capacity
        self.count = min(self.count + 1, self.capacity)

    def sample(self, size):
        indices = np.random.choice(self.count, size=size)
        return (np.stack(self.memory.loc[indices, field]) for field in
                self.memory.columns)
class SQLAgent:
    def __init__(self, env):
        self.action_n = env.action_space.n
        self.gamma = 0.99

        self.replayer = DQNReplayer(10000)

        self.alpha = 0.02

        self.evaluate_net = self.build_net(
                input_size=env.observation_space.shape[0],
                hidden_sizes=[256, 256], output_size=self.action_n)
        self.optimizer = optim.Adam(self.evaluate_net.parameters(), lr=3e-4)
        self.loss = nn.MSELoss()

    def build_net(self, input_size, hidden_sizes, output_size):
        layers = []
        for input_size, output_size in zip(
                [input_size,] + hidden_sizes, hidden_sizes + [output_size,]):
            layers.append(nn.Linear(input_size, output_size))
            layers.append(nn.ReLU())
        layers = layers[:-1]
        model = nn.Sequential(*layers)
        return model

    def reset(self, mode=None):
        self.mode = mode
        if self.mode == 'train':
            self.trajectory = []
            self.target_net = copy.deepcopy(self.evaluate_net)

    def step(self, observation, reward, done):
        state_tensor = torch.as_tensor(observation,
                dtype=torch.float).squeeze(0)
        q_div_alpha_tensor = self.evaluate_net(state_tensor) / self.alpha
        v_div_alpha_tensor = torch.logsumexp(q_div_alpha_tensor, dim=-1,
                keepdim=True)
        prob_tensor = (q_div_alpha_tensor - v_div_alpha_tensor).exp()
        action_tensor = distributions.Categorical(prob_tensor).sample()
        action = action_tensor.item()
        if self.mode == 'train':
            self.trajectory += [observation, reward, done, action]
            if len(self.trajectory) >= 8:
                state, _, _, act, next_state, reward, done, _ = \
                        self.trajectory[-8:]
                self.replayer.store(state, act, reward, next_state, done)
            if self.replayer.count >= 500:
                self.learn()
        return action

    def close(self):
        pass

    def learn(self):
        # replay
        states, actions, rewards, next_states, dones = \
                self.replayer.sample(128) # replay transitions
        state_tensor = torch.as_tensor(states, dtype=torch.float)
        action_tensor = torch.as_tensor(actions, dtype=torch.long)
        reward_tensor = torch.as_tensor(rewards, dtype=torch.float)
        next_state_tensor = torch.as_tensor(next_states, dtype=torch.float)
        done_tensor = torch.as_tensor(dones, dtype=torch.float)

        # train
        next_q_tensor = self.target_net(next_state_tensor)
        next_v_tensor = self.alpha * torch.logsumexp(next_q_tensor / self.alpha, dim=-1)
        target_tensor = reward_tensor + self.gamma * (1. - done_tensor) * next_v_tensor
        pred_tensor = self.evaluate_net(state_tensor)
        q_tensor = pred_tensor.gather(1, action_tensor.unsqueeze(1)).squeeze(1)
        loss_tensor = self.loss(q_tensor, target_tensor.detach())
        self.optimizer.zero_grad()
        loss_tensor.backward()
        self.optimizer.step()


agent = SQLAgent(env)
#Training
def play_episode(env, agent, max_episode_steps=None, mode=None, render=False):
    observation, reward, done = env.reset(), 0., False
    agent.reset(mode=mode)
    episode_reward, elapsed_steps = 0., 0
    while True:
        action = agent.step(observation, reward, done)
        if render:
            env.render()
        if done:
            break
        observation, reward, done, _ = env.step(action)
        episode_reward += reward
        elapsed_steps += 1
        if max_episode_steps and elapsed_steps >= max_episode_steps:
            break
    agent.close()
    return episode_reward, elapsed_steps


logging.info('==== train ====')
episode_rewards = []
for episode in itertools.count():
    episode_reward, elapsed_steps = play_episode(env.unwrapped, agent,
            max_episode_steps=env._max_episode_steps, render=1,mode='train')
    episode_rewards.append(episode_reward)
    logging.debug('train episode %d: reward = %.2f, steps = %d',
            episode, episode_reward, elapsed_steps)
    if np.mean(episode_rewards[-10:]) > 250:
        break
plt.plot(episode_rewards)


logging.info('==== test ====')
episode_rewards = []
for episode in range(100):
    episode_reward, elapsed_steps = play_episode(env, agent, render=1)
    episode_rewards.append(episode_reward)
    logging.debug('test episode %d: reward = %.2f, steps = %d',
            episode, episode_reward, elapsed_steps)
logging.info('average episode reward = %.2f ± %.2f',
        np.mean(episode_rewards), np.std(episode_rewards))
  • All the codes have been run by yourself, you know! ╰(°▽°)╯

Result display

[zero basic reinforcement learning] teach you to train the rocket lunar lander based on gym with Q-learning

Write at the end

Ten years of sharpening the sword, encourage with you!
More codeGitee home page: https://gitee.com/GZHzzz
Blog home page:https://blog.csdn.net/gzhzzaa

  • Fighting!

Classical model based on pytorchTypical agent model based on pytorch
Classic papers on Reinforcement LearningClassic papers on Reinforcement Learning
[zero basic reinforcement learning] teach you to train the rocket lunar lander based on gym with Q-learning

while True:
	Go life

[zero basic reinforcement learning] teach you to train the rocket lunar lander based on gym with Q-learning

Thanks for sharing! (❁´◡`❁)

Recommended Today

Building the basic static page of Vue chat room

design sketch HTML: <template>     <view>         <view>             <scroll-view scroll-y=”true”>                 <div> <!– Message notification — >                     <div>                         <div>2021-12-28 16:50:00</div> < div > XXX processed this work order < / div >                     </div> <!– Left — >                     <!– <div></div> –> <!– Right — >                     <!– <div></div> –>               </div>               <div>                 <div>                     <image src=”../../static/logo.png”>                     <div>                         <div>2021-12-28 16:50:00</div> < […]