Constructing lightweight reinforcement learning dqn with pytorch lightning

Time:2020-12-6

This paper aims to explore the application of pytorch lightning in the exciting field of reinforcement learning (RL). Here, we will use the classic inverted pendulum gym environment to build a standard deep Q network (dqn) model to illustrate how to start using lightning to build RL models.

-In this paper, we will discuss:
-What is lighting and why should it be applied to RL
-Introduction of standard dqn model
-Steps to build dqn with lightning
-Results and conclusions

##What is lighting?

Lightning is a recently released Python library that clearly abstracts and automates all the everyday boilerplate code that comes with the ML model, allowing you to focus on the actual ml parts, which are often the most interesting ones.

In addition to automating boilerplate code, lightning can also be used as a style guide for building clean and replicable ml systems.
This is very attractive for the following reasons:

1. It is easier to recognize and understand ml code by abstracting template engineering code.
2. The unified structure of lightning makes it very easy to build and understand on the basis of existing projects.
3. Lightning automation code is built with high-quality code that has been fully tested, regularly maintained and follows ml best practices.

## DQN

Before we get into the code, let’s take a quick look at the capabilities of dqn. Dqn learns the best strategy for a given environment by learning the value of each operation in a specific state. These values are called Q values.
At first, the agent’s understanding of its environment is very poor, because it doesn’t have much experience. Therefore, its Q value will be very inaccurate. However, over time, as an agent explores its environment, it learns more accurate Q values and can then make the right decisions. This allows it to be further improved until it eventually converges to an optimal strategy (ideally).
Most of the environments we are interested in, such as modern video games and simulation environments, are too complex and large to store the values of each state / action pair. That’s why we use deep neural networks to approximate these values.
The general life cycle of an agent is as follows:

1. The agent obtains the current state of the environment and calculates it through the network. Then, the network outputs the Q value of each action in a given state.
2. Next, we decide whether to use the network to give the agent the best operation or to take random operation for further exploration.
3. The action is passed to the environment and gets feedback, telling the agent what the next state it is in, the reward for executing the previous action in the previous state, and whether the event in the step is completed.
4. We obtain the experience gained in the last step in the form of tuple (state, behavior, reward, next state, completed event) and store it in the agent memory.
5. Finally, we extract a small batch of repeated experiences from the agent memory and use these past experiences to calculate the agent’s loss.

This is a high level overview of dqn functions.

##Lightweight dqn

The age of enlightenment is an intellectual and philosophical movement that dominates the intellectual world. Let’s look at the components of our dqn
**Model * *: neural networks used to approximate Q-values

**Replay buffer * *: This is the memory of our agent to store previous experience

**Agent * *: the agent itself is something that interacts with the environment and replay buffer

**Lightning module * *: handles all training of agents

##Model

For this example, we can use a very simple MLP. This means that we don’t use anything fancy like convolution or recursion, just normal linear layers. The reason for this is that because of the simplicity of the card inverted pendulum environment, anything more complex than this is overly complex.

classDQN(nn.Module):”””
Simple MLP network
Args:
obs_size: observation/state size of the environment
n_actions: number of discrete actions available in the environment
hidden_size: size of hidden layers
“””def__init__(self, obs_size: int, n_actions: int, hidden_size: int = 128):
super(DQN, self).__init__()
self.net = nn.Sequential(
nn.Linear(obs_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, n_actions)
)

defforward(self, x):return self.net(x.float())

##Replay buffer

The construction of replay buffer is quite direct, we only need some kind of data structure to store tuples. We need to be able to sample these tuples and add new tuples. The buffer in this example is based on Lapins replay buffer, because it is the simplest and fastest implementation I’ve found so far. The code is as follows

# Named tuple for storing experience steps gathered in training
Experience = collections.namedtuple(
‘Experience’, field_names=[‘state’, ‘action’, ‘reward’,
‘done’, ‘new_state’])

classReplayBuffer:”””
Replay Buffer for storing past experiences allowing the agent to learn from them
Args:
capacity: size of the buffer
“””def__init__(self, capacity: int) -> None:
self.buffer = collections.deque(maxlen=capacity)

def__len__(self) -> None:return len(self.buffer)

defappend(self, experience: Experience) -> None:”””
Add experience to the buffer
Args:
experience: tuple (state, action, reward, done, new_state)
“””
self.buffer.append(experience)

defsample(self, batch_size: int) -> Tuple:
indices = np.random.choice(len(self.buffer), batch_size, replace=False)
states, actions, rewards, dones, next_states = zip(*[self.buffer[idx] for idx in indices])

return (np.array(states), np.array(actions), np.array(rewards, dtype=np.float32),
np.array(dones, dtype=np.bool), np.array(next_states))

But we haven’t finished yet. If you used lightning before you knew that its structure was based on the idea of creating a data loader, and then used it to pass on a small batch of experience to each training step, how this works for most ml systems (such as supervisory models). But how does it work when we generate data sets?
We need to create our own iteratable dataset that uses constantly updated replay buffers to sample previous experiences. Then, we have a small set of experiences that are passed on to the training steps to calculate our losses, just like any other model. In addition to including inputs and tags, our small batch contains (status, behavior, reward, next state, completed event)

classRLDataset(IterableDataset):”””
Iterable Dataset containing the ReplayBuffer
which will be updated with new experiences during training
Args:
buffer: replay buffer
sample_size: number of experiences to sample at a time
“””def__init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None:
self.buffer = buffer
self.sample_size = sample_size

def__iter__(self) -> Tuple:
states, actions, rewards, dones, new_states = self.buffer.sample(self.sample_size)
for i in range(len(dones)):
yield states[i], actions[i], rewards[i], dones[i], new_states[i]

As you can see, when we create the dataset, we pass in a replay buffer from which we can sample to allow the data loader to pass the batch to the lightning module.

##Agent

The agent class handles the interaction with the environment. There are three main methods for agent class

get_ Action: using the passed ε value, the agent decides whether to use random operations or perform the operation with the highest Q value from the network output.

play_ Step: in this case, the agent uses the_ The action selected in action performs a step in the environment. After getting feedback from the environment, the experience is stored in the replay buffer. If the environment has completed this step, the environment is reset. Finally, return the current reward and completion flag.

Reset: resets the environment and updates the current state stored in the agent.

classAgent:”””
Base Agent class handeling the interaction with the environment
Args:
env: training environment
replay_buffer: replay buffer storing experiences
“””def__init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None:
self.env = env
self.replay_buffer = replay_buffer
self.reset()
self.state = self.env.reset()

defreset(self) -> None:””” Resents the environment and updates the state”””
self.state = self.env.reset()

defget_action(self, net: nn.Module, epsilon: float, device: str) -> int:”””
Using the given network, decide what action to carry out
using an epsilon-greedy policy
Args:
net: DQN network
epsilon: value to determine likelihood of taking a random action
device: current device
Returns:
action
“””if np.random.random() < epsilon:
action = self.env.action_space.sample()
else:
state = torch.tensor([self.state])

if device notin [‘cpu’]:
state = state.cuda(device)

q_values = net(state)
_, action = torch.max(q_values, dim=1)
action = int(action.item())

return action

@torch.no_grad()defplay_step(self, net: nn.Module, epsilon: float = 0.0, device: str = ‘cpu’) -> Tuple[float, bool]:”””
Carries out a single interaction step between the agent and the environment
Args:
net: DQN network
epsilon: value to determine likelihood of taking a random action
device: current device
Returns:
reward, done
“””

action = self.get_action(net, epsilon, device)

# do step in the environment
new_state, reward, done, _ = self.env.step(action)

exp = Experience(self.state, action, reward, done, new_state)

self.replay_buffer.append(exp)

self.state = new_state
if done:
self.reset()
return reward, done

##Lightning module

Now that we have established the core classes for dqn, we can start thinking about training dqn agents. This is where lighting comes in. We will lay out all our training logic in a clean and structured way by building a lighting module.

Lightning provides many interfaces and rewritable functions for maximum flexibility, but we have to implement four key methods to make the project run. This is the following:

1. forward()
2. configure_optimizers
3. train_dataloader
4. train_step

With the filling of these four methods, we can train any ML model we encounter very well. Anything that needs to go beyond these methods works well with the rest of the interfaces and callbacks in lightning. For a complete list of these available interfaces, check the lightning documentation. Now, let’s look at our lightweight model.

###Initialization

First, we need to initialize our environment, network, agent and replay buffer. We also call the crowd function, which fills the replay buffer randomly (the crowd function is shown in the full code example below).

###Forward pass

What we do here is encapsulate the forward transfer function of our dqn network.

###Loss function

Before we start training agents, we need to define the loss function. The loss function used here is based on the Lapan implementation.

This is a simple mean square error (MSE) loss. The current state action value of our dqn network is compared with the expected state action value of the next state. In RL, we don’t have a perfect tag to learn from; instead, the agent learns from the target value of its expected next state value.

However, by using the same network to predict the value of the current state and the value of the next state, the result will become an unstable moving target. To counter this, we use the target network. This network is a copy of the primary network and is synchronized with the primary network on a regular basis. This provides a temporary fixed target, allowing the agent to calculate a more stable loss function.

As you can see, the state operation value is calculated using the primary network, while the next state value (equivalent to our target / tag) uses the target network.

###Optimizer

This is another simple addition, just telling lighting what optimizer will be used during the reverse pass. We will use the standard Adam optimizer.

###Training data loader

Next, we need to provide lightning with our training data loader. As you might expect, we initialized the iteratable dataset we created earlier. Then pass this to the data loader as usual. Lightning will process the supplied batches during the training, convert them to a python tensor, and move them to the correct device.

###Training steps

Finally, we have the training steps. Here, we enter all the logic to be performed for each training iteration.
During each training iteration, we want the agent to call the previously defined agent.play_ Step () and pass in the current device and the value of ε to perform a step in the environment. This returns the reward for the step and whether the iteration was completed in the step. We add step rewards to the entire event to track the agent’s success in that event.
Next, we use the current small batch provided by lighting to calculate our loss.
If we have reached the end of this iteration, indicated by the done flag, we will update the current total with session forward_ Reward variable.
At the end of the step, we check to see if it’s time to synchronize the primary and target networks. Soft update is usually used when only a part of the weight is updated, but for this simple example, a full update is sufficient.
Finally, we need to return a dict containing the loss lightning will use for backpropagation, one dict containing the values we want to record (Note: these values must be tensors), and the other dict containing any values we want to display on the progress bar.

So we now have everything we need to run the dqn agent.

###Running agent

Now all we have to do is initialize and adapt our lighting model. In our main Python file, we’ll set the seed and provide an Arg parser that contains any necessary hyperparameters we want to pass to the model.

In our main model, we use dqnling to initialize. Next is the setting of lightning trainer.
Here, we set up the coaching process to use the GPU. If you do not have access to the GPU, please remove “GPU” and “distributed” from the trainer_ Backend “parameter. Training in this mode is very fast, even with CPU, so in order to observe lightning during operation, we will turn off the early stop mechanism.
Finally, because we are using an iteratable dataset, we need to specify val_ check_ interval。 Typically, this interval is set automatically based on the length of the dataset. However, iteratable datasets do not have a length function. Therefore, we need to set this value ourselves, even if we don’t perform the validation step.

The last step is to call the trainer.fit (), and watch its training.

##Results

After about 1200 generations, you will see that the total reward of the agent reaches the maximum score of 200. To see the reward metric being drawn, call

tensorboard –logdir lightning_logs

![](http://images.deephub.ai/uplo…
In the picture on the left, you can see the reward for each step. Because of the nature of the environment, this will always be 1, because the agent gets a + 1 reward for each step, and the poles never drop (that’s all the rewards). On the way to the right, we can see the total reward for each step. The agent quickly reaches the highest reward, and then fluctuates between the good and the bad.
conclusion
Now you’ve seen how easy and practical it is to leverage the power of pytorch lightning in an intensive learning program.
This is a very simple example, just to illustrate the use of lighting in RL, so there is a lot of room for improvement. If you want to use this code as a template and try to implement your own proxy, here are some things I’ll try.
It might be better to reduce the learning rate. Through the configure_ The optimizer method initializes the learning rate scheduler to use it.

1. Improve the synchronization rate of the target network or use soft update instead of full update
2. Use more progressive ε attenuation in the process of more steps.
3. By setting Max in the trainer_ Epochs to increase the training algebra.
4. Track the average total rewards as well as the total rewards in the tensorboard log.
5. Use test / Val lightning hook to add test and verification steps
6. Finally, try some more complex models and environments
I hope this article will help you start the project. Happy code!

Author: Donal Byrne
Deep hub translation group: tensor Zhang

Focus on WeChat official account’deephub-imba’to get the complete code of this article.
Constructing lightweight reinforcement learning dqn with pytorch lightning