Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rnn support: DRQN agent + recurrent buffer #258

Merged
merged 51 commits into from
Oct 13, 2022
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
981be53
initial commit DRQN implementation
hnekoeiq Mar 11, 2022
0aabcff
initial commit recurrent buffer implementation
hnekoeiq Mar 11, 2022
9d666a8
Merge pull request #256 from chandar-lab/recurrent_buffer
hnekoeiq Mar 12, 2022
f557787
Merge branch 'rnn_support' into recurrent_dqn
hnekoeiq Mar 12, 2022
38af33c
Merge pull request #257 from chandar-lab/recurrent_dqn
hnekoeiq Mar 12, 2022
41b6041
drqn agent working with recurrent buffer
hnekoeiq Mar 12, 2022
d2f205f
fix device
hnekoeiq Mar 15, 2022
1ae9cf0
nstep update for RNN
karthiks1701 Mar 16, 2022
346a46f
handled input obs dimension depending on whether agent is acting or u…
TongTongX Mar 16, 2022
921d336
nstep update for RNN after format
karthiks1701 Mar 16, 2022
d66e8ad
Merge branch 'rnn_support' of https://github.com/chandar-lab/RLHive i…
karthiks1701 Mar 16, 2022
c0f976f
reformatted code with black
TongTongX Mar 16, 2022
cc28dfa
Merge branch 'rnn_support' of github.com:chandar-lab/RLHive into rnn_…
TongTongX Mar 16, 2022
4874d7a
fix bugs with buffer size and device
Mar 18, 2022
cf6cfe4
update device configuration
hnekoeiq Mar 22, 2022
8e935b9
DRQN agent tested + updated doctrings and some cleanups
hnekoeiq Mar 22, 2022
b3c810c
fixing id issue for MARL
Mar 31, 2022
7264b91
formatted with black
Mar 31, 2022
de35f5c
initial commit DRQN implementation
hnekoeiq Mar 11, 2022
4ce0d65
initial commit recurrent buffer implementation
hnekoeiq Mar 11, 2022
d985d7e
drqn agent working with recurrent buffer
hnekoeiq Mar 12, 2022
48135a7
fix device
hnekoeiq Mar 15, 2022
29d7607
nstep update for RNN
karthiks1701 Mar 16, 2022
9b25182
handled input obs dimension depending on whether agent is acting or u…
TongTongX Mar 16, 2022
28178cb
nstep update for RNN after format
karthiks1701 Mar 16, 2022
abe510b
reformatted code with black
TongTongX Mar 16, 2022
8f58643
fix bugs with buffer size and device
Mar 18, 2022
a682dcc
update device configuration
hnekoeiq Mar 22, 2022
89e0c05
DRQN agent tested + updated doctrings and some cleanups
hnekoeiq Mar 22, 2022
67f7dfa
fixing id issue for MARL
Mar 31, 2022
9315f23
formatted with black
Mar 31, 2022
f1a9dd8
updating to new registeration
hnekoeiq Apr 11, 2022
a4e9989
suppor both lstm and gru
hnekoeiq Apr 11, 2022
f6d2976
docstrings, cleanups and adding some utils functions
hnekoeiq Apr 12, 2022
1eb931b
Merge branch 'rnn_support' of https://github.com/chandar-lab/RLHive i…
Apr 25, 2022
058f979
merged changes related to callable objects
TongTongX Jul 15, 2022
1df5a69
added sequence model class
TongTongX Jul 16, 2022
a9b73f1
removed unused parameter in base sequence module class; docstring min…
TongTongX Jul 18, 2022
9eb4d52
Merge branch 'dev' into rnn_support
hnekoeiq Jul 22, 2022
abaa7a8
docstring and other minor changes
TongTongX Aug 2, 2022
104823e
reverted stack_size for other agents
TongTongX Aug 2, 2022
088549b
sequence registrable not inherit torch nn.Module
TongTongX Aug 9, 2022
6c6a87b
format
TongTongX Aug 9, 2022
f441706
DRQN reset hidden state in act(), set device of sequence model in yml…
TongTongX Aug 30, 2022
fcc8e6e
clean up
hnekoeiq Sep 19, 2022
89376b2
Merge branch 'dev' into rnn_support
hnekoeiq Sep 20, 2022
2ffb18f
minors fixes
TongTongX Oct 3, 2022
64a88ac
alphabetical ordering
TongTongX Oct 3, 2022
eedf125
fixed device mismatch between rnn hidden state and placeholder image
Oct 13, 2022
39082aa
add both batch and sequence dim to observation during acting
Oct 13, 2022
9ddf9d2
Merge branch 'dev' into rnn_support
sriyash421 Oct 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions hive/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from hive.agents.agent import Agent
from hive.agents.ddpg import DDPG
from hive.agents.dqn import DQNAgent
from hive.agents.drqn import DRQNAgent
from hive.agents.legal_moves_rainbow import LegalMovesRainbowAgent
from hive.agents.rainbow import RainbowDQNAgent
from hive.agents.random import RandomAgent
Expand All @@ -13,6 +14,7 @@
{
"DDPG": DDPG,
"DQNAgent": DQNAgent,
"DRQNAgent": DRQNAgent,
"LegalMovesRainbowAgent": LegalMovesRainbowAgent,
"RainbowDQNAgent": RainbowDQNAgent,
"RandomAgent": RandomAgent,
Expand Down
281 changes: 281 additions & 0 deletions hive/agents/drqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
import copy
import os

import gym
import numpy as np
import torch

from hive.agents.agent import Agent
from hive.agents.dqn import DQNAgent
from hive.agents.qnets.base import FunctionApproximator
from hive.agents.qnets.qnet_heads import DRQNNetwork
from hive.agents.qnets.utils import (
InitializationFn,
calculate_output_dim,
create_init_weights_fn,
)
from hive.replays import BaseReplayBuffer, CircularReplayBuffer
from hive.replays.recurrent_replay import RecurrentReplayBuffer
from hive.utils.loggers import Logger, NullLogger
from hive.utils.schedule import (
LinearSchedule,
PeriodicSchedule,
Schedule,
SwitchSchedule,
)
from hive.utils.utils import LossFn, OptimizerFn, create_folder, seeder


class DRQNAgent(DQNAgent):
"""An agent implementing the DRQN algorithm. Uses an epsilon greedy
exploration policy
"""

def __init__(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please go through the arguments and documentation and make sure the ones you are exposing are all actually being used by your agent.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

self,
observation_space: gym.spaces.Box,
action_space: gym.spaces.Discrete,
representation_net: FunctionApproximator,
id=0,
optimizer_fn: OptimizerFn = None,
loss_fn: LossFn = None,
init_fn: InitializationFn = None,
replay_buffer: BaseReplayBuffer = None,
max_seq_len: int = 1,
discount_rate: float = 0.99,
n_step: int = 1,
grad_clip: float = None,
reward_clip: float = None,
update_period_schedule: Schedule = None,
target_net_soft_update: bool = False,
target_net_update_fraction: float = 0.05,
target_net_update_schedule: Schedule = None,
epsilon_schedule: Schedule = None,
test_epsilon: float = 0.001,
min_replay_history: int = 5000,
batch_size: int = 32,
device="cpu",
logger: Logger = None,
log_frequency: int = 100,
**kwargs,
):
"""
Args:
observation_space (gym.spaces.Box): Observation space for the agent.
action_space (gym.spaces.Discrete): Action space for the agent.
representation_net (FunctionApproximator): A network that outputs the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am assuming there are restrictions on the representation_net? eg it needs to be one of your recurrent ones? Please mention this in the documentation.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you more explicitly mention the restrictions on representation_net? For example, which methods it should have or that it should follow the structure of ConvRNNNetwork or something?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

representations that will be used to compute Q-values (e.g.
everything except the final layer of the DRQN), as well as the
hidden states of the recurrent component.
id: Agent identifier.
optimizer_fn (OptimizerFn): A function that takes in a list of parameters
to optimize and returns the optimizer. If None, defaults to
:py:class:`~torch.optim.Adam`.
loss_fn (LossFn): Loss function used by the agent. If None, defaults to
:py:class:`~torch.nn.SmoothL1Loss`.
init_fn (InitializationFn): Initializes the weights of qnet using
create_init_weights_fn.
replay_buffer (BaseReplayBuffer): The replay buffer that the agent will
push observations to and sample from during learning. If None,
defaults to
:py:class:`~hive.replays.circular_replay.CircularReplayBuffer`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is incorrect. Please go through all the documentation and make sure it is correct.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

max_seq_len (int): The number of consecutive transitions in a sequence.
discount_rate (float): A number between 0 and 1 specifying how much
future rewards are discounted by the agent.
n_step (int): The horizon used in n-step returns to compute TD(n) targets.
grad_clip (float): Gradients will be clipped to between
[-grad_clip, grad_clip].
reward_clip (float): Rewards will be clipped to between
[-reward_clip, reward_clip].
update_period_schedule (Schedule): Schedule determining how frequently
the agent's Q-network is updated.
target_net_soft_update (bool): Whether the target net parameters are
replaced by the qnet parameters completely or using a weighted
average of the target net parameters and the qnet parameters.
target_net_update_fraction (float): The weight given to the target
net parameters in a soft update.
target_net_update_schedule (Schedule): Schedule determining how frequently
the target net is updated.
epsilon_schedule (Schedule): Schedule determining the value of epsilon
through the course of training.
test_epsilon (float): epsilon (probability of choosing a random action)
to be used during testing phase.
min_replay_history (int): How many observations to fill the replay buffer
with before starting to learn.
batch_size (int): The size of the batch sampled from the replay buffer
during learning.
device: Device on which all computations should be run.
logger (ScheduledLogger): Logger used to log agent's metrics.
log_frequency (int): How often to log the agent's metrics.
"""
super().__init__(
observation_space=observation_space,
action_space=action_space,
representation_net=representation_net,
id=id,
optimizer_fn=optimizer_fn,
loss_fn=loss_fn,
init_fn=init_fn,
replay_buffer=replay_buffer,
discount_rate=discount_rate,
n_step=n_step,
grad_clip=grad_clip,
reward_clip=reward_clip,
update_period_schedule=update_period_schedule,
target_net_soft_update=target_net_soft_update,
target_net_update_fraction=target_net_update_fraction,
target_net_update_schedule=target_net_update_schedule,
epsilon_schedule=epsilon_schedule,
test_epsilon=test_epsilon,
min_replay_history=min_replay_history,
batch_size=batch_size,
device=device,
logger=logger,
log_frequency=log_frequency,
)
if replay_buffer is None:
replay_buffer = RecurrentReplayBuffer
self._replay_buffer = replay_buffer(
max_seq_len=max_seq_len,
observation_shape=self._observation_space.shape,
observation_dtype=self._observation_space.dtype,
action_shape=self._action_space.shape,
action_dtype=self._action_space.dtype,
)
self._max_seq_len = max_seq_len
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be moved above the super constructor call:

        if replay_buffer is None:
            replay_buffer = RecurrentReplayBuffer
        replay_buffer = partial(replay_buffer, max_seq_len=max_seq_len)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


def create_q_networks(self, representation_net):
"""Creates the Q-network and target Q-network.

Args:
representation_net: A network that outputs the representations that will
be used to compute Q-values (e.g. everything except the final layer
of the DRQN).
"""
Comment on lines +152 to +155
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should add a comment about the expected output of this network, ie that it outputs a output and hidden state

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

network = representation_net(self._state_size)
network_output_dim = np.prod(calculate_output_dim(network, self._state_size)[0])
self._qnet = DRQNNetwork(network, network_output_dim, self._action_space.n).to(
self._device
)

self._qnet.apply(self._init_fn)
self._target_qnet = copy.deepcopy(self._qnet).requires_grad_(False)
self._hidden_state = network.init_hidden(batch_size=1, device=self._device)

@torch.no_grad()
def act(self, observation):
"""Returns the action for the agent. If in training mode, follows an epsilon
greedy policy. Otherwise, returns the action with the highest Q-value.

Args:
observation: The current observation.
"""

# Reset hidden state if it is episode beginning.
if self._state["episode_start"]:
self._hidden_state = self._qnet.base_network.init_hidden(
batch_size=1, device=self._device
)

# Determine and log the value of epsilon
if self._training:
if not self._learn_schedule.get_value():
epsilon = 1.0
else:
epsilon = self._epsilon_schedule.update()
if self._logger.update_step(self._timescale):
self._logger.log_scalar("epsilon", epsilon, self._timescale)
else:
epsilon = self._test_epsilon

# Sample action. With epsilon probability choose random action,
# otherwise select the action with the highest q-value.
observation = torch.tensor(
np.expand_dims(observation, axis=0), device=self._device
).float()
qvals, self._hidden_state = self._qnet(observation, self._hidden_state)
if self._rng.random() < epsilon:
action = self._rng.integers(self._action_space.n)
else:
# Note: not explicitly handling the ties
action = torch.argmax(qvals).item()

if (
self._training
and self._logger.should_log(self._timescale)
and self._state["episode_start"]
):
self._logger.log_scalar("train_qval", torch.max(qvals), self._timescale)
self._state["episode_start"] = False
return action

def update(self, update_info):
"""
Updates the DRQN agent.

Args:
update_info: dictionary containing all the necessary information to
update the agent. Should contain a full transition, with keys for
"observation", "action", "reward", and "done".
"""
if update_info["done"]:
self._state["episode_start"] = True

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is episode start used anywhere except the buffer? because the buffer takes care of it and it is redundant.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is there for stuff in act(). There's probably a better way to do it. It doesn't really make sense for the agent to do it. It might make sense to add it as part of the observation, but need to think about this a bit.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably want to fix it in a separate PR as this is what DQN does too.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We created an issue for fixing it in both DQN and DRQN.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this part is fine. I don't think it needs fixing.


if not self._training:
return

# Add the most recent transition to the replay buffer.
self._replay_buffer.add(**self.preprocess_update_info(update_info))

# Update the q network based on a sample batch from the replay buffer.
# If the replay buffer doesn't have enough samples, catch the exception
# and move on.
if (
self._learn_schedule.update()
and self._replay_buffer.size() > 0
and self._update_period_schedule.update()
):
batch = self._replay_buffer.sample(batch_size=self._batch_size)
(
current_state_inputs,
next_state_inputs,
batch,
) = self.preprocess_update_batch(batch)

hidden_state = self._qnet.base_network.init_hidden(
batch_size=self._batch_size, device=self._device
)
target_hidden_state = self._target_qnet.base_network.init_hidden(
batch_size=self._batch_size, device=self._device
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels weird that you are accessing an internal module of the qnet. I think instead of self._qnet.base_network.init_hidden, it should be self._qnet.init_hidden

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# Compute predicted Q values
self._optimizer.zero_grad()
pred_qvals, _ = self._qnet(*current_state_inputs, hidden_state)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure, the qnet takes in a window of past observations and you take as output the last hidden state and pass it through an MLP to get the Q-values? So, it involves some redundant computation when calculating Q-values for s_t and s_t+1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have another PR #270 that handles hidden states saving & burn-in frames. In this PR the hidden states are initialized from 0's.

Could you also provide some reference if you have seen more efficient ways of reusing hidden state and computing Q?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can have a look at the cleanrl/sb3 implementation of recurrent networks. there implementation with jax might have some principles or tricks which we can use for our code base?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the current implementation is good enough and is working. We can create new prs to improve its efficiency maybe.

pred_qvals = pred_qvals.view(self._batch_size, self._max_seq_len, -1)
actions = batch["action"].long()
pred_qvals = torch.gather(pred_qvals, -1, actions.unsqueeze(-1)).squeeze(-1)

# Compute 1-step Q targets
next_qvals, _ = self._target_qnet(*next_state_inputs, target_hidden_state)
next_qvals = next_qvals.view(self._batch_size, self._max_seq_len, -1)
next_qvals, _ = torch.max(next_qvals, dim=-1)

q_targets = batch["reward"] + self._discount_rate * next_qvals * (
1 - batch["done"]
)

loss = self._loss_fn(pred_qvals, q_targets).mean()

if self._logger.should_log(self._timescale):
self._logger.log_scalar("train_loss", loss, self._timescale)

loss.backward()
sriyash421 marked this conversation as resolved.
Show resolved Hide resolved
if self._grad_clip is not None:
torch.nn.utils.clip_grad_value_(
self._qnet.parameters(), self._grad_clip
)
self._optimizer.step()

# Update target network
if self._target_net_update_schedule.update():
self._update_target()
2 changes: 2 additions & 0 deletions hive/agents/qnets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from hive.agents.qnets.base import FunctionApproximator
from hive.agents.qnets.conv import ConvNetwork
from hive.agents.qnets.mlp import MLPNetwork
from hive.agents.qnets.rnn import ConvRNNNetwork

registry.register_all(
FunctionApproximator,
{
"MLPNetwork": MLPNetwork,
"ConvNetwork": ConvNetwork,
"ConvRNNNetwork": ConvRNNNetwork,
sriyash421 marked this conversation as resolved.
Show resolved Hide resolved
"NatureAtariDQNModel": NatureAtariDQNModel,
},
)
Expand Down
42 changes: 42 additions & 0 deletions hive/agents/qnets/qnet_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,48 @@ def forward(self, x):
return self.output_layer(x)


class DRQNNetwork(nn.Module):
"""Implements the standard DRQN value computation. This module returns two outputs,
which correspond to the two outputs from :obj:`base_network`. In particular, it
transforms the first output from :obj:`base_network` with output dimension
:obj:`hidden_dim` to dimension :obj:`out_dim`, which should be equal to the
number of actions. The second output of this module is the second output from
:obj:`base_network`, which is the hidden state that will be used as the initial
hidden state when computing the next action in the trajectory.
"""

def __init__(
self,
base_network: nn.Module,
sriyash421 marked this conversation as resolved.
Show resolved Hide resolved
hidden_dim: int,
out_dim: int,
linear_fn: nn.Module = None,
):
"""
Args:
base_network (torch.nn.Module): Backbone network that returns two outputs,
one is the representation used to compute action values, and the
other one is the hidden state used as input hidden state later.
hidden_dim (int): Dimension of the output of the :obj:`network`.
out_dim (int): Output dimension of the DRQN. Should be equal to the
number of actions that you are computing values for.
linear_fn (torch.nn.Module): Function that will create the
:py:class:`torch.nn.Module` that will take the output of
:obj:`network` and produce the final action values. If
:obj:`None`, a :py:class:`torch.nn.Linear` layer will be used.
"""
super().__init__()
self.base_network = base_network
self._linear_fn = linear_fn if linear_fn is not None else nn.Linear
self.output_layer = self._linear_fn(hidden_dim, out_dim)
Comment on lines +73 to +75
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should probably make all of these internal variables

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe in a separate PR? All network modules defined here have base_network and output_layer.


def forward(self, x, hidden_state=None):
x, hidden_state = self.base_network(x, hidden_state)

x = x.flatten(start_dim=1)
return self.output_layer(x), hidden_state
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the only place that this returned hidden_state used in the act() function of the agent?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. forward() is also called in update() of the agent but the output hidden_state is not used.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it can have an internal function to call during act and update separately?



class DuelingNetwork(nn.Module):
"""Computes action values using Dueling Networks (https://arxiv.org/abs/1511.06581).
In dueling, we have two heads---one for estimating advantage function and one for
Expand Down
Loading