-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rnn support: DRQN agent + recurrent buffer #258
Conversation
initial commit recurrent buffer implementation
initial commit DRQN implementation
self.size() + self._max_seq_len + self._n_step - 1 | ||
) | ||
elements = array[full_indices] | ||
elements = elements.reshape(indices.shape[0], -1, *elements.shape[2:]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want to ensure the dimension of observation is (batch_size, seq_length, C, H, W). C=1 according to https://github.com/chandar-lab/RLHive/blame/main/hive/envs/atari/atari.py#L62.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Functions and code resued from CircularReplayBuffer? Maybe just use the inherited functions? Or are there changes across all the functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are using the inherited functions defined in CircularReplayBuffer except for the ones where max_seq_len
and stack_size
differ.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can something like super calls be done? For e.g. the sample function in PPOReplayBuffer? That is the input is updated such that it can be passed to the function of the master class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There will be a new PR that implements a recurrent replay that saves trajectories rather than sequences of transitions. It will be more suitable for recurrent DQN. For now I think it's fine to keep those functions where max_seq_len
and stack_size
are differerent.
hive/replays/recurrent_replay.py
Outdated
@@ -200,7 +200,7 @@ def sample(self, batch_size): | |||
trajectory_lengths = ( | |||
np.argmax(terminals.astype(bool), axis=1) + 1 | |||
) * is_terminal + self._n_step * (1 - is_terminal) | |||
is_terminal = terminals[:,1:self._n_step-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because of the sequence length, is_terminal has to be of shape (B*seq_length). Hence this particular change after calculating trajectory lengths.
] + np.arange(self._n_step) | ||
disc_rewards = np.einsum( | ||
"ijk,k->ij", rewards[:, idx], self._discount | ||
) | ||
rewards = disc_rewards |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went through and gave a quick review. I didn't thoroughly check for correctness yet, ideally would want to see experiments that show it's working first before doing that. Also please write updated docstrings.
hive/agents/qnets/qnet_heads.py
Outdated
"""Implements the standard DQN value computation. Transforms output from | ||
:obj:`base_network` with output dimension :obj:`hidden_dim` to dimension | ||
:obj:`out_dim`, which should be equal to the number of actions. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update docstring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
hive/agents/qnets/qnet_heads.py
Outdated
base_network (torch.nn.Module): Backbone network that computes the | ||
representations that are used to compute action values. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update docstring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should say something about the expected output of this base_network
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update documentation according to previous comment. Specifically that base_network returns two things, and what those two things are.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
hive/agents/drqn.py
Outdated
"""An agent implementing the DQN algorithm. Uses an epsilon greedy | ||
exploration policy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
hive/agents/qnets/rnn.py
Outdated
|
||
if mlp_layers is not None: | ||
# MLP Layers | ||
# conv_output_size = calculate_output_dim(self.conv, in_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
hive/agents/qnets/utils.py
Outdated
@@ -20,9 +20,12 @@ def calculate_output_dim(net, input_shape): | |||
""" | |||
if isinstance(input_shape, int): | |||
input_shape = (input_shape,) | |||
placeholder = torch.zeros((0,) + tuple(input_shape)) | |||
placeholder = torch.zeros((1,) + tuple(input_shape)).to(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Device shouldn't need to be passed in right? Like at this point it should all be on the cpu? I am fine with the change, just not sure why it's necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dapatil211 The problem is here we either need to create a dummy hidden_state
on CPU and pass it to the network or add an if condition in the forward function of ConvRNNNetwork to take care of hidden_state = None
. We did the latter but there might be another case other than calculate_output_dim
where hidden_state
is None and in that case, our hidden_state
inside the forward function should be on the _device. (We had to change the empty array because lstm was not able to process it)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well don't you only need to do that because you are doing .to(device)
here when you don't really need to?: https://github.com/chandar-lab/RLHive/pull/258/files/4874d7a579fbcaf3c26dce7d3fecc01c5c57fac6#diff-324043212c24b1c8e9c8c139a4defa49947886ee8fc9f4f9ab8138d4347f4860R139-R141
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
hive/agents/qnets/utils.py
Outdated
output = net(placeholder) | ||
return output.size()[1:] | ||
if isinstance(output, tuple): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's a tuple, you should return the size of each output. It doesn't make sense to only do the first output. That's just specific to your current use case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dapatil211 This one needs a bit more discussion. Since hidden_state
(the second item in the tuple) is a tuple itself, should we check the items inside that and return their size as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably should be recursive? Keep going until you hit scalars or tensors
hive/replays/recurrent_replay.py
Outdated
# mask = np.expand_dims(trajectory_lengths, 1) > np.arange(self._n_step) | ||
# rewards = np.sum(rewards * mask, axis=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
hive/replays/recurrent_replay.py
Outdated
idx = np.arange(rewards.shape[1] - self._n_step + 1)[ | ||
:, None | ||
] + np.arange(self._n_step) | ||
disc_rewards = np.einsum( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you give a comment as to what this is doing? einsum is notoriously bad in terms of being interpreted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure will do!
hive/agents/drqn.py
Outdated
representation_net=representation_net, | ||
obs_dim=obs_dim, | ||
act_dim=act_dim, | ||
id=0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we missed id=id
. to pass the id to the base DQN class, because of this in marlgrid all the agents have the same id and that is causing problems.
hive/agents/qnets/rnn.py
Outdated
conv_output_size = calculate_output_dim(self.conv, in_dim) | ||
if self._rnn_type == "lstm": | ||
self.rnn = nn.LSTM( | ||
np.prod(conv_output_size), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing different types of network to a sequencer class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
"observation", "action", "reward", and "done". | ||
""" | ||
if update_info["done"]: | ||
self._state["episode_start"] = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is episode start used anywhere except the buffer? because the buffer takes care of it and it is redundant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is there for stuff in act(). There's probably a better way to do it. It doesn't really make sense for the agent to do it. It might make sense to add it as part of the observation, but need to think about this a bit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably want to fix it in a separate PR as this is what DQN does too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We created an issue for fixing it in both DQN and DRQN.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this part is fine. I don't think it needs fixing.
hive/agents/qnets/sequence_models.py
Outdated
batch_first (bool): If True, then the input and output tensors are provided as (batch, seq, feature) instead of (seq, batch, feature). | ||
""" | ||
super().__init__( | ||
rnn_input_size=rnn_input_size, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can remove rnn_input_size if the base function does not use it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. SImilarly batch_first is also removed.
hive/agents/drqn.py
Outdated
self._optimizer.zero_grad() | ||
pred_qvals, hidden_state = self._qnet(*current_state_inputs, hidden_state) | ||
pred_qvals = pred_qvals.view( | ||
self._batch_size, self._replay_buffer._max_seq_len, -1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should not be accessing internal variable of buffer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
hive/agents/__init__.py
Outdated
@@ -16,6 +17,7 @@ | |||
"LegalMovesRainbowAgent": LegalMovesRainbowAgent, | |||
"RainbowDQNAgent": RainbowDQNAgent, | |||
"RandomAgent": RandomAgent, | |||
"DRQNAgent": DRQNAgent, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alphabetical order please.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
hive/agents/drqn.py
Outdated
import copy | ||
import os | ||
|
||
import gym | ||
import numpy as np | ||
import torch | ||
|
||
from hive.agents.agent import Agent | ||
from hive.agents.qnets.base import FunctionApproximator | ||
from hive.agents.qnets.qnet_heads import DRQNNetwork | ||
from hive.agents.qnets.utils import ( | ||
InitializationFn, | ||
calculate_output_dim, | ||
create_init_weights_fn, | ||
) | ||
from hive.replays import BaseReplayBuffer, CircularReplayBuffer | ||
from hive.replays.recurrent_replay import RecurrentReplayBuffer | ||
from hive.utils.loggers import Logger, NullLogger | ||
from hive.utils.schedule import ( | ||
LinearSchedule, | ||
PeriodicSchedule, | ||
Schedule, | ||
SwitchSchedule, | ||
) | ||
from hive.agents.dqn import DQNAgent | ||
from hive.utils.utils import LossFn, OptimizerFn, create_folder, seeder |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure you run import sorting (isort
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Args: | ||
observation_space (gym.spaces.Box): Observation space for the agent. | ||
action_space (gym.spaces.Discrete): Action space for the agent. | ||
representation_net (FunctionApproximator): A network that outputs the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am assuming there are restrictions on the representation_net? eg it needs to be one of your recurrent ones? Please mention this in the documentation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you more explicitly mention the restrictions on representation_net? For example, which methods it should have or that it should follow the structure of ConvRNNNetwork or something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
hive/agents/drqn.py
Outdated
stack_size: Number of observations stacked to create the state fed to the | ||
DRQN. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this mean in the context of DRQN? You aren't using this are you? If not, remove. You may need to keep a dummy or add a kwargs to make it work with the runner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
hive/agents/qnets/rnn.py
Outdated
def __init__( | ||
self, | ||
in_dim, | ||
sequence_fn: SequenceModule, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See my comment in the sequence_models.py file about this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
capacity: int = 10000, | ||
max_seq_len: int = 1, | ||
n_step: int = 1, | ||
gamma: float = 0.99, | ||
observation_shape=(), | ||
observation_dtype=np.uint8, | ||
action_shape=(), | ||
action_dtype=np.int8, | ||
reward_shape=(), | ||
reward_dtype=np.float32, | ||
extra_storage_types=None, | ||
num_players_sharing_buffer: int = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the spacing inconsistent?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we specify the datatype black will add whitespaces around the equal sign, otherwise it will remove those whitespaces.
hive/replays/recurrent_replay.py
Outdated
capacity (int): Total number of observations that can be stored in the | ||
buffer. Note, this is not the same as the number of transitions that | ||
can be stored in the buffer. | ||
max_seq_len (int): The number of consecutive transitions in a sequence. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please clarify the description of this parameter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
super().__init__( | ||
capacity=capacity, | ||
stack_size=1, | ||
n_step=n_step, | ||
gamma=gamma, | ||
observation_shape=observation_shape, | ||
observation_dtype=observation_dtype, | ||
action_shape=action_shape, | ||
action_dtype=action_dtype, | ||
reward_shape=reward_shape, | ||
reward_dtype=reward_dtype, | ||
extra_storage_types=extra_storage_types, | ||
num_players_sharing_buffer=num_players_sharing_buffer, | ||
) | ||
self._max_seq_len = max_seq_len | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this class be equivalent to just creating CircularReplayBuffer(stack_size=max_seq_len)? If so, why not just do that? If not, what/where is the change in logic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are differences in sample()
.
hive/replays/recurrent_replay.py
Outdated
self._n_step | ||
) # (S-N+1) x N | ||
rewards = rewards[:, idx] # B x (S-N+1) x N | ||
# Creating a vectorized sliding window to calculate discounted returns for every element in the sequence |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please make sure commented lines and lines part of docstrings are still under 88 characters.
hive/runners/single_agent_loop.py
Outdated
@@ -153,7 +153,6 @@ def set_up_experiment(config): | |||
agent = agent_fn( | |||
observation_space=env_spec.observation_space[0], | |||
action_space=env_spec.action_space[0], | |||
stack_size=config.get("stack_size", 1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't be removed. It will break other agents.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
"observation", "action", "reward", and "done". | ||
""" | ||
if update_info["done"]: | ||
self._state["episode_start"] = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this part is fine. I don't think it needs fixing.
hive/agents/drqn.py
Outdated
self._hidden_state = self._qnet.base_network.init_hidden( | ||
batch_size=1, device=self._device | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is the same thing. episode_start can only be set in update because that's when the agent knows that the episode ended. My comment is about how the resetting of the hidden state should be done in act()
hive/agents/qnets/sequence_models.py
Outdated
rnn_input_size (int): The number of expected features in the input x. | ||
rnn_hidden_size (int): The number of features in the hidden state h. | ||
num_rnn_layers (int): Number of recurrent layers. | ||
batch_first (bool): If True, then the input and output tensors are provided as (batch, seq, feature) instead of (seq, batch, feature). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
line length
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
hive/agents/qnets/sequence_models.py
Outdated
rnn_input_size (int): The number of expected features in the input x. | ||
rnn_hidden_size (int): The number of features in the hidden state h. | ||
num_rnn_layers (int): Number of recurrent layers. | ||
batch_first (bool): If True, then the input and output tensors are provided as (batch, seq, feature) instead of (seq, batch, feature). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
line length
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
hive/agents/qnets/sequence_models.py
Outdated
batch_first=batch_first, | ||
) | ||
|
||
def init_hidden(self, batch_size, device="cpu"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this device not just be passed once in the initializer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
hive/agents/drqn.py
Outdated
if replay_buffer is None: | ||
replay_buffer = RecurrentReplayBuffer | ||
self._replay_buffer = replay_buffer( | ||
max_seq_len=max_seq_len, | ||
observation_shape=self._observation_space.shape, | ||
observation_dtype=self._observation_space.dtype, | ||
action_shape=self._action_space.shape, | ||
action_dtype=self._action_space.dtype, | ||
) | ||
self._max_seq_len = max_seq_len |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be moved above the super constructor call:
if replay_buffer is None:
replay_buffer = RecurrentReplayBuffer
replay_buffer = partial(replay_buffer, max_seq_len=max_seq_len)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
hive/agents/drqn.py
Outdated
replay_buffer (BaseReplayBuffer): The replay buffer that the agent will | ||
push observations to and sample from during learning. If None, | ||
defaults to | ||
:py:class:`~hive.replays.circular_replay.CircularReplayBuffer`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is incorrect. Please go through all the documentation and make sure it is correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Args: | ||
observation_space (gym.spaces.Box): Observation space for the agent. | ||
action_space (gym.spaces.Discrete): Action space for the agent. | ||
representation_net (FunctionApproximator): A network that outputs the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you more explicitly mention the restrictions on representation_net? For example, which methods it should have or that it should follow the structure of ConvRNNNetwork or something?
hive/agents/drqn.py
Outdated
hidden_state = self._qnet.base_network.init_hidden( | ||
batch_size=self._batch_size, device=self._device | ||
) | ||
target_hidden_state = self._target_qnet.base_network.init_hidden( | ||
batch_size=self._batch_size, device=self._device | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It feels weird that you are accessing an internal module of the qnet. I think instead of self._qnet.base_network.init_hidden
, it should be self._qnet.init_hidden
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
self.base_network = base_network | ||
self._linear_fn = linear_fn if linear_fn is not None else nn.Linear | ||
self.output_layer = self._linear_fn(hidden_dim, out_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should probably make all of these internal variables
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe in a separate PR? All network modules defined here have base_network
and output_layer
.
hive/replays/recurrent_replay.py
Outdated
rewards = rewards[:, idx] # B x (S-N+1) x N | ||
# Creating a vectorized sliding window to calculate | ||
# discounted returns for every element in the sequence. | ||
# equivalent to np.sum(rewards * self._discount[None, None, :], axis=2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is longer than 88 characters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
) | ||
# Compute predicted Q values | ||
self._optimizer.zero_grad() | ||
pred_qvals, _ = self._qnet(*current_state_inputs, hidden_state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to make sure, the qnet takes in a window of past observations and you take as output the last hidden state and pass it through an MLP to get the Q-values? So, it involves some redundant computation when calculating Q-values for s_t and s_t+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have another PR #270 that handles hidden states saving & burn-in frames. In this PR the hidden states are initialized from 0's.
Could you also provide some reference if you have seen more efficient ways of reusing hidden state and computing Q?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can have a look at the cleanrl/sb3 implementation of recurrent networks. there implementation with jax might have some principles or tricks which we can use for our code base?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the current implementation is good enough and is working. We can create new prs to improve its efficiency maybe.
x, hidden_state = self.base_network(x, hidden_state) | ||
|
||
x = x.flatten(start_dim=1) | ||
return self.output_layer(x), hidden_state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it can have an internal function to call during act and update separately?
self.size() + self._max_seq_len + self._n_step - 1 | ||
) | ||
elements = array[full_indices] | ||
elements = elements.reshape(indices.shape[0], -1, *elements.shape[2:]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Functions and code resued from CircularReplayBuffer? Maybe just use the inherited functions? Or are there changes across all the functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assuming everything runs and matches existing benchmarks, looks good. Please add wandb runs to this conversation and also to the benchmarking report. Also when merging, do a squash and merge.
We should rerun the experiments for benchmarking, but here are some of the results on atari: And results on Hanabi DRQN: https://wandb.ai/chandar-rl/Hive/reports/DRQN-for-Hanabi--VmlldzoyMzU1NTUy?accessToken=efz63017kwgubdp8oyloepsfpjn3pmf0qka070h6rh90jjrcxcp8jd3eo20ko9sj |
No description provided.