diff --git a/hive/agents/__init__.py b/hive/agents/__init__.py index 5c5e1d06..cab8903b 100644 --- a/hive/agents/__init__.py +++ b/hive/agents/__init__.py @@ -4,6 +4,7 @@ from hive.agents.dqn import DQNAgent from hive.agents.drqn import DRQNAgent from hive.agents.legal_moves_rainbow import LegalMovesRainbowAgent +from hive.agents.ppo import PPOAgent from hive.agents.rainbow import RainbowDQNAgent from hive.agents.random import RandomAgent from hive.agents.td3 import TD3 @@ -16,6 +17,7 @@ "DQNAgent": DQNAgent, "DRQNAgent": DRQNAgent, "LegalMovesRainbowAgent": LegalMovesRainbowAgent, + "PPOAgent": PPOAgent, "RainbowDQNAgent": RainbowDQNAgent, "RandomAgent": RandomAgent, "TD3": TD3, diff --git a/hive/agents/ppo.py b/hive/agents/ppo.py new file mode 100644 index 00000000..17e144c4 --- /dev/null +++ b/hive/agents/ppo.py @@ -0,0 +1,449 @@ +import os +from typing import Union + +import gymnasium as gym +import numpy as np +import torch + +from hive.agents.agent import Agent +from hive.agents.qnets.base import FunctionApproximator +from hive.agents.qnets.normalizer import ( + MovingAvgNormalizer, + RewardNormalizer, +) +from hive.agents.qnets.ac_nets import ActorCriticNetwork +from hive.agents.qnets.utils import ( + InitializationFn, + calculate_output_dim, + create_init_weights_fn, +) +from hive.replays.on_policy_replay import OnPolicyReplayBuffer +from hive.utils.loggers import Logger, NullLogger +from hive.utils.schedule import PeriodicSchedule, Schedule, ConstantSchedule +from hive.utils.utils import LossFn, OptimizerFn, create_folder + + +class PPOAgent(Agent): + """An agent implementing the PPO algorithm.""" + + def __init__( + self, + observation_space: gym.spaces.Box, + action_space: Union[gym.spaces.Discrete, gym.spaces.Box], + representation_net: FunctionApproximator = None, + actor_net: FunctionApproximator = None, + critic_net: FunctionApproximator = None, + init_fn: InitializationFn = None, + optimizer_fn: OptimizerFn = None, + anneal_lr_schedule: Schedule = None, + critic_loss_fn: LossFn = None, + observation_normalizer: MovingAvgNormalizer = None, + reward_normalizer: RewardNormalizer = None, + stack_size: int = 1, + replay_buffer: OnPolicyReplayBuffer = None, + discount_rate: float = 0.99, + n_step: int = 1, + grad_clip: float = None, + batch_size: int = 64, + logger: Logger = None, + log_frequency: int = 1, + clip_coefficient: float = 0.2, + entropy_coefficient: float = 0.01, + clip_value_loss: bool = True, + value_fn_coefficient: float = 0.5, + transitions_per_update: int = 1024, + num_epochs_per_update: int = 4, + normalize_advantages: bool = True, + target_kl: float = None, + device="cpu", + id=0, + ): + """ + Args: + observation_space (gym.spaces.Box): Observation space for the agent. + action_space (gym.spaces.Box): Action space for the agent. + representation_net (FunctionApproximator): The network that encodes the + observations that are then fed into the actor_net and critic_net. If + None, defaults to :py:class:`~torch.nn.Identity`. + actor_net (FunctionApproximator): The network that takes the encoded + observations from representation_net and outputs the representations + used to compute the actions (ie everything except the last layer). + critic_net (FunctionApproximator): The network that takes two inputs: the + encoded observations from representation_net and actions. It outputs + the representations used to compute the values of the actions (ie + everything except the last layer). + init_fn (InitializationFn): Initializes the weights of agent networks + using create_init_weights_fn. + optimizer_fn (OptimizerFn): A function that takes in the list of + parameters of the actor and critic returns the optimizer for the actor. + If None, defaults to :py:class:`~torch.optim.Adam`. + critic_loss_fn (LossFn): The loss function used to optimize the critic. If + None, defaults to :py:class:`~torch.nn.MSELoss`. + observation_normalizer (MovingAvgNormalizer): The function for + normalizing observations + reward_normalizer (RewardNormalizer): The function for normalizing + rewards + stack_size (int): Number of observations stacked to create the state fed + to the agent. + replay_buffer (OnPolicyReplayBuffer): The replay buffer that the agent will + push observations to and sample from during learning. If None, + defaults to + :py:class:`~hive.replays.circular_replay.OnPolicyReplayBuffer`. + discount_rate (float): A number between 0 and 1 specifying how much + future rewards are discounted by the agent. + n_step (int): The horizon used in n-step returns to compute TD(n) targets. + grad_clip (float): Gradients will be clipped to between + [-grad_clip, grad_clip]. + batch_size (int): The size of the batch sampled from the replay buffer + during learning. + logger (Logger): Logger used to log agent's metrics. + log_frequency (int): How often to log the agent's metrics. + clip_coefficient (float): A number between 0 and 1 specifying the clip ratio + for the surrogate objective function to penalise large changes in + the policy and/or critic. + entropy_coefficient (float): Coefficient for the entropy loss. + clip_value_loss (bool): Flag to use the clipped objective for the value + function. + value_fn_coefficient (float): Coefficient for the value function loss. + transitions_per_update (int): Total number of observations that are + stored before the update. + num_epochs_per_update (int): Number of iterations over the entire + buffer during an update step. + normalize_advantages (bool): Flag to normalise advantages before + calculating policy loss. + target_kl (float): Terminates the update if kl-divergence between old and + updated policy exceeds target_kl. + device: Device on which all computations should be run. + id: Agent identifier. + """ + super().__init__(observation_space, action_space, id) + self._device = torch.device("cpu" if not torch.cuda.is_available() else device) + self._state_size = ( + stack_size * self._observation_space.shape[0], + *self._observation_space.shape[1:], + ) + self._init_fn = create_init_weights_fn(init_fn) + self.create_networks( + representation_net, + actor_net, + critic_net, + ) + if observation_normalizer is not None: + self._observation_normalizer = observation_normalizer(self._state_size) + else: + self._observation_normalizer = None + + if reward_normalizer is not None: + self._reward_normalizer = reward_normalizer(discount_rate) + else: + self._reward_normalizer = None + + if optimizer_fn is None: + optimizer_fn = torch.optim.Adam + self._optimizer = optimizer_fn(self._actor_critic.parameters()) + if anneal_lr_schedule is None: + anneal_lr_schedule = ConstantSchedule(1.0) + else: + anneal_lr_schedule = anneal_lr_schedule() + + self._lr_scheduler = torch.optim.lr_scheduler.LambdaLR( + self._optimizer, lambda x: anneal_lr_schedule.update() + ) + if replay_buffer is None: + replay_buffer = OnPolicyReplayBuffer + self._replay_buffer = replay_buffer( + capacity=transitions_per_update, + observation_shape=self._observation_space.shape, + observation_dtype=self._observation_space.dtype, + action_shape=self._action_space.shape, + action_dtype=self._action_space.dtype, + gamma=discount_rate, + ) + self._discount_rate = discount_rate**n_step + self._grad_clip = grad_clip + if critic_loss_fn is None: + critic_loss_fn = torch.nn.MSELoss + self._critic_loss_fn = critic_loss_fn(reduction="none") + self._batch_size = batch_size + self._logger = logger + if self._logger is None: + self._logger = NullLogger([]) + self._timescale = self.id + self._logger.register_timescale( + self._timescale, PeriodicSchedule(False, True, log_frequency) + ) + self._clip_coefficient = clip_coefficient + self._entropy_coefficient = entropy_coefficient + self._clip_value_loss = clip_value_loss + self._value_fn_coefficient = value_fn_coefficient + self._transitions_per_update = transitions_per_update + self._num_epochs_per_update = num_epochs_per_update + self._normalize_advantages = normalize_advantages + self._target_kl = target_kl + + self._training = False + + def create_networks(self, representation_net, actor_net, critic_net): + """Creates the actor and critic networks. + + Args: + representation_net: A network that outputs the shared representations that + will be used by the actor and critic networks to process observations. + actor_net: The network that will be used to compute actions. + critic_net: The network that will be used to compute values of states. + """ + if representation_net is None: + network = torch.nn.Identity() + else: + network = representation_net(self._state_size) + + network_output_shape = calculate_output_dim(network, self._state_size) + self._actor_critic = ActorCriticNetwork( + network, + actor_net, + critic_net, + network_output_shape, + self._action_space, + isinstance(self._action_space, gym.spaces.Box), + ).to(self._device) + + self._actor_critic.apply(self._init_fn) + + def train(self): + """Changes the agent to training mode.""" + super().train() + self._actor_critic.train() + + def eval(self): + """Changes the agent to evaluation mode.""" + super().eval() + self._actor_critic.eval() + + def preprocess_update_info(self, update_info, agent_traj_state): + """Preprocesses the :obj:`update_info` before it goes into the replay buffer. + + Args: + update_info: Contains the information from the current timestep that the + agent should use to update itself. + """ + if self._observation_normalizer: + update_info["observation"] = self._observation_normalizer( + update_info["observation"] + ) + + done = update_info["terminated"] or update_info["truncated"] + if self._reward_normalizer: + self._reward_normalizer.update(update_info["reward"], done) + update_info["reward"] = self._reward_normalizer(update_info["reward"]) + + preprocessed_update_info = { + "observation": update_info["observation"], + "action": update_info["action"], + "reward": update_info["reward"], + "done": done, + "logprob": agent_traj_state["logprob"], + "values": agent_traj_state["value"], + "returns": np.empty(agent_traj_state["value"].shape), + "advantages": np.empty(agent_traj_state["value"].shape), + } + if "agent_id" in update_info: + preprocessed_update_info["agent_id"] = int(update_info["agent_id"]) + + return preprocessed_update_info + + def preprocess_update_batch(self, batch): + """Returns preprocesed batch sampled from the replay buffer. + + Args: + batch: Batch sampled from the replay buffer for the current update. + """ + for key in batch: + batch[key] = torch.tensor(batch[key], device=self._device) + + return batch + + @torch.no_grad() + def get_action_logprob_value(self, observation): + """Returns the action, logprob, and value for the agent + + Args: + observation: The current observation. + """ + observation = torch.tensor( + np.expand_dims(observation, axis=0), device=self._device + ).float() + action, logprob, _, value = self._actor_critic(observation) + action = action.cpu().detach().numpy() + logprob = logprob.cpu().numpy() + value = value.cpu().numpy() + action = action[0] + + return action, logprob, value + + @torch.no_grad() + def act(self, observation, agent_traj_state=None): + """Returns the action for the agent. + + Args: + observation: The current observation. + """ + if agent_traj_state is None: + agent_traj_state = {} + if self._observation_normalizer: + self._observation_normalizer.update(observation) + observation = self._observation_normalizer(observation) + action, logprob, value = self.get_action_logprob_value(observation) + agent_traj_state["logprob"] = logprob + agent_traj_state["value"] = value + return action, agent_traj_state + + def update(self, update_info, agent_traj_state=None): + """ + Updates the PPO agent. + + Args: + update_info: dictionary containing all the necessary information to + update the agent. Should contain a full transition, with keys for + "observation", "next_observation", "action", "reward", "terminated", + and "truncated". + """ + if not self._training: + return + + # Add the most recent transition to the replay buffer. + self._replay_buffer.add( + **self.preprocess_update_info(update_info, agent_traj_state) + ) + + if self._replay_buffer.size() >= self._transitions_per_update - 1: + if self._observation_normalizer: + update_info["next_observation"] = self._observation_normalizer( + update_info["next_observation"] + ) + _, _, values = self.get_action_logprob_value( + update_info["next_observation"] + ) + self._replay_buffer.compute_advantages(values) + clip_fraction = 0 + num_updates = 0 + for _ in range(self._num_epochs_per_update): + for batch in self._replay_buffer.sample(batch_size=self._batch_size): + batch = self.preprocess_update_batch(batch) + self._optimizer.zero_grad() + + _, logprob, entropy, values = self._actor_critic( + batch["observation"], batch["action"] + ) + logratios = logprob - batch["logprob"] + ratios = torch.exp(logratios) + advantages = batch["advantages"] + if self._normalize_advantages: + advantages = (advantages - advantages.mean()) / ( + advantages.std() + 1e-8 + ) + # Actor loss + loss_unclipped = -advantages * ratios + loss_clipped = -advantages * torch.clamp( + ratios, 1 - self._clip_coefficient, 1 + self._clip_coefficient + ) + actor_loss = torch.max(loss_unclipped, loss_clipped).mean() + entropy_loss = entropy.mean() + + # Critic loss + values = values.view(-1) + if self._clip_value_loss: + v_loss_unclipped = self._critic_loss_fn( + values, batch["returns"] + ) + v_clipped = batch["values"] + torch.clamp( + values - batch["values"], + -self._clip_coefficient, + self._clip_coefficient, + ) + v_loss_clipped = self._critic_loss_fn( + v_clipped, batch["returns"] + ) + v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) + critic_loss = 0.5 * v_loss_max.mean() + else: + critic_loss = ( + 0.5 * self._critic_loss_fn(values, batch["returns"]).mean() + ) + + loss = ( + actor_loss + - self._entropy_coefficient * entropy_loss + + self._value_fn_coefficient * critic_loss + ) + loss.backward() + + if self._grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + self._actor_critic.parameters(), self._grad_clip + ) + + self._optimizer.step() + num_updates += 1 + + with torch.no_grad(): + # calculate approx_kl + # http://joschu.net/blog/kl-approx.html + old_approx_kl = (-logratios).mean() + approx_kl = ((ratios - 1) - logratios).mean() + clip_fraction += ( + ((ratios - 1.0).abs() > self._clip_coefficient) + .float() + .mean() + .item() + ) + + if self._target_kl is not None and self._target_kl < approx_kl: + break + self._replay_buffer.reset() + if self._logger.update_step(self._timescale): + self._logger.log_metrics( + { + "loss": loss, + "actor_loss": actor_loss, + "critic_loss": critic_loss, + "entropy_loss": entropy_loss, + "approx_kl": approx_kl, + "old_approx_kl": old_approx_kl, + "clip_fraction": clip_fraction / num_updates, + "lr": self._lr_scheduler.get_last_lr()[0], + }, + prefix=self._timescale, + ) + self._lr_scheduler.step() + return agent_traj_state + + def save(self, dname): + state_dict = { + "actor_critic": self._actor_critic.state_dict(), + "optimizer": self._optimizer.state_dict(), + } + if self._observation_normalizer: + state_dict[ + "observation_normalizer" + ] = self._observation_normalizer.state_dict() + if self._reward_normalizer: + state_dict["reward_normalizer"] = self._reward_normalizer.state_dict() + torch.save( + state_dict, + os.path.join(dname, "agent.pt"), + ) + replay_dir = os.path.join(dname, "replay") + create_folder(replay_dir) + self._replay_buffer.save(replay_dir) + + def load(self, dname): + checkpoint = torch.load(os.path.join(dname, "agent.pt")) + self._actor_critic.load_state_dict(checkpoint["actor_critic"]) + self._optimizer.load_state_dict(checkpoint["optimizer"]) + self._replay_buffer.load(os.path.join(dname, "replay")) + if self._observation_normalizer: + self._observation_normalizer.load_state_dict( + checkpoint["observation_normalizer"] + ) + if self._reward_normalizer: + self._reward_normalizer.load_state_dict(checkpoint["reward_normalizer"]) diff --git a/hive/agents/qnets/ac_nets.py b/hive/agents/qnets/ac_nets.py new file mode 100644 index 00000000..17775ef0 --- /dev/null +++ b/hive/agents/qnets/ac_nets.py @@ -0,0 +1,119 @@ +from typing import Tuple, Union + +import gymnasium as gym +import numpy as np +import torch +from gymnasium.spaces import Box, Discrete + +from hive.agents.qnets.base import FunctionApproximator +from hive.agents.qnets.utils import calculate_output_dim + + +class CategoricalHead(torch.nn.Module): + """A module that implements a discrete actor head. It uses the ouput from + the :obj:`actor_net`, and adds creates a + :py:class:`~torch.distributions.categorical.Categorical` object to compute + the action distribution.""" + + def __init__( + self, feature_dim: Tuple[int], action_space: gym.spaces.Discrete + ) -> None: + """ + Args: + feature dim: Expected output shape of the actor network. + action_shape: Expected shape of actions. + """ + super().__init__() + self.network = torch.nn.Linear(feature_dim, action_space.n) + self.distribution = torch.distributions.categorical.Categorical + + def forward(self, x): + logits = self.network(x) + return self.distribution(logits=logits) + + +class GaussianPolicyHead(torch.nn.Module): + """A module that implements a continuous actor head. It uses the output from the + :obj:`actor_net` and state independent learnable parameter :obj:`policy_logstd` to + create a :py:class:`~torch.distributions.normal.Normal` object to compute + the action distribution.""" + + def __init__(self, feature_dim: Tuple[int], action_space: gym.spaces.Box) -> None: + """ + Args: + feature dim: Expected output shape of the actor network. + action_shape: Expected shape of actions. + """ + super().__init__() + self._action_shape = action_space.shape + self.policy_mean = torch.nn.Sequential( + torch.nn.Linear(feature_dim, np.prod(self._action_shape)) + ) + self.policy_logstd = torch.nn.Parameter( + torch.zeros(1, np.prod(action_space.shape)) + ) + self.distribution = torch.distributions.normal.Normal + + def forward(self, x): + _mean = self.policy_mean(x) + _std = self.policy_logstd.repeat(x.shape[0], 1).exp() + distribution = self.distribution( + torch.reshape(_mean, (x.size(0), *self._action_shape)), + torch.reshape(_std, (x.size(0), *self._action_shape)), + ) + return distribution + + +class ActorCriticNetwork(torch.nn.Module): + """A module that implements the actor and critic computation. It puts together + the :obj:`representation_network`, :obj:`actor_net` and :obj:`critic_net`, then + adds two final :py:class:`~torch.nn.Linear` layers to compute the action and state + value.""" + + def __init__( + self, + representation_network: torch.nn.Module, + actor_net: FunctionApproximator, + critic_net: FunctionApproximator, + network_output_dim: Union[int, Tuple[int]], + action_space: Union[Box, Discrete], + continuous_action: bool, + ) -> None: + super().__init__() + self._network = representation_network + self._continuous_action = continuous_action + if actor_net is None: + actor_network = torch.nn.Identity() + else: + actor_network = actor_net(network_output_dim) + feature_dim = np.prod(calculate_output_dim(actor_network, network_output_dim)) + actor_head = GaussianPolicyHead if self._continuous_action else CategoricalHead + + self.actor = torch.nn.Sequential( + actor_network, + torch.nn.Flatten(), + actor_head(feature_dim, action_space), + ) + + if critic_net is None: + critic_network = torch.nn.Identity() + else: + critic_network = critic_net(network_output_dim) + feature_dim = np.prod(calculate_output_dim(critic_network, network_output_dim)) + self.critic = torch.nn.Sequential( + critic_network, + torch.nn.Flatten(), + torch.nn.Linear(feature_dim, 1), + ) + + def forward(self, x, action=None): + hidden_state = self._network(x) + distribution = self.actor(hidden_state) + value = self.critic(hidden_state) + if action is None: + action = distribution.sample() + + logprob, entropy = distribution.log_prob(action), distribution.entropy() + if self._continuous_action: + logprob, entropy = logprob.sum(dim=-1), entropy.sum(dim=-1) + return action, logprob, entropy, value diff --git a/hive/agents/qnets/normalizer.py b/hive/agents/qnets/normalizer.py new file mode 100644 index 00000000..7297085d --- /dev/null +++ b/hive/agents/qnets/normalizer.py @@ -0,0 +1,177 @@ +import abc +from typing import Tuple + +import numpy as np + +from hive.utils.registry import Registrable, registry + + +# taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py +class MeanStd: + """Tracks the mean, variance and count of values.""" + + # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm + def __init__(self, epsilon=1e-4, shape=()): + """Tracks the mean, variance and count of values.""" + self.mean = np.zeros(shape, "float64") + self.var = np.ones(shape, "float64") + self.count = epsilon + + def update(self, x): + """Updates the mean, var and count from a batch of samples.""" + batch_mean = np.mean(x, axis=0) + batch_var = np.var(x, axis=0) + batch_count = x.shape[0] + self.update_from_moments(batch_mean, batch_var, batch_count) + + def update_from_moments(self, batch_mean, batch_var, batch_count): + """Updates from batch mean, variance and count moments.""" + self.mean, self.var, self.count = self.update_mean_var_count_from_moments( + self.mean, self.var, self.count, batch_mean, batch_var, batch_count + ) + + def update_mean_var_count_from_moments( + self, mean, var, count, batch_mean, batch_var, batch_count + ): + """Updates the mean, var and count using the previous mean, var, count + and batch values.""" + delta = batch_mean - mean + tot_count = count + batch_count + + new_mean = mean + delta * batch_count / tot_count + m_a = var * count + m_b = batch_var * batch_count + M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count + new_var = M2 / tot_count + new_count = tot_count + + return new_mean, new_var, new_count + + def state_dict(self): + """Returns the state as a dictionary.""" + return {"mean": self.mean, "var": self.var, "count": self.count} + + def load_state_dict(self, state_dict): + """Loads the state from a dictionary.""" + self.mean = state_dict["mean"] + self.var = state_dict["var"] + self.count = state_dict["count"] + + +class Normalizer(Registrable): + """A wrapper for callables that produce normalization functions. + + These wrapped callables can be partially initialized through configuration + files or command line arguments. + """ + + @classmethod + def type_name(cls): + """ + Returns: + "norm_fn" + """ + return "norm_fn" + + @abc.abstractmethod + def state_dict(self): + """Returns the state of the normalizer as a dictionary.""" + + @abc.abstractmethod + def load_state_dict(self, state_dict): + """Loads the normalizer state from a dictionary.""" + + +class MovingAvgNormalizer(Normalizer): + """Implements a moving average normalization and clipping function. Normalizes + input data with the running mean and std. The normalized data is then clipped + within the specified range. + """ + + def __init__( + self, shape: Tuple[int, ...], epsilon: float = 1e-4, clip: np.float32 = np.inf + ): + """ + Args: + epsilon (float): minimum value of variance to avoid division by 0. + shape (tuple[int]): The shape of input data. + clip (np.float32): The clip value for the normalised data. + """ + super().__init__() + self._rms = MeanStd(epsilon, shape) + self._shape = shape + self._epsilon = epsilon + self._clip = clip + + def __call__(self, input_data): + input_data = np.array([input_data]) + input_data = ( + (input_data - self._rms.mean) / np.sqrt(self._rms.var + self._epsilon) + )[0] + if self._clip is not None: + input_data = np.clip(input_data, -self._clip, self._clip) + return input_data + + def update(self, input_data): + self._rms.update(input_data) + + def state_dict(self): + return self._rms.state_dict() + + def load_state_dict(self, state_dict): + self._rms.load_state_dict(state_dict) + + +class RewardNormalizer(Normalizer): + """Normalizes and clips rewards from the environment. Applies a discount-based + scaling scheme, where the rewards are divided by the standard deviation of a + rolling discounted sum of the rewards. The scaled rewards are then clipped within + specified range. + """ + + def __init__(self, gamma: float, epsilon: float = 1e-4, clip: np.float32 = np.inf): + """ + Args: + gamma (float): discount factor for the agent. + epsilon (float): minimum value of variance to avoid division by 0. + clip (np.float32): The clip value for the normalised data. + """ + super().__init__() + self._return_rms = MeanStd(epsilon, ()) + self._epsilon = epsilon + self._clip = clip + self._gamma = gamma + self._returns = np.zeros(1) + + def __call__(self, rew): + rew = np.array([rew]) + rew = (rew / np.sqrt(self._return_rms.var + self._epsilon))[0] + if self._clip is not None: + rew = np.clip(rew, -self._clip, self._clip) + return rew + + def update(self, rew, done): + self._returns = self._returns * self._gamma + rew + self._return_rms.update(self._returns) + self._returns *= 1 - done + + def state_dict(self): + state_dict = self._return_rms.state_dict() + state_dict["returns"] = self._returns + return state_dict + + def load_state_dict(self, state_dict): + self._returns = state_dict["returns"] + state_dict.pop("returns") + self._return_rms.load_state_dict(state_dict) + + +registry.register_all( + Normalizer, + { + "RewardNormalizer": RewardNormalizer, + "MovingAvgNormalizer": MovingAvgNormalizer, + }, +) + +get_norm_fn = getattr(registry, f"get_{Normalizer.type_name()}") diff --git a/hive/configs/atari/ppo.yml b/hive/configs/atari/ppo.yml new file mode 100644 index 00000000..328644f3 --- /dev/null +++ b/hive/configs/atari/ppo.yml @@ -0,0 +1,73 @@ +name: 'SingleAgentRunner' +kwargs: + experiment_manager: + name: 'Experiment' + kwargs: + name: &run_name 'atari-ppo' + save_dir: 'experiment' + saving_schedule: + name: 'PeriodicSchedule' + kwargs: + off_value: False + on_value: True + period: 1000000 + + train_steps: 10000000 + test_frequency: 250000 + test_episodes: 10 + max_steps_per_episode: 27000 + stack_size: &stack_size 4 + environment: + name: 'AtariEnv' + kwargs: + env_name: 'Breakout' + + agent: + name: 'PPOAgent' + kwargs: + representation_net: + name: 'ConvNetwork' + kwargs: + channels: [32, 64, 64] + kernel_sizes: [8, 4, 3] + strides: [4, 2, 1] + paddings: [2, 2, 1] + mlp_layers: [512] + optimizer_fn: + name: 'Adam' + kwargs: + lr: .00025 + init_fn: + name: 'orthogonal' + replay_buffer: + name: 'OnPolicyReplayBuffer' + kwargs: + stack_size: *stack_size + compute_advantage_fn: + name: "gae_advantages" + kwargs: + gae_lambda: 0.95 + + discount_rate: .99 + grad_clip: .5 + clip_coefficient: .1 + entropy_coefficient: .0 + clip_value_loss: True + value_fn_coefficient: .5 + transitions_per_update: 4096 + num_epochs_per_update: 4 + normalize_advantages: True + batch_size: 256 + device: 'cuda' + id: 'agent' + # List of logger configs used. + loggers: + - + name: ChompLogger + - + name: WandbLogger + kwargs: + project: Hive + name: *run_name + resume: "allow" + start_method: "fork" diff --git a/hive/configs/gym/ppo.yml b/hive/configs/gym/ppo.yml new file mode 100644 index 00000000..b49af1e6 --- /dev/null +++ b/hive/configs/gym/ppo.yml @@ -0,0 +1,72 @@ +name: 'SingleAgentRunner' +kwargs: + experiment_manager: + name: 'Experiment' + kwargs: + name: &run_name 'gym-ppo' + save_dir: 'experiment' + saving_schedule: + name: 'PeriodicSchedule' + kwargs: + off_value: False + on_value: True + period: 10000 + + train_steps: 500000 + test_frequency: 1000 + test_episodes: 10 + max_steps_per_episode: 500 + stack_size: &stack_size 1 + environment: + name: 'GymEnv' + kwargs: + env_name: 'CartPole-v1' + + agent: + name: 'PPOAgent' + kwargs: + actor_net: + name: 'MLPNetwork' + kwargs: + hidden_units: [256, 256] + critic_net: + name: 'MLPNetwork' + kwargs: + hidden_units: [256, 256] + replay_buffer: + name: 'OnPolicyReplayBuffer' + kwargs: + compute_advantage_fn: + name: "gae_advantages" + kwargs: + gae_lambda: 0.95 + + optimizer_fn: + name: 'Adam' + kwargs: + lr: .00025 + discount_rate: .99 + grad_clip: .5 + clip_coefficient: .2 + entropy_coefficient: .01 + clip_value_loss: True + value_fn_coefficient: .5 + transitions_per_update: 2048 + num_epochs_per_update: 4 + normalize_advantages: True + batch_size: 128 + device: 'cuda' + id: 'agent' + init_fn: + name: 'orthogonal' + # List of logger configs used. + loggers: + - + name: ChompLogger + - + name: WandbLogger + kwargs: + project: Hive + name: *run_name + resume: "allow" + start_method: "fork" diff --git a/hive/configs/mujoco/ppo.yml b/hive/configs/mujoco/ppo.yml new file mode 100644 index 00000000..475e6ebc --- /dev/null +++ b/hive/configs/mujoco/ppo.yml @@ -0,0 +1,93 @@ +name: 'SingleAgentRunner' +kwargs: + experiment_manager: + name: 'Experiment' + kwargs: + name: &run_name 'mujoco-ppo' + save_dir: 'experiment' + saving_schedule: + name: 'PeriodicSchedule' + kwargs: + off_value: False + on_value: True + period: 500000 + + train_steps: 1000000 + test_frequency: 10000 + test_episodes: 5 + max_steps_per_episode: 100000 + stack_size: &stack_size 1 + environment: + name: 'GymEnv' + kwargs: + env_name: 'HalfCheetah-v4' + env_wrappers: + - + name: 'ClipAction' + agent: + name: 'PPOAgent' + kwargs: + actor_net: + name: 'MLPNetwork' + kwargs: + hidden_units: [64, 64] + activation_fn: + name: 'Tanh' + critic_net: + name: 'MLPNetwork' + kwargs: + hidden_units: [64, 64] + observation_normalizer: + name: 'MovingAvgNormalizer' + kwargs: + clip: 10 + reward_normalizer: + name: 'RewardNormalizer' + kwargs: + clip: 10 + replay_buffer: + name: 'OnPolicyReplayBuffer' + kwargs: + compute_advantage_fn: + name: "gae_advantages" + kwargs: + gae_lambda: 0.95 + optimizer_fn: + name: 'Adam' + kwargs: + lr: .0003 + eps: 1.e-5 + anneal_lr_schedule: + name: LinearSchedule + kwargs: + init_value: 1.0 + end_value: 0.0 + steps: 488 # 1,000,000 // 2048 + discount_rate: .99 + grad_clip: .5 + clip_coefficient: .2 + entropy_coefficient: .0 + clip_value_loss: True + value_fn_coefficient: .5 + transitions_per_update: 2048 + num_epochs_per_update: 10 + normalize_advantages: True + batch_size: 64 + device: 'cuda' + id: 'agent' + init_fn: + name: 'orthogonal' + kwargs: + gain: 1.414 + + # List of logger configs used. + loggers: + - + name: ChompLogger + - + name: WandbLogger + kwargs: + project: Hive + name: *run_name + resume: "allow" + start_method: "fork" diff --git a/hive/envs/env_wrapper.py b/hive/envs/env_wrapper.py new file mode 100644 index 00000000..3b3e5a80 --- /dev/null +++ b/hive/envs/env_wrapper.py @@ -0,0 +1,23 @@ +from hive.utils.registry import Registrable + + +class EnvWrapper(Registrable): + """A wrapper for callables that produce environment wrappers. + + These wrapped callables can be partially initialized through configuration + files or command line arguments. + """ + + @classmethod + def type_name(cls): + """ + Returns: + "env_wrapper" + """ + return "env_wrapper" + + +def apply_wrappers(env, env_wrappers): + for wrapper in env_wrappers: + env = wrapper(env) + return env diff --git a/hive/envs/gym_env.py b/hive/envs/gym_env.py index be4ad7f0..b704a428 100644 --- a/hive/envs/gym_env.py +++ b/hive/envs/gym_env.py @@ -1,6 +1,13 @@ +import inspect +from typing import List + import gymnasium as gym +from gymnasium import wrappers + from hive.envs.base import BaseEnv from hive.envs.env_spec import EnvSpec +from hive.envs.env_wrapper import EnvWrapper, apply_wrappers +from hive.utils.registry import registry class GymEnv(BaseEnv): @@ -8,11 +15,19 @@ class GymEnv(BaseEnv): Class for loading gym environments. """ - def __init__(self, env_name, num_players=1, render_mode=None, **kwargs): + def __init__( + self, + env_name: str, + env_wrappers: List[EnvWrapper] = None, + num_players: int = 1, + render_mode: str = None, + **kwargs + ): """ Args: env_name (str): Name of the environment (NOTE: make sure it is available in gym.envs.registry.all()) + env_wrappers (List[EnvWrapper]): List of environment wrappers to apply. num_players (int): Number of players for the environment. render_mode (str): One of None, "human", "rgb_array", "ansi", or "rgb_array_list". See gym documentation for details. @@ -20,11 +35,11 @@ def __init__(self, env_name, num_players=1, render_mode=None, **kwargs): :py:meth:`create_env_spec` can be passed as keyword arguments to this constructor. """ - self.create_env(env_name, render_mode=render_mode, **kwargs) + self.create_env(env_name, env_wrappers, render_mode=render_mode, **kwargs) super().__init__(self.create_env_spec(env_name, **kwargs), num_players) self._seed = None - def create_env(self, env_name, **kwargs): + def create_env(self, env_name, env_wrappers, **kwargs): """Function used to create the environment. Subclasses can override this method if they are using a gym style environment that needs special logic. @@ -33,6 +48,9 @@ def create_env(self, env_name, **kwargs): """ self._env = gym.make(env_name, **kwargs) + if env_wrappers is not None: + self._env = apply_wrappers(self._env, env_wrappers) + def create_env_spec(self, env_name, **kwargs): """Function used to create the specification. Subclasses can override this method if they are using a gym style environment that needs special logic. @@ -73,3 +91,16 @@ def seed(self, seed=None): def close(self): self._env.close() + + +wrappers = [ + getattr(wrappers, x) + for x in dir(wrappers) + if inspect.isclass(getattr(wrappers, x)) + and issubclass(getattr(wrappers, x), gym.Wrapper) +] + +registry.register_all( + EnvWrapper, + {wrapper.__name__: wrapper for wrapper in wrappers}, +) diff --git a/hive/replays/__init__.py b/hive/replays/__init__.py index f4fe8211..5a87ff83 100644 --- a/hive/replays/__init__.py +++ b/hive/replays/__init__.py @@ -1,5 +1,6 @@ from hive.replays.circular_replay import CircularReplayBuffer, SimpleReplayBuffer from hive.replays.legal_moves_replay import LegalMovesBuffer +from hive.replays.on_policy_replay import OnPolicyReplayBuffer from hive.replays.prioritized_replay import PrioritizedReplayBuffer from hive.replays.recurrent_replay import RecurrentReplayBuffer from hive.replays.replay_buffer import BaseReplayBuffer @@ -10,6 +11,7 @@ { "CircularReplayBuffer": CircularReplayBuffer, "LegalMovesBuffer": LegalMovesBuffer, + "OnPolicyReplayBuffer": OnPolicyReplayBuffer, "PrioritizedReplayBuffer": PrioritizedReplayBuffer, "RecurrentReplayBuffer": RecurrentReplayBuffer, "SimpleReplayBuffer": SimpleReplayBuffer, diff --git a/hive/replays/on_policy_replay.py b/hive/replays/on_policy_replay.py new file mode 100644 index 00000000..5baaebe1 --- /dev/null +++ b/hive/replays/on_policy_replay.py @@ -0,0 +1,127 @@ +import numpy as np + +from hive.replays.circular_replay import CircularReplayBuffer +from hive.utils.advantage import AdvantageComputationFn + + +class OnPolicyReplayBuffer(CircularReplayBuffer): + """An extension of the CircularReplayBuffer for on-policy learning algorithms""" + + def __init__( + self, + capacity: int = 10000, + stack_size: int = 1, + n_step: int = 1, + gamma: float = 0.99, + compute_advantage_fn: AdvantageComputationFn = None, + observation_shape=(), + observation_dtype=np.uint8, + action_shape=(), + action_dtype=np.int8, + reward_shape=(), + reward_dtype=np.float32, + extra_storage_types=None, + ): + """Constructor for OnPolicyReplayBuffer. + + Args: + capacity (int): Total number of observations that can be stored in the + buffer + stack_size (int): The number of frames to stack to create an observation. + n_step (int): Horizon used to compute n-step return reward + gamma (float): Discounting factor used to compute n-step return reward + compute_advantage_fn (AdvantageComputationFn): Function used to compute the + advantages. + observation_shape: Shape of observations that will be stored in the buffer. + observation_dtype: Type of observations that will be stored in the buffer. + This can either be the type itself or string representation of the + type. The type can be either a native python type or a numpy type. If + a numpy type, a string of the form np.uint8 or numpy.uint8 is + acceptable. + action_shape: Shape of actions that will be stored in the buffer. + action_dtype: Type of actions that will be stored in the buffer. Format is + described in the description of observation_dtype. + action_shape: Shape of actions that will be stored in the buffer. + action_dtype: Type of actions that will be stored in the buffer. Format is + described in the description of observation_dtype. + reward_shape: Shape of rewards that will be stored in the buffer. + reward_dtype: Type of rewards that will be stored in the buffer. Format is + described in the description of observation_dtype. + extra_storage_types (dict): A dictionary describing extra items to store + in the buffer. The mapping should be from the name of the item to a + (type, shape) tuple. + """ + if extra_storage_types is None: + extra_storage_types = dict() + extra_storage_types.update( + { + "values": (np.float32, ()), + "returns": (np.float32, ()), + "advantages": (np.float32, ()), + "logprob": (np.float32, ()), + } + ) + super().__init__( + capacity + stack_size - 1, + stack_size, + n_step, + gamma, + observation_shape, + observation_dtype, + action_shape, + action_dtype, + reward_shape, + reward_dtype, + extra_storage_types, + ) + self._compute_advantage_fn = compute_advantage_fn + + # Taken from https://github.com/vwxyzjn/ppo-implementation-details/blob/main/ppo_shared.py + def compute_advantages(self, last_values): + """Compute advantages using rewards and value estimates.""" + ( + self._storage["advantages"], + self._storage["returns"], + ) = self._compute_advantage_fn( + self._storage["values"], + last_values, + self._storage["done"], + self._storage["reward"], + self._gamma, + ) + + def reset(self): + """Resets the storage.""" + if self._stack_size > 1: + saved_transitions = { + k: self._storage[k][-(self._stack_size - 1) :] + for k in self._storage.keys() + } + self._create_storage(self._capacity, self._specs) + for k in self._storage.keys(): + self._storage[k][: (self._stack_size - 1)] = saved_transitions[k] + else: + self._create_storage(self._capacity, self._specs) + self._cursor = self._stack_size - 1 + self._num_added = self._stack_size - 1 + + def _find_valid_indices(self): + """Filters invalid indices.""" + self._sample_cursor = 0 + self._valid_indices = self._filter_transitions(np.arange(self._capacity)) + self._valid_indices = self._rng.permutation(self._valid_indices) + return len(self._valid_indices) + + def _sample_indices(self, batch_size): + """Samples valid indices that can be used by the replay.""" + start = self._sample_cursor + end = min(len(self._valid_indices), (self._sample_cursor + batch_size)) + indices = self._valid_indices[start:end] + self._sample_cursor += batch_size + return indices + self._stack_size - 1 + + def sample(self, batch_size): + valid_ind_size = self._find_valid_indices() + num_batches = int(np.ceil(valid_ind_size / batch_size)) + for _ in range(num_batches): + yield super().sample(batch_size) diff --git a/hive/runners/multi_agent_loop.py b/hive/runners/multi_agent_loop.py index acdec227..26f48e42 100644 --- a/hive/runners/multi_agent_loop.py +++ b/hive/runners/multi_agent_loop.py @@ -158,6 +158,7 @@ def run_one_step( agent, { "observation": observation, + "next_observation": next_observation, "action": action, "info": other_info, }, diff --git a/hive/runners/single_agent_loop.py b/hive/runners/single_agent_loop.py index 95643c7f..667b4416 100644 --- a/hive/runners/single_agent_loop.py +++ b/hive/runners/single_agent_loop.py @@ -115,6 +115,7 @@ def run_one_step( info = { "observation": observation, + "next_observation": next_observation, "reward": reward, "action": action, "terminated": terminated, @@ -167,6 +168,7 @@ def run_end_step( info = { "observation": observation, + "next_observation": next_observation, "reward": reward, "action": action, "terminated": terminated, diff --git a/hive/utils/advantage.py b/hive/utils/advantage.py new file mode 100644 index 00000000..0b0b3902 --- /dev/null +++ b/hive/utils/advantage.py @@ -0,0 +1,84 @@ +import numpy as np + +from hive.utils.registry import registry +from hive.utils.utils import Registrable + + +class AdvantageComputationFn(Registrable): + """A wrapper for callables that produce Advantage Computation Functions.""" + + @classmethod + def type_name(cls): + """ + Returns: + "advantage_computation_fn" + """ + return "advantage_computation_fn" + + +def compute_gae_advantages( + values: np.ndarray, + last_values: np.ndarray, + dones: np.ndarray, + rewards: np.ndarray, + gamma: float, + gae_lambda: float, +): + """Helper function that computes advantages and returns using Generalized Advantage Estimation. + + Args: + values (np.ndarray): Value estimates for each step. + last_values (np.ndarray): Value estimate for the last step. + dones (np.ndarray): Done flags for each step. + rewards (np.ndarray): Rewards for each step. + gamma (float): Discount factor. + gae_lambda (float): GAE lambda parameter. + """ + last_gae_lambda = 0 + num_steps = len(values) + advantages = np.zeros_like(rewards) + for t in reversed(range(num_steps)): + next_values = last_values if t == num_steps - 1 else values[t + 1] + next_non_terminal = 1.0 - dones[t] + delta = rewards[t] + gamma * next_values * next_non_terminal - values[t] + advantages[t] = last_gae_lambda = ( + delta + gamma * gae_lambda * next_non_terminal * last_gae_lambda + ) + returns = advantages + values + return advantages, returns + + +def compute_standard_advantages( + values: np.ndarray, + last_values: np.ndarray, + dones: np.ndarray, + rewards: np.ndarray, + gamma: float, +): + """Helper function that computes advantages and returns using standard advantage estimation. + + Args: + values (np.ndarray): Value estimates for each step. + last_values (np.ndarray): Value estimate for the last step. + dones (np.ndarray): Done flags for each step. + rewards (np.ndarray): Rewards for each step. + gamma (float): Discount factor. + """ + num_steps = len(values) + advantages = np.zeros_like(rewards) + returns = np.zeros_like(rewards) + for t in reversed(range(num_steps)): + next_return = last_values if t == num_steps - 1 else returns[t + 1] + next_non_terminal = 1.0 - dones[t] + returns[t] = rewards[t] + gamma * next_non_terminal * next_return + advantages = returns - values + return advantages, returns + + +registry.register_all( + AdvantageComputationFn, + { + "gae_advantages": compute_gae_advantages, + "standard_advantages": compute_standard_advantages, + }, +)