diff --git a/hive/agents/__init__.py b/hive/agents/__init__.py index 98b958db..5c5e1d06 100644 --- a/hive/agents/__init__.py +++ b/hive/agents/__init__.py @@ -2,6 +2,7 @@ from hive.agents.agent import Agent from hive.agents.ddpg import DDPG from hive.agents.dqn import DQNAgent +from hive.agents.drqn import DRQNAgent from hive.agents.legal_moves_rainbow import LegalMovesRainbowAgent from hive.agents.rainbow import RainbowDQNAgent from hive.agents.random import RandomAgent @@ -13,6 +14,7 @@ { "DDPG": DDPG, "DQNAgent": DQNAgent, + "DRQNAgent": DRQNAgent, "LegalMovesRainbowAgent": LegalMovesRainbowAgent, "RainbowDQNAgent": RainbowDQNAgent, "RandomAgent": RandomAgent, diff --git a/hive/agents/drqn.py b/hive/agents/drqn.py new file mode 100644 index 00000000..28cc318f --- /dev/null +++ b/hive/agents/drqn.py @@ -0,0 +1,282 @@ +import copy +import os +from functools import partial + +import gym +import numpy as np +import torch + +from hive.agents.agent import Agent +from hive.agents.dqn import DQNAgent +from hive.agents.qnets.base import FunctionApproximator +from hive.agents.qnets.qnet_heads import DRQNNetwork +from hive.agents.qnets.utils import ( + InitializationFn, + calculate_output_dim, + create_init_weights_fn, +) +from hive.replays import BaseReplayBuffer, CircularReplayBuffer +from hive.replays.recurrent_replay import RecurrentReplayBuffer +from hive.utils.loggers import Logger, NullLogger +from hive.utils.schedule import ( + LinearSchedule, + PeriodicSchedule, + Schedule, + SwitchSchedule, +) +from hive.utils.utils import LossFn, OptimizerFn, create_folder, seeder + + +class DRQNAgent(DQNAgent): + """An agent implementing the DRQN algorithm. Uses an epsilon greedy + exploration policy + """ + + def __init__( + self, + observation_space: gym.spaces.Box, + action_space: gym.spaces.Discrete, + representation_net: FunctionApproximator, + id=0, + optimizer_fn: OptimizerFn = None, + loss_fn: LossFn = None, + init_fn: InitializationFn = None, + replay_buffer: BaseReplayBuffer = None, + max_seq_len: int = 1, + discount_rate: float = 0.99, + n_step: int = 1, + grad_clip: float = None, + reward_clip: float = None, + update_period_schedule: Schedule = None, + target_net_soft_update: bool = False, + target_net_update_fraction: float = 0.05, + target_net_update_schedule: Schedule = None, + epsilon_schedule: Schedule = None, + test_epsilon: float = 0.001, + min_replay_history: int = 5000, + batch_size: int = 32, + device="cpu", + logger: Logger = None, + log_frequency: int = 100, + **kwargs, + ): + """ + Args: + observation_space (gym.spaces.Box): Observation space for the agent. + action_space (gym.spaces.Discrete): Action space for the agent. + representation_net (FunctionApproximator): A network that outputs the + representations that will be used to compute Q-values (e.g. + everything except the final layer of the DRQN), as well as the + hidden states of the recurrent component. The structure should be + similar to ConvRNNNetwork, i.e., it should have a current module + component placed between the convolutional layers and MLP layers. + It should also define a method that initializes the hidden state + of the recurrent module if the computation requires hidden states + as input/output. + id: Agent identifier. + optimizer_fn (OptimizerFn): A function that takes in a list of parameters + to optimize and returns the optimizer. If None, defaults to + :py:class:`~torch.optim.Adam`. + loss_fn (LossFn): Loss function used by the agent. If None, defaults to + :py:class:`~torch.nn.SmoothL1Loss`. + init_fn (InitializationFn): Initializes the weights of qnet using + create_init_weights_fn. + replay_buffer (BaseReplayBuffer): The replay buffer that the agent will + push observations to and sample from during learning. If None, + defaults to + :py:class:`~hive.replays.recurrent_replay.RecurrentReplayBuffer`. + max_seq_len (int): The number of consecutive transitions in a sequence. + discount_rate (float): A number between 0 and 1 specifying how much + future rewards are discounted by the agent. + n_step (int): The horizon used in n-step returns to compute TD(n) targets. + grad_clip (float): Gradients will be clipped to between + [-grad_clip, grad_clip]. + reward_clip (float): Rewards will be clipped to between + [-reward_clip, reward_clip]. + update_period_schedule (Schedule): Schedule determining how frequently + the agent's Q-network is updated. + target_net_soft_update (bool): Whether the target net parameters are + replaced by the qnet parameters completely or using a weighted + average of the target net parameters and the qnet parameters. + target_net_update_fraction (float): The weight given to the target + net parameters in a soft update. + target_net_update_schedule (Schedule): Schedule determining how frequently + the target net is updated. + epsilon_schedule (Schedule): Schedule determining the value of epsilon + through the course of training. + test_epsilon (float): epsilon (probability of choosing a random action) + to be used during testing phase. + min_replay_history (int): How many observations to fill the replay buffer + with before starting to learn. + batch_size (int): The size of the batch sampled from the replay buffer + during learning. + device: Device on which all computations should be run. + logger (ScheduledLogger): Logger used to log agent's metrics. + log_frequency (int): How often to log the agent's metrics. + """ + if replay_buffer is None: + replay_buffer = RecurrentReplayBuffer + replay_buffer = partial(replay_buffer, max_seq_len=max_seq_len) + self._max_seq_len = max_seq_len + + super().__init__( + observation_space=observation_space, + action_space=action_space, + representation_net=representation_net, + id=id, + optimizer_fn=optimizer_fn, + loss_fn=loss_fn, + init_fn=init_fn, + replay_buffer=replay_buffer, + discount_rate=discount_rate, + n_step=n_step, + grad_clip=grad_clip, + reward_clip=reward_clip, + update_period_schedule=update_period_schedule, + target_net_soft_update=target_net_soft_update, + target_net_update_fraction=target_net_update_fraction, + target_net_update_schedule=target_net_update_schedule, + epsilon_schedule=epsilon_schedule, + test_epsilon=test_epsilon, + min_replay_history=min_replay_history, + batch_size=batch_size, + device=device, + logger=logger, + log_frequency=log_frequency, + ) + + def create_q_networks(self, representation_net): + """Creates the Q-network and target Q-network. + + Args: + representation_net: A network that outputs the representations that will + be used to compute Q-values (e.g. everything except the final layer + of the DRQN). + """ + network = representation_net(self._state_size) + network_output_dim = np.prod(calculate_output_dim(network, self._state_size)[0]) + self._qnet = DRQNNetwork(network, network_output_dim, self._action_space.n).to( + self._device + ) + self._qnet.update_rnn_device() + + self._qnet.apply(self._init_fn) + self._target_qnet = copy.deepcopy(self._qnet).requires_grad_(False) + self._hidden_state = self._qnet.init_hidden(batch_size=1) + + @torch.no_grad() + def act(self, observation): + """Returns the action for the agent. If in training mode, follows an epsilon + greedy policy. Otherwise, returns the action with the highest Q-value. + + Args: + observation: The current observation. + """ + + # Reset hidden state if it is episode beginning. + if self._state["episode_start"]: + self._hidden_state = self._qnet.init_hidden(batch_size=1) + + # Determine and log the value of epsilon + if self._training: + if not self._learn_schedule.get_value(): + epsilon = 1.0 + else: + epsilon = self._epsilon_schedule.update() + if self._logger.update_step(self._timescale): + self._logger.log_scalar("epsilon", epsilon, self._timescale) + else: + epsilon = self._test_epsilon + + # Sample action. With epsilon probability choose random action, + # otherwise select the action with the highest q-value. + # Insert batch_size and sequence_len dimensions to observation + observation = torch.tensor( + np.expand_dims(observation, axis=(0, 1)), device=self._device + ).float() + qvals, self._hidden_state = self._qnet(observation, self._hidden_state) + if self._rng.random() < epsilon: + action = self._rng.integers(self._action_space.n) + else: + # Note: not explicitly handling the ties + action = torch.argmax(qvals).item() + + if ( + self._training + and self._logger.should_log(self._timescale) + and self._state["episode_start"] + ): + self._logger.log_scalar("train_qval", torch.max(qvals), self._timescale) + self._state["episode_start"] = False + return action + + def update(self, update_info): + """ + Updates the DRQN agent. + + Args: + update_info: dictionary containing all the necessary information to + update the agent. Should contain a full transition, with keys for + "observation", "action", "reward", and "done". + """ + if update_info["done"]: + self._state["episode_start"] = True + + if not self._training: + return + + # Add the most recent transition to the replay buffer. + self._replay_buffer.add(**self.preprocess_update_info(update_info)) + + # Update the q network based on a sample batch from the replay buffer. + # If the replay buffer doesn't have enough samples, catch the exception + # and move on. + if ( + self._learn_schedule.update() + and self._replay_buffer.size() > 0 + and self._update_period_schedule.update() + ): + batch = self._replay_buffer.sample(batch_size=self._batch_size) + ( + current_state_inputs, + next_state_inputs, + batch, + ) = self.preprocess_update_batch(batch) + + hidden_state = self._qnet.init_hidden( + batch_size=self._batch_size, + ) + target_hidden_state = self._target_qnet.init_hidden( + batch_size=self._batch_size, + ) + # Compute predicted Q values + self._optimizer.zero_grad() + pred_qvals, _ = self._qnet(*current_state_inputs, hidden_state) + pred_qvals = pred_qvals.view(self._batch_size, self._max_seq_len, -1) + actions = batch["action"].long() + pred_qvals = torch.gather(pred_qvals, -1, actions.unsqueeze(-1)).squeeze(-1) + + # Compute 1-step Q targets + next_qvals, _ = self._target_qnet(*next_state_inputs, target_hidden_state) + next_qvals = next_qvals.view(self._batch_size, self._max_seq_len, -1) + next_qvals, _ = torch.max(next_qvals, dim=-1) + + q_targets = batch["reward"] + self._discount_rate * next_qvals * ( + 1 - batch["done"] + ) + + loss = self._loss_fn(pred_qvals, q_targets).mean() + + if self._logger.should_log(self._timescale): + self._logger.log_scalar("train_loss", loss, self._timescale) + + loss.backward() + if self._grad_clip is not None: + torch.nn.utils.clip_grad_value_( + self._qnet.parameters(), self._grad_clip + ) + self._optimizer.step() + + # Update target network + if self._target_net_update_schedule.update(): + self._update_target() diff --git a/hive/agents/qnets/__init__.py b/hive/agents/qnets/__init__.py index a9a00575..ed580678 100644 --- a/hive/agents/qnets/__init__.py +++ b/hive/agents/qnets/__init__.py @@ -3,12 +3,14 @@ from hive.agents.qnets.base import FunctionApproximator from hive.agents.qnets.conv import ConvNetwork from hive.agents.qnets.mlp import MLPNetwork +from hive.agents.qnets.rnn import ConvRNNNetwork registry.register_all( FunctionApproximator, { - "MLPNetwork": MLPNetwork, "ConvNetwork": ConvNetwork, + "ConvRNNNetwork": ConvRNNNetwork, + "MLPNetwork": MLPNetwork, "NatureAtariDQNModel": NatureAtariDQNModel, }, ) diff --git a/hive/agents/qnets/qnet_heads.py b/hive/agents/qnets/qnet_heads.py index c70161c8..768e7582 100644 --- a/hive/agents/qnets/qnet_heads.py +++ b/hive/agents/qnets/qnet_heads.py @@ -39,6 +39,54 @@ def forward(self, x): return self.output_layer(x) +class DRQNNetwork(nn.Module): + """Implements the standard DRQN value computation. This module returns two outputs, + which correspond to the two outputs from :obj:`base_network`. In particular, it + transforms the first output from :obj:`base_network` with output dimension + :obj:`hidden_dim` to dimension :obj:`out_dim`, which should be equal to the + number of actions. The second output of this module is the second output from + :obj:`base_network`, which is the hidden state that will be used as the initial + hidden state when computing the next action in the trajectory. + """ + + def __init__( + self, + base_network: nn.Module, + hidden_dim: int, + out_dim: int, + linear_fn: nn.Module = None, + ): + """ + Args: + base_network (torch.nn.Module): Backbone network that returns two outputs, + one is the representation used to compute action values, and the + other one is the hidden state used as input hidden state later. + hidden_dim (int): Dimension of the output of the :obj:`network`. + out_dim (int): Output dimension of the DRQN. Should be equal to the + number of actions that you are computing values for. + linear_fn (torch.nn.Module): Function that will create the + :py:class:`torch.nn.Module` that will take the output of + :obj:`network` and produce the final action values. If + :obj:`None`, a :py:class:`torch.nn.Linear` layer will be used. + """ + super().__init__() + self.base_network = base_network + self._linear_fn = linear_fn if linear_fn is not None else nn.Linear + self.output_layer = self._linear_fn(hidden_dim, out_dim) + + def forward(self, x, hidden_state=None): + x, hidden_state = self.base_network(x, hidden_state) + + x = x.flatten(start_dim=1) + return self.output_layer(x), hidden_state + + def init_hidden(self, batch_size): + return self.base_network.init_hidden(batch_size) + + def update_rnn_device(self): + self.base_network.update_rnn_device() + + class DuelingNetwork(nn.Module): """Computes action values using Dueling Networks (https://arxiv.org/abs/1511.06581). In dueling, we have two heads---one for estimating advantage function and one for diff --git a/hive/agents/qnets/rnn.py b/hive/agents/qnets/rnn.py new file mode 100644 index 00000000..e90e9c25 --- /dev/null +++ b/hive/agents/qnets/rnn.py @@ -0,0 +1,136 @@ +import numpy as np +import torch +from torch import nn + +from hive.agents.qnets.mlp import MLPNetwork +from hive.agents.qnets.conv import ConvNetwork +from hive.agents.qnets.utils import calculate_output_dim +from hive.agents.qnets.sequence_models import SequenceFn + + +class ConvRNNNetwork(nn.Module): + """ + Basic convolutional recurrent neural network architecture. Applies a number of + convolutional layers (each followed by a ReLU activation), recurrent layers, and then + feeds the output into an :py:class:`hive.agents.qnets.mlp.MLPNetwork`. + + Note, if :obj:`channels` is :const:`None`, the network created for the + convolution portion of the architecture is simply an + :py:class:`torch.nn.Identity` module. If :obj:`mlp_layers` is + :const:`None`, the mlp portion of the architecture is an + :py:class:`torch.nn.Identity` module. + """ + + def __init__( + self, + in_dim, + sequence_fn: SequenceFn, + channels=None, + mlp_layers=None, + kernel_sizes=1, + strides=1, + paddings=0, + normalization_factor=255, + noisy=False, + std_init=0.5, + ): + """ + Args: + in_dim (tuple): The tuple of observations dimension (channels, width, + height). + sequence_fn (SequenceFn): A sequence neural network that learns + recurrent representation. Usually placed between the convolutional + layers and mlp layers. + channels (list): The size of output channel for each convolutional layer. + mlp_layers (list): The number of neurons for each mlp layer after the + convolutional layers. + kernel_sizes (list | int): The kernel size for each convolutional layer + strides (list | int): The stride used for each convolutional layer. + paddings (list | int): The size of the padding used for each convolutional + layer. + normalization_factor (float | int): What the input is divided by before + the forward pass of the network. + noisy (bool): Whether the MLP part of the network will use + :py:class:`~hive.agents.qnets.noisy_linear.NoisyLinear` layers or + :py:class:`torch.nn.Linear` layers. + std_init (float): The range for the initialization of the standard + deviation of the weights in + :py:class:`~hive.agents.qnets.noisy_linear.NoisyLinear`. + """ + super().__init__() + self._normalization_factor = normalization_factor + if channels is not None: + if isinstance(kernel_sizes, int): + kernel_sizes = [kernel_sizes] * len(channels) + if isinstance(strides, int): + strides = [strides] * len(channels) + if isinstance(paddings, int): + paddings = [paddings] * len(channels) + + if not all( + len(x) == len(channels) for x in [kernel_sizes, strides, paddings] + ): + raise ValueError("The lengths of the parameter lists must be the same") + + # Convolutional Layers + channels.insert(0, in_dim[0]) + conv_seq = [] + for i in range(0, len(channels) - 1): + conv_seq.append( + torch.nn.Conv2d( + in_channels=channels[i], + out_channels=channels[i + 1], + kernel_size=kernel_sizes[i], + stride=strides[i], + padding=paddings[i], + ) + ) + conv_seq.append(torch.nn.ReLU()) + self.conv = torch.nn.Sequential(*conv_seq) + else: + self.conv = torch.nn.Identity() + + # RNN Layers + conv_output_size = calculate_output_dim(self.conv, in_dim) + self.rnn = sequence_fn( + rnn_input_size=np.prod(conv_output_size), + ) + + if mlp_layers is not None: + # MLP Layers + self.mlp = MLPNetwork( + sequence_fn.keywords["rnn_hidden_size"], + mlp_layers, + noisy=noisy, + std_init=std_init, + ) + else: + self.mlp = nn.Identity() + + def forward(self, x, hidden_state=None): + # Act: sequence length is 1; Update: sequence length pre-defined. + B, L, C, H, W = x.size() + x = x.reshape(B * L, C, H, W) + + x = x.float() + x = x / self._normalization_factor + x = self.conv(x) + + _, C, H, W = x.size() + x = x.view(B, L, C, H, W) + if hidden_state is None: + hidden_state = self.init_hidden(B) + x = torch.flatten(x, start_dim=2, end_dim=-1) # (B, L, -1) + x, hidden_state = self.rnn(x, hidden_state) + x = self.mlp(x.reshape((B * L, -1))) + return x, hidden_state + + def init_hidden(self, batch_size): + hidden_state = self.rnn.init_hidden( + batch_size=batch_size, + ) + + return hidden_state + + def update_rnn_device(self): + self.rnn.update_device() diff --git a/hive/agents/qnets/sequence_models.py b/hive/agents/qnets/sequence_models.py new file mode 100644 index 00000000..19159052 --- /dev/null +++ b/hive/agents/qnets/sequence_models.py @@ -0,0 +1,148 @@ +import torch +from torch import nn + +from hive.utils.registry import registry, Registrable +from hive.agents.qnets.base import FunctionApproximator + + +class SequenceFn(Registrable): + """A wrapper for callables that produce sequence functions.""" + + @classmethod + def type_name(cls): + return "sequence_fn" + + +class SequenceModel(nn.Module): + """ + Base sequence neural network architecture. + """ + + def __init__( + self, + rnn_hidden_size=128, + num_rnn_layers=1, + device="cpu", + ): + """ + Args: + rnn_hidden_size (int): The number of features in the hidden state h. + num_rnn_layers (int): Number of recurrent layers. + device: Device on which all computations should be run. + """ + super().__init__() + self._rnn_hidden_size = rnn_hidden_size + self._num_rnn_layers = num_rnn_layers + self.core = None + self._device = device + + def forward(self, x, hidden_state=None): + x, hidden_state = self.core(x, hidden_state) + return x, hidden_state + + def update_device(self): + self._device = next(self.core.parameters()).device + + +class LSTMModel(SequenceModel): + """ + A multi-layer long short-term memory (LSTM) RNN. + """ + + def __init__( + self, + rnn_input_size=256, + rnn_hidden_size=128, + num_rnn_layers=1, + batch_first=True, + device="cpu", + ): + """ + Args: + rnn_input_size (int): The number of expected features in the input x. + rnn_hidden_size (int): The number of features in the hidden state h. + num_rnn_layers (int): Number of recurrent layers. + batch_first (bool): If True, then the input and output tensors are + provided as (batch, seq, feature) instead of (seq, batch, feature). + """ + super().__init__( + rnn_hidden_size=rnn_hidden_size, + num_rnn_layers=num_rnn_layers, + device=device, + ) + self.core = nn.LSTM( + input_size=rnn_input_size, + hidden_size=self._rnn_hidden_size, + num_layers=self._num_rnn_layers, + batch_first=batch_first, + ) + + def init_hidden(self, batch_size): + hidden_state = ( + torch.zeros( + (self._num_rnn_layers, batch_size, self._rnn_hidden_size), + dtype=torch.float32, + device=self._device, + ), + torch.zeros( + (self._num_rnn_layers, batch_size, self._rnn_hidden_size), + dtype=torch.float32, + device=self._device, + ), + ) + + return hidden_state + + +class GRUModel(SequenceModel): + """ + A multi-layer gated recurrent unit (GRU) RNN. + """ + + def __init__( + self, + rnn_input_size=256, + rnn_hidden_size=128, + num_rnn_layers=1, + batch_first=True, + device="cpu", + ): + """ + Args: + rnn_input_size (int): The number of expected features in the input x. + rnn_hidden_size (int): The number of features in the hidden state h. + num_rnn_layers (int): Number of recurrent layers. + batch_first (bool): If True, then the input and output tensors are + provided as (batch, seq, feature) instead of (seq, batch, feature). + """ + super().__init__( + rnn_hidden_size=rnn_hidden_size, + num_rnn_layers=num_rnn_layers, + device=device, + ) + self.core = nn.GRU( + input_size=rnn_input_size, + hidden_size=self._rnn_hidden_size, + num_layers=self._num_rnn_layers, + batch_first=batch_first, + ) + + def init_hidden(self, batch_size): + hidden_state = torch.zeros( + (self._num_rnn_layers, batch_size, self._rnn_hidden_size), + dtype=torch.float32, + device=self._device, + ) + + return hidden_state + + +registry.register_all( + SequenceFn, + { + "LSTM": LSTMModel, + "GRU": GRUModel, + }, +) + +get_sequence_fn = getattr(registry, f"get_{SequenceFn.type_name()}") diff --git a/hive/agents/qnets/utils.py b/hive/agents/qnets/utils.py index f1fb976f..16bbcfb7 100644 --- a/hive/agents/qnets/utils.py +++ b/hive/agents/qnets/utils.py @@ -20,9 +20,18 @@ def calculate_output_dim(net, input_shape): """ if isinstance(input_shape, int): input_shape = (input_shape,) - placeholder = torch.zeros((0,) + tuple(input_shape)) + placeholder = torch.zeros((1,) + tuple(input_shape)) output = net(placeholder) - return output.size()[1:] + return extract_shape(output) + + +def extract_shape(x): + if isinstance(x, torch.Tensor): + return x.size()[1:] + elif isinstance(x, tuple) or isinstance(x, list): + return tuple(extract_shape(y) for y in x) + else: + raise ValueError("Invalid argument shape") def create_init_weights_fn(initialization_fn): diff --git a/hive/configs/atari/drqn.yml b/hive/configs/atari/drqn.yml new file mode 100644 index 00000000..77f67f9e --- /dev/null +++ b/hive/configs/atari/drqn.yml @@ -0,0 +1,86 @@ +run_name: &run_name 'atari-drqn' +train_steps: 50000000 +test_frequency: 250000 +test_episodes: 10 +max_steps_per_episode: 27000 +max_seq_len: &max_seq_len 10 +save_dir: 'experiment' +saving_schedule: + name: 'PeriodicSchedule' + kwargs: + off_value: False + on_value: True + period: 1000000 +environment: + name: 'AtariEnv' + kwargs: + env_name: 'Asterix' + +agent: + name: 'DRQNAgent' + kwargs: + representation_net: + name: 'ConvRNNNetwork' + kwargs: + channels: [32, 64, 64] + kernel_sizes: [8, 4, 3] + strides: [4, 2, 1] + paddings: [2, 2, 1] + mlp_layers: [512] + sequence_fn: + name: 'LSTM' + kwargs: + rnn_hidden_size: 128 + num_rnn_layers: 1 + optimizer_fn: + name: 'RMSpropTF' + kwargs: + lr: 0.00025 + alpha: .95 + eps: 0.00001 + centered: True + init_fn: + name: 'xavier_uniform' + loss_fn: + name: 'SmoothL1Loss' + replay_buffer: + name: 'RecurrentReplayBuffer' + kwargs: + capacity: 1000000 + gamma: &gamma .99 + max_seq_len: *max_seq_len + discount_rate: *gamma + reward_clip: 1 + update_period_schedule: + name: 'PeriodicSchedule' + kwargs: + off_value: False + on_value: True + period: 4 + target_net_update_schedule: + name: 'PeriodicSchedule' + kwargs: + off_value: False + on_value: True + period: 8000 + epsilon_schedule: + name: 'LinearSchedule' + kwargs: + init_value: 1.0 + end_value: .01 + steps: 250000 + test_epsilon: .001 + min_replay_history: 200 + device: 'cuda' + log_frequency: 1000 +# List of logger configs used. +loggers: + - + name: ChompLogger + - + name: WandbLogger + kwargs: + project: Hive + name: *run_name + resume: "allow" + start_method: "fork" diff --git a/hive/replays/__init__.py b/hive/replays/__init__.py index 4302d9c9..f4fe8211 100644 --- a/hive/replays/__init__.py +++ b/hive/replays/__init__.py @@ -1,6 +1,7 @@ from hive.replays.circular_replay import CircularReplayBuffer, SimpleReplayBuffer from hive.replays.legal_moves_replay import LegalMovesBuffer from hive.replays.prioritized_replay import PrioritizedReplayBuffer +from hive.replays.recurrent_replay import RecurrentReplayBuffer from hive.replays.replay_buffer import BaseReplayBuffer from hive.utils.registry import registry @@ -8,9 +9,10 @@ BaseReplayBuffer, { "CircularReplayBuffer": CircularReplayBuffer, - "SimpleReplayBuffer": SimpleReplayBuffer, - "PrioritizedReplayBuffer": PrioritizedReplayBuffer, "LegalMovesBuffer": LegalMovesBuffer, + "PrioritizedReplayBuffer": PrioritizedReplayBuffer, + "RecurrentReplayBuffer": RecurrentReplayBuffer, + "SimpleReplayBuffer": SimpleReplayBuffer, }, ) diff --git a/hive/replays/recurrent_replay.py b/hive/replays/recurrent_replay.py new file mode 100644 index 00000000..e7beb20c --- /dev/null +++ b/hive/replays/recurrent_replay.py @@ -0,0 +1,262 @@ +import os +import pickle + +import numpy as np +from hive.replays.circular_replay import CircularReplayBuffer + + +class RecurrentReplayBuffer(CircularReplayBuffer): + """ + First implementation of recurrent buffer without storing hidden states + """ + + def __init__( + self, + capacity: int = 10000, + max_seq_len: int = 1, + n_step: int = 1, + gamma: float = 0.99, + observation_shape=(), + observation_dtype=np.uint8, + action_shape=(), + action_dtype=np.int8, + reward_shape=(), + reward_dtype=np.float32, + extra_storage_types=None, + num_players_sharing_buffer: int = None, + ): + """Constructor for RecurrentReplayBuffer. + + Args: + capacity (int): Total number of observations that can be stored in the + buffer. Note, this is not the same as the number of transitions that + can be stored in the buffer. + max_seq_len (int): The number of consecutive transitions in a sequence + sampled from an episode. + n_step (int): Horizon used to compute n-step return reward + gamma (float): Discounting factor used to compute n-step return reward + observation_shape: Shape of observations that will be stored in the buffer. + observation_dtype: Type of observations that will be stored in the buffer. + This can either be the type itself or string representation of the + type. The type can be either a native python type or a numpy type. If + a numpy type, a string of the form np.uint8 or numpy.uint8 is + acceptable. + action_shape: Shape of actions that will be stored in the buffer. + action_dtype: Type of actions that will be stored in the buffer. Format is + described in the description of observation_dtype. + reward_shape: Shape of rewards that will be stored in the buffer. + reward_dtype: Type of rewards that will be stored in the buffer. Format is + described in the description of observation_dtype. + extra_storage_types (dict): A dictionary describing extra items to store + in the buffer. The mapping should be from the name of the item to a + (type, shape) tuple. + num_players_sharing_buffer (int): Number of agents that share their + buffers. It is used for self-play. + """ + super().__init__( + capacity=capacity, + stack_size=1, + n_step=n_step, + gamma=gamma, + observation_shape=observation_shape, + observation_dtype=observation_dtype, + action_shape=action_shape, + action_dtype=action_dtype, + reward_shape=reward_shape, + reward_dtype=reward_dtype, + extra_storage_types=extra_storage_types, + num_players_sharing_buffer=num_players_sharing_buffer, + ) + self._max_seq_len = max_seq_len + + def size(self): + """Returns the number of transitions stored in the buffer.""" + return max( + min(self._num_added, self._capacity) - self._max_seq_len - self._n_step + 1, + 0, + ) + + def add(self, observation, action, reward, done, **kwargs): + """Adds a transition to the buffer. + The required components of a transition are given as positional arguments. The + user can pass additional components to store in the buffer as kwargs as long as + they were defined in the specification in the constructor. + """ + + if self._episode_start: + self._pad_buffer(self._max_seq_len - 1) + self._episode_start = False + transition = { + "observation": observation, + "action": action, + "reward": reward, + "done": done, + } + transition.update(kwargs) + for key in self._specs: + obj_type = ( + transition[key].dtype + if hasattr(transition[key], "dtype") + else type(transition[key]) + ) + if not np.can_cast(obj_type, self._specs[key][0], casting="same_kind"): + raise ValueError( + f"Key {key} has wrong dtype. Expected {self._specs[key][0]}," + f"received {type(transition[key])}." + ) + if self._num_players_sharing_buffer is None: + self._add_transition(**transition) + else: + self._episode_storage[kwargs["agent_id"]].append(transition) + if done: + for transition in self._episode_storage[kwargs["agent_id"]]: + self._add_transition(**transition) + self._episode_storage[kwargs["agent_id"]] = [] + + if done: + self._episode_start = True + + def _get_from_array(self, array, indices, num_to_access=1): + """Retrieves consecutive elements in the array, wrapping around if necessary. + If more than 1 element is being accessed, the elements are concatenated along + the first dimension. + Args: + array: array to access from + indices: starts of ranges to access from + num_to_access: how many consecutive elements to access + """ + full_indices = np.indices((indices.shape[0], num_to_access))[1] + full_indices = (full_indices + np.expand_dims(indices, axis=1)) % ( + self.size() + self._max_seq_len + self._n_step - 1 + ) + elements = array[full_indices] + elements = elements.reshape(indices.shape[0], -1, *elements.shape[2:]) + return elements + + def _get_from_storage(self, key, indices, num_to_access=1): + """Gets values from storage. + Args: + key: The name of the component to retrieve. + indices: This can be a single int or a 1D numpyp array. The indices are + adjusted to fall within the current bounds of the buffer. + num_to_access: how many consecutive elements to access + """ + if not isinstance(indices, np.ndarray): + indices = np.array([indices]) + if num_to_access == 0: + return np.array([]) + elif num_to_access == 1: + return self._storage[key][ + indices % (self.size() + self._max_seq_len + self._n_step - 1) + ] + else: + return self._get_from_array( + self._storage[key], indices, num_to_access=num_to_access + ) + + def _sample_indices(self, batch_size): + """Samples valid indices that can be used by the replay.""" + indices = np.array([], dtype=np.int32) + while len(indices) < batch_size: + start_index = ( + self._rng.integers(self.size(), size=batch_size - len(indices)) + + self._cursor + ) + start_index = self._filter_transitions(start_index) + indices = np.concatenate([indices, start_index]) + return indices + self._max_seq_len - 1 + + def _filter_transitions(self, indices): + """Filters invalid indices.""" + if self._max_seq_len == 1: + return indices + done = self._get_from_storage("done", indices, self._max_seq_len - 1) + done = done.astype(bool) + if self._max_seq_len == 2: + indices = indices[~done] + else: + indices = indices[~done.any(axis=1)] + return indices + + def sample(self, batch_size): + """Sample transitions from the buffer. For a given transition, if it's + done is True, the next_observation value should not be taken to have any + meaning. + + Args: + batch_size (int): Number of transitions to sample. + """ + if self._num_added < self._max_seq_len + self._n_step: + raise ValueError("Not enough transitions added to the buffer to sample") + indices = self._sample_indices(batch_size) + batch = {} + batch["indices"] = indices + terminals = self._get_from_storage( + "done", + indices - self._max_seq_len + 1, + num_to_access=self._max_seq_len + self._n_step - 1, + ) + + if self._n_step == 1: + is_terminal = terminals + trajectory_lengths = np.ones(batch_size) + else: + is_terminal = terminals.any(axis=1).astype(int) + trajectory_lengths = ( + np.argmax(terminals.astype(bool), axis=1) + 1 + ) * is_terminal + self._n_step * (1 - is_terminal) + is_terminal = terminals[:, 1 : self._n_step - 1] + trajectory_lengths = trajectory_lengths.astype(np.int64) + + for key in self._specs: + if key == "observation": + batch[key] = self._get_from_storage( + "observation", + indices - self._max_seq_len + 1, + num_to_access=self._max_seq_len, + ) + elif key == "action": + batch[key] = self._get_from_storage( + "action", + indices - self._max_seq_len + 1, + num_to_access=self._max_seq_len, + ) + elif key == "done": + batch["done"] = is_terminal + elif key == "reward": + rewards = self._get_from_storage( + "reward", + indices - self._max_seq_len + 1, + num_to_access=self._max_seq_len + self._n_step - 1, + ) + if self._max_seq_len + self._n_step - 1 == 1: + rewards = np.expand_dims(rewards, 1) + + if self._n_step == 1: + rewards = rewards * np.expand_dims(self._discount, axis=0) + + elif self._n_step > 1: + idx = np.arange(rewards.shape[1] - self._n_step + 1)[ + :, None + ] + np.arange( + self._n_step + ) # (S-N+1) x N + rewards = rewards[:, idx] # B x (S-N+1) x N + # Creating a vectorized sliding window to calculate + # discounted returns for every element in the sequence. + # Equivalent to + # np.sum(rewards * self._discount[None, None, :], axis=2) + disc_rewards = np.einsum("ijk,k->ij", rewards, self._discount) + rewards = disc_rewards + + batch["reward"] = rewards + else: + batch[key] = self._get_from_storage(key, indices) + + batch["trajectory_lengths"] = trajectory_lengths + batch["next_observation"] = self._get_from_storage( + "observation", + indices + trajectory_lengths - self._max_seq_len + 1, + num_to_access=self._max_seq_len, + ) + return batch