-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rnn support: DRQN agent + recurrent buffer #258
Changes from 44 commits
981be53
0aabcff
9d666a8
f557787
38af33c
41b6041
d2f205f
1ae9cf0
346a46f
921d336
d66e8ad
c0f976f
cc28dfa
4874d7a
cf6cfe4
8e935b9
b3c810c
7264b91
de35f5c
4ce0d65
d985d7e
48135a7
29d7607
9b25182
28178cb
abe510b
8f58643
a682dcc
89e0c05
67f7dfa
9315f23
f1a9dd8
a4e9989
f6d2976
1eb931b
058f979
1df5a69
a9b73f1
9eb4d52
abaa7a8
104823e
088549b
6c6a87b
f441706
fcc8e6e
89376b2
2ffb18f
64a88ac
eedf125
39082aa
9ddf9d2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,281 @@ | ||
import copy | ||
import os | ||
|
||
import gym | ||
import numpy as np | ||
import torch | ||
|
||
from hive.agents.agent import Agent | ||
from hive.agents.dqn import DQNAgent | ||
from hive.agents.qnets.base import FunctionApproximator | ||
from hive.agents.qnets.qnet_heads import DRQNNetwork | ||
from hive.agents.qnets.utils import ( | ||
InitializationFn, | ||
calculate_output_dim, | ||
create_init_weights_fn, | ||
) | ||
from hive.replays import BaseReplayBuffer, CircularReplayBuffer | ||
from hive.replays.recurrent_replay import RecurrentReplayBuffer | ||
from hive.utils.loggers import Logger, NullLogger | ||
from hive.utils.schedule import ( | ||
LinearSchedule, | ||
PeriodicSchedule, | ||
Schedule, | ||
SwitchSchedule, | ||
) | ||
from hive.utils.utils import LossFn, OptimizerFn, create_folder, seeder | ||
|
||
|
||
class DRQNAgent(DQNAgent): | ||
"""An agent implementing the DRQN algorithm. Uses an epsilon greedy | ||
exploration policy | ||
""" | ||
|
||
def __init__( | ||
self, | ||
observation_space: gym.spaces.Box, | ||
action_space: gym.spaces.Discrete, | ||
representation_net: FunctionApproximator, | ||
id=0, | ||
optimizer_fn: OptimizerFn = None, | ||
loss_fn: LossFn = None, | ||
init_fn: InitializationFn = None, | ||
replay_buffer: BaseReplayBuffer = None, | ||
max_seq_len: int = 1, | ||
discount_rate: float = 0.99, | ||
n_step: int = 1, | ||
grad_clip: float = None, | ||
reward_clip: float = None, | ||
update_period_schedule: Schedule = None, | ||
target_net_soft_update: bool = False, | ||
target_net_update_fraction: float = 0.05, | ||
target_net_update_schedule: Schedule = None, | ||
epsilon_schedule: Schedule = None, | ||
test_epsilon: float = 0.001, | ||
min_replay_history: int = 5000, | ||
batch_size: int = 32, | ||
device="cpu", | ||
logger: Logger = None, | ||
log_frequency: int = 100, | ||
**kwargs, | ||
): | ||
""" | ||
Args: | ||
observation_space (gym.spaces.Box): Observation space for the agent. | ||
action_space (gym.spaces.Discrete): Action space for the agent. | ||
representation_net (FunctionApproximator): A network that outputs the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am assuming there are restrictions on the representation_net? eg it needs to be one of your recurrent ones? Please mention this in the documentation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you more explicitly mention the restrictions on representation_net? For example, which methods it should have or that it should follow the structure of ConvRNNNetwork or something? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
representations that will be used to compute Q-values (e.g. | ||
everything except the final layer of the DRQN), as well as the | ||
hidden states of the recurrent component. | ||
id: Agent identifier. | ||
optimizer_fn (OptimizerFn): A function that takes in a list of parameters | ||
to optimize and returns the optimizer. If None, defaults to | ||
:py:class:`~torch.optim.Adam`. | ||
loss_fn (LossFn): Loss function used by the agent. If None, defaults to | ||
:py:class:`~torch.nn.SmoothL1Loss`. | ||
init_fn (InitializationFn): Initializes the weights of qnet using | ||
create_init_weights_fn. | ||
replay_buffer (BaseReplayBuffer): The replay buffer that the agent will | ||
push observations to and sample from during learning. If None, | ||
defaults to | ||
:py:class:`~hive.replays.circular_replay.CircularReplayBuffer`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is incorrect. Please go through all the documentation and make sure it is correct. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
max_seq_len (int): The number of consecutive transitions in a sequence. | ||
discount_rate (float): A number between 0 and 1 specifying how much | ||
future rewards are discounted by the agent. | ||
n_step (int): The horizon used in n-step returns to compute TD(n) targets. | ||
grad_clip (float): Gradients will be clipped to between | ||
[-grad_clip, grad_clip]. | ||
reward_clip (float): Rewards will be clipped to between | ||
[-reward_clip, reward_clip]. | ||
update_period_schedule (Schedule): Schedule determining how frequently | ||
the agent's Q-network is updated. | ||
target_net_soft_update (bool): Whether the target net parameters are | ||
replaced by the qnet parameters completely or using a weighted | ||
average of the target net parameters and the qnet parameters. | ||
target_net_update_fraction (float): The weight given to the target | ||
net parameters in a soft update. | ||
target_net_update_schedule (Schedule): Schedule determining how frequently | ||
the target net is updated. | ||
epsilon_schedule (Schedule): Schedule determining the value of epsilon | ||
through the course of training. | ||
test_epsilon (float): epsilon (probability of choosing a random action) | ||
to be used during testing phase. | ||
min_replay_history (int): How many observations to fill the replay buffer | ||
with before starting to learn. | ||
batch_size (int): The size of the batch sampled from the replay buffer | ||
during learning. | ||
device: Device on which all computations should be run. | ||
logger (ScheduledLogger): Logger used to log agent's metrics. | ||
log_frequency (int): How often to log the agent's metrics. | ||
""" | ||
super().__init__( | ||
observation_space=observation_space, | ||
action_space=action_space, | ||
representation_net=representation_net, | ||
id=id, | ||
optimizer_fn=optimizer_fn, | ||
loss_fn=loss_fn, | ||
init_fn=init_fn, | ||
replay_buffer=replay_buffer, | ||
discount_rate=discount_rate, | ||
n_step=n_step, | ||
grad_clip=grad_clip, | ||
reward_clip=reward_clip, | ||
update_period_schedule=update_period_schedule, | ||
target_net_soft_update=target_net_soft_update, | ||
target_net_update_fraction=target_net_update_fraction, | ||
target_net_update_schedule=target_net_update_schedule, | ||
epsilon_schedule=epsilon_schedule, | ||
test_epsilon=test_epsilon, | ||
min_replay_history=min_replay_history, | ||
batch_size=batch_size, | ||
device=device, | ||
logger=logger, | ||
log_frequency=log_frequency, | ||
) | ||
if replay_buffer is None: | ||
replay_buffer = RecurrentReplayBuffer | ||
self._replay_buffer = replay_buffer( | ||
max_seq_len=max_seq_len, | ||
observation_shape=self._observation_space.shape, | ||
observation_dtype=self._observation_space.dtype, | ||
action_shape=self._action_space.shape, | ||
action_dtype=self._action_space.dtype, | ||
) | ||
self._max_seq_len = max_seq_len | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be moved above the super constructor call:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
def create_q_networks(self, representation_net): | ||
"""Creates the Q-network and target Q-network. | ||
|
||
Args: | ||
representation_net: A network that outputs the representations that will | ||
be used to compute Q-values (e.g. everything except the final layer | ||
of the DRQN). | ||
""" | ||
Comment on lines
+152
to
+155
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should add a comment about the expected output of this network, ie that it outputs a output and hidden state There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
network = representation_net(self._state_size) | ||
network_output_dim = np.prod(calculate_output_dim(network, self._state_size)[0]) | ||
self._qnet = DRQNNetwork(network, network_output_dim, self._action_space.n).to( | ||
self._device | ||
) | ||
|
||
self._qnet.apply(self._init_fn) | ||
self._target_qnet = copy.deepcopy(self._qnet).requires_grad_(False) | ||
self._hidden_state = network.init_hidden(batch_size=1, device=self._device) | ||
|
||
@torch.no_grad() | ||
def act(self, observation): | ||
"""Returns the action for the agent. If in training mode, follows an epsilon | ||
greedy policy. Otherwise, returns the action with the highest Q-value. | ||
|
||
Args: | ||
observation: The current observation. | ||
""" | ||
|
||
# Reset hidden state if it is episode beginning. | ||
if self._state["episode_start"]: | ||
self._hidden_state = self._qnet.base_network.init_hidden( | ||
batch_size=1, device=self._device | ||
) | ||
|
||
# Determine and log the value of epsilon | ||
if self._training: | ||
if not self._learn_schedule.get_value(): | ||
epsilon = 1.0 | ||
else: | ||
epsilon = self._epsilon_schedule.update() | ||
if self._logger.update_step(self._timescale): | ||
self._logger.log_scalar("epsilon", epsilon, self._timescale) | ||
else: | ||
epsilon = self._test_epsilon | ||
|
||
# Sample action. With epsilon probability choose random action, | ||
# otherwise select the action with the highest q-value. | ||
observation = torch.tensor( | ||
np.expand_dims(observation, axis=0), device=self._device | ||
).float() | ||
qvals, self._hidden_state = self._qnet(observation, self._hidden_state) | ||
if self._rng.random() < epsilon: | ||
action = self._rng.integers(self._action_space.n) | ||
else: | ||
# Note: not explicitly handling the ties | ||
action = torch.argmax(qvals).item() | ||
|
||
if ( | ||
self._training | ||
and self._logger.should_log(self._timescale) | ||
and self._state["episode_start"] | ||
): | ||
self._logger.log_scalar("train_qval", torch.max(qvals), self._timescale) | ||
self._state["episode_start"] = False | ||
return action | ||
|
||
def update(self, update_info): | ||
""" | ||
Updates the DRQN agent. | ||
|
||
Args: | ||
update_info: dictionary containing all the necessary information to | ||
update the agent. Should contain a full transition, with keys for | ||
"observation", "action", "reward", and "done". | ||
""" | ||
if update_info["done"]: | ||
self._state["episode_start"] = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is episode start used anywhere except the buffer? because the buffer takes care of it and it is redundant. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is there for stuff in act(). There's probably a better way to do it. It doesn't really make sense for the agent to do it. It might make sense to add it as part of the observation, but need to think about this a bit. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We probably want to fix it in a separate PR as this is what DQN does too. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We created an issue for fixing it in both DQN and DRQN. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this part is fine. I don't think it needs fixing. |
||
|
||
if not self._training: | ||
return | ||
|
||
# Add the most recent transition to the replay buffer. | ||
self._replay_buffer.add(**self.preprocess_update_info(update_info)) | ||
|
||
# Update the q network based on a sample batch from the replay buffer. | ||
# If the replay buffer doesn't have enough samples, catch the exception | ||
# and move on. | ||
if ( | ||
self._learn_schedule.update() | ||
and self._replay_buffer.size() > 0 | ||
and self._update_period_schedule.update() | ||
): | ||
batch = self._replay_buffer.sample(batch_size=self._batch_size) | ||
( | ||
current_state_inputs, | ||
next_state_inputs, | ||
batch, | ||
) = self.preprocess_update_batch(batch) | ||
|
||
hidden_state = self._qnet.base_network.init_hidden( | ||
batch_size=self._batch_size, device=self._device | ||
) | ||
target_hidden_state = self._target_qnet.base_network.init_hidden( | ||
batch_size=self._batch_size, device=self._device | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It feels weird that you are accessing an internal module of the qnet. I think instead of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
# Compute predicted Q values | ||
self._optimizer.zero_grad() | ||
pred_qvals, _ = self._qnet(*current_state_inputs, hidden_state) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to make sure, the qnet takes in a window of past observations and you take as output the last hidden state and pass it through an MLP to get the Q-values? So, it involves some redundant computation when calculating Q-values for s_t and s_t+1 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have another PR #270 that handles hidden states saving & burn-in frames. In this PR the hidden states are initialized from 0's. Could you also provide some reference if you have seen more efficient ways of reusing hidden state and computing Q? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can have a look at the cleanrl/sb3 implementation of recurrent networks. there implementation with jax might have some principles or tricks which we can use for our code base? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the current implementation is good enough and is working. We can create new prs to improve its efficiency maybe. |
||
pred_qvals = pred_qvals.view(self._batch_size, self._max_seq_len, -1) | ||
actions = batch["action"].long() | ||
pred_qvals = torch.gather(pred_qvals, -1, actions.unsqueeze(-1)).squeeze(-1) | ||
|
||
# Compute 1-step Q targets | ||
next_qvals, _ = self._target_qnet(*next_state_inputs, target_hidden_state) | ||
next_qvals = next_qvals.view(self._batch_size, self._max_seq_len, -1) | ||
next_qvals, _ = torch.max(next_qvals, dim=-1) | ||
|
||
q_targets = batch["reward"] + self._discount_rate * next_qvals * ( | ||
1 - batch["done"] | ||
) | ||
|
||
loss = self._loss_fn(pred_qvals, q_targets).mean() | ||
|
||
if self._logger.should_log(self._timescale): | ||
self._logger.log_scalar("train_loss", loss, self._timescale) | ||
|
||
loss.backward() | ||
sriyash421 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if self._grad_clip is not None: | ||
torch.nn.utils.clip_grad_value_( | ||
self._qnet.parameters(), self._grad_clip | ||
) | ||
self._optimizer.step() | ||
|
||
# Update target network | ||
if self._target_net_update_schedule.update(): | ||
self._update_target() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,6 +39,48 @@ def forward(self, x): | |
return self.output_layer(x) | ||
|
||
|
||
class DRQNNetwork(nn.Module): | ||
"""Implements the standard DRQN value computation. This module returns two outputs, | ||
which correspond to the two outputs from :obj:`base_network`. In particular, it | ||
transforms the first output from :obj:`base_network` with output dimension | ||
:obj:`hidden_dim` to dimension :obj:`out_dim`, which should be equal to the | ||
number of actions. The second output of this module is the second output from | ||
:obj:`base_network`, which is the hidden state that will be used as the initial | ||
hidden state when computing the next action in the trajectory. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
base_network: nn.Module, | ||
sriyash421 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
hidden_dim: int, | ||
out_dim: int, | ||
linear_fn: nn.Module = None, | ||
): | ||
""" | ||
Args: | ||
base_network (torch.nn.Module): Backbone network that returns two outputs, | ||
one is the representation used to compute action values, and the | ||
other one is the hidden state used as input hidden state later. | ||
hidden_dim (int): Dimension of the output of the :obj:`network`. | ||
out_dim (int): Output dimension of the DRQN. Should be equal to the | ||
number of actions that you are computing values for. | ||
linear_fn (torch.nn.Module): Function that will create the | ||
:py:class:`torch.nn.Module` that will take the output of | ||
:obj:`network` and produce the final action values. If | ||
:obj:`None`, a :py:class:`torch.nn.Linear` layer will be used. | ||
""" | ||
super().__init__() | ||
self.base_network = base_network | ||
self._linear_fn = linear_fn if linear_fn is not None else nn.Linear | ||
self.output_layer = self._linear_fn(hidden_dim, out_dim) | ||
Comment on lines
+73
to
+75
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should probably make all of these internal variables There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe in a separate PR? All network modules defined here have |
||
|
||
def forward(self, x, hidden_state=None): | ||
x, hidden_state = self.base_network(x, hidden_state) | ||
|
||
x = x.flatten(start_dim=1) | ||
return self.output_layer(x), hidden_state | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the only place that this returned hidden_state used in the act() function of the agent? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe it can have an internal function to call during act and update separately? |
||
|
||
|
||
class DuelingNetwork(nn.Module): | ||
"""Computes action values using Dueling Networks (https://arxiv.org/abs/1511.06581). | ||
In dueling, we have two heads---one for estimating advantage function and one for | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please go through the arguments and documentation and make sure the ones you are exposing are all actually being used by your agent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done