Skip to content

Commit

Permalink
Feature/pixel a2c (#313)
Browse files Browse the repository at this point in the history
* Add pixel obs to A2C

* A2C atari

* Add ppo_mujoco cfg

* Add a2c_mujoco cfg

* Update rollout_steps

* Normalize adv a2c

* Update ppo_mujoco

* Add ortho init

* Fix output dim when MLP is not shared

* Removed ortho_init

* Update checkpoint saving

* Update ppo_mujoco cfg
  • Loading branch information
belerico authored Jul 12, 2024
1 parent 06225b2 commit b9e57ed
Show file tree
Hide file tree
Showing 10 changed files with 446 additions and 268 deletions.
147 changes: 102 additions & 45 deletions sheeprl/algos/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,22 @@
from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler
from torchmetrics import SumMetric

from sheeprl.algos.a2c.agent import A2CAgent, build_agent
from sheeprl.algos.a2c.loss import policy_loss, value_loss
from sheeprl.algos.a2c.utils import prepare_obs, test
from sheeprl.algos.a2c.loss import policy_loss
from sheeprl.algos.ppo.agent import PPOAgent, build_agent
from sheeprl.algos.ppo.loss import entropy_loss, value_loss
from sheeprl.algos.ppo.utils import normalize_obs, prepare_obs, test
from sheeprl.data import ReplayBuffer
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
from sheeprl.utils.metric import MetricAggregator
from sheeprl.utils.registry import register_algorithm
from sheeprl.utils.timer import timer
from sheeprl.utils.utils import gae, save_configs
from sheeprl.utils.utils import gae, normalize_tensor, save_configs


def train(
fabric: Fabric,
agent: A2CAgent,
agent: PPOAgent,
optimizer: torch.optim.Optimizer,
data: Dict[str, torch.Tensor],
aggregator: MetricAggregator,
Expand Down Expand Up @@ -62,15 +63,21 @@ def train(
# This is achieved by accumulating the gradients and calling the backward method only at the end.
for i, batch_idxes in enumerate(sampler):
batch = {k: v[batch_idxes] for k, v in data.items()}
obs = {k: v for k, v in batch.items() if k in cfg.algo.mlp_keys.encoder}
normalized_obs = normalize_obs(
batch, cfg.algo.cnn_keys.encoder, cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder
)

# is_accumulating is True for every i except for the last one
is_accumulating = i < len(sampler) - 1

with fabric.no_backward_sync(agent.feature_extractor, enabled=is_accumulating), fabric.no_backward_sync(
agent.actor, enabled=is_accumulating
), fabric.no_backward_sync(agent.critic, enabled=is_accumulating):
_, logprobs, values = agent(obs, torch.split(batch["actions"], agent.actions_dim, dim=-1))
_, logprobs, entropy, new_values = agent(
normalized_obs, torch.split(batch["actions"], agent.actions_dim, dim=-1)
)
if cfg.algo.normalize_advantages:
batch["advantages"] = normalize_tensor(batch["advantages"])

# Policy loss
pg_loss = policy_loss(
Expand All @@ -81,12 +88,19 @@ def train(

# Value loss
v_loss = value_loss(
values,
new_values,
batch["values"],
batch["returns"],
0.0,
False,
cfg.algo.loss_reduction,
)

loss = pg_loss + v_loss
# Entropy loss
ent_loss = entropy_loss(entropy, cfg.algo.loss_reduction)

# Total loss
loss = pg_loss + cfg.algo.vf_coef * v_loss + cfg.algo.ent_coef * ent_loss
fabric.backward(loss)

if not is_accumulating:
Expand All @@ -102,10 +116,18 @@ def train(

@register_algorithm(decoupled=False)
def main(fabric: Fabric, cfg: Dict[str, Any]):
if "minedojo" in cfg.env.wrapper._target_.lower():
raise ValueError(
"MineDojo is not currently supported by PPO agent, since it does not take "
"into consideration the action masks provided by the environment, but needed "
"in order to play correctly the game. "
"As an alternative you can use one of the Dreamers' agents."
)

# Initialize Fabric
rank = fabric.global_rank
world_size = fabric.world_size
device = fabric.device
fabric.seed_everything(cfg.seed)

# Resume from checkpoint
if cfg.checkpoint.resume_from:
Expand Down Expand Up @@ -139,19 +161,15 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

if not isinstance(observation_space, gym.spaces.Dict):
raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}")
if len(cfg.algo.mlp_keys.encoder) == 0:
raise RuntimeError("You should specify at least one MLP key for the encoder: `algo.mlp_keys.encoder=[state]`")
for k in cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder:
if k in observation_space.keys() and len(observation_space[k].shape) > 1:
raise ValueError(
"Only environments with vector-only observations are supported by the A2C agent. "
f"The observation with key '{k}' has shape {observation_space[k].shape}. "
f"Provided environment: {cfg.env.id}"
)
if cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder == []:
raise RuntimeError(
"You should specify at least one CNN keys or MLP keys from the cli: "
"`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`"
)
if cfg.metric.log_level > 0:
fabric.print("Encoder CNN keys:", cfg.algo.cnn_keys.encoder)
fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder)
obs_keys = cfg.algo.mlp_keys.encoder
obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder

is_continuous = isinstance(envs.single_action_space, gym.spaces.Box)
is_multidiscrete = isinstance(envs.single_action_space, gym.spaces.MultiDiscrete)
Expand All @@ -160,11 +178,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
if is_continuous
else (envs.single_action_space.nvec.tolist() if is_multidiscrete else [envs.single_action_space.n])
)

# Create the agent model: this should be a torch.nn.Module to be accelerated with Fabric
# Given that the environment has been created with the `make_env` method, the agent
# forward method must accept as input a dictionary like {"obs1_name": obs1, "obs2_name": obs2, ...}.
# The agent should be able to process both image and vector-like observations.
# Create the actor and critic models
agent, player = build_agent(
fabric,
actions_dim,
Expand All @@ -174,18 +188,30 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
state["agent"] if cfg.checkpoint.resume_from else None,
)

# the optimizer and set up it with Fabric
# Define the optimizer
optimizer = hydra.utils.instantiate(cfg.algo.optimizer, params=agent.parameters(), _convert_="all")

if fabric.is_global_zero:
save_configs(cfg, log_dir)

# Load the state from the checkpoint
if cfg.checkpoint.resume_from:
optimizer.load_state_dict(state["optimizer"])

# Setup agent and optimizer with Fabric
optimizer = fabric.setup_optimizers(optimizer)

# Create a metric aggregator to log the metrics
aggregator = None
if not MetricAggregator.disabled:
aggregator: MetricAggregator = hydra.utils.instantiate(cfg.metric.aggregator, _convert_="all").to(device)

# Local data
if cfg.buffer.size < cfg.algo.rollout_steps:
raise ValueError(
f"The size of the buffer ({cfg.buffer.size}) cannot be lower "
f"than the rollout steps ({cfg.algo.rollout_steps})"
)
rb = ReplayBuffer(
cfg.buffer.size,
cfg.env.num_envs,
Expand All @@ -195,16 +221,26 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
)

# Global variables
last_log = 0
last_train = 0
train_step = 0
policy_step = 0
last_checkpoint = 0
start_iter = (
# + 1 because the checkpoint is at the end of the update step
# (when resuming from a checkpoint, the update at the checkpoint
# is ended and you have to start with the next one)
(state["iter_num"] // fabric.world_size) + 1
if cfg.checkpoint.resume_from
else 1
)
policy_step = state["iter_num"] * cfg.env.num_envs * cfg.algo.rollout_steps if cfg.checkpoint.resume_from else 0
last_log = state["last_log"] if cfg.checkpoint.resume_from else 0
last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0
policy_steps_per_iter = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size)
total_iters = cfg.algo.total_steps // policy_steps_per_iter if not cfg.dry_run else 1
if cfg.checkpoint.resume_from:
cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size

# Warning for log and checkpoint every
if cfg.metric.log_every % policy_steps_per_iter != 0:
if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_iter != 0:
warnings.warn(
f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the "
f"policy_steps_per_iter value ({policy_steps_per_iter}), so "
Expand All @@ -219,13 +255,23 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
"policy_steps_per_iter value."
)

# Linear learning rate scheduler
if cfg.algo.anneal_lr:
from torch.optim.lr_scheduler import PolynomialLR

scheduler = PolynomialLR(optimizer=optimizer, total_iters=total_iters, power=1.0)
if cfg.checkpoint.resume_from:
scheduler.load_state_dict(state["scheduler"])

# Get the first environment observation and start the optimization
step_data = {}
next_obs = envs.reset(seed=cfg.seed)[0] # [N_envs, N_obs]
for k in obs_keys:
if k in cfg.algo.cnn_keys.encoder:
next_obs[k] = next_obs[k].reshape(cfg.env.num_envs, -1, *next_obs[k].shape[-2:])
step_data[k] = next_obs[k][np.newaxis]

for iter_num in range(1, total_iters + 1):
for iter_num in range(start_iter, total_iters + 1):
with torch.inference_mode():
for _ in range(0, cfg.algo.rollout_steps):
policy_step += cfg.env.num_envs * world_size
Expand All @@ -234,16 +280,14 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
# Sample an action given the observation received by the environment
# This calls the `forward` method of the PyTorch module, escaping from Fabric
# because we don't want this to be a synchronization point
torch_obs = prepare_obs(
fabric, next_obs, mlp_keys=cfg.algo.mlp_keys.encoder, num_envs=cfg.env.num_envs
fabric, next_obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs
)
actions, _, values = player(torch_obs)
actions, logprobs, values = player(torch_obs)
if is_continuous:
real_actions = torch.stack(actions, -1).cpu().numpy()
else:
real_actions = torch.stack([act.argmax(dim=-1) for act in actions], axis=-1).cpu().numpy()
real_actions = torch.stack([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy()
actions = torch.cat(actions, -1).cpu().numpy()

# Single environment step
Expand Down Expand Up @@ -274,7 +318,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Update the step data
step_data["dones"] = dones[np.newaxis]
step_data["values"] = values.cpu().numpy()[np.newaxis]
step_data["actions"] = actions.reshape(1, cfg.env.num_envs, -1)
step_data["actions"] = actions[np.newaxis]
step_data["logprobs"] = logprobs.cpu().numpy()[np.newaxis]
step_data["rewards"] = rewards[np.newaxis]
if cfg.buffer.memmap:
step_data["returns"] = np.zeros_like(rewards, shape=(1, *rewards.shape))
Expand All @@ -287,6 +332,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
next_obs = {}
for k in obs_keys:
_obs = obs[k]
if k in cfg.algo.cnn_keys.encoder:
_obs = _obs.reshape(cfg.env.num_envs, -1, *_obs.shape[-2:])
step_data[k] = _obs[np.newaxis]
next_obs[k] = _obs

Expand All @@ -306,7 +353,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Estimate returns with GAE (https://arxiv.org/abs/1506.02438)
with torch.inference_mode():
torch_obs = prepare_obs(fabric, next_obs, mlp_keys=cfg.algo.mlp_keys.encoder, num_envs=cfg.env.num_envs)
torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs)
next_values = player.get_values(torch_obs)
returns, advantages = gae(
local_data["rewards"].to(torch.float64),
Expand All @@ -317,14 +364,22 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
cfg.algo.gamma,
cfg.algo.gae_lambda,
)

# Add returns and advantages to the buffer
local_data["returns"] = returns.float()
local_data["advantages"] = advantages.float()

# Train the agent
if cfg.buffer.share_data and fabric.world_size > 1:
# Gather all the tensors from all the world and reshape them
gathered_data: Dict[str, torch.Tensor] = fabric.all_gather(local_data)
# Flatten the first three dimensions: [World_Size, Buffer_Size, Num_Envs]
gathered_data = {k: v.flatten(start_dim=0, end_dim=2).float() for k, v in gathered_data.items()}
else:
# Flatten the first two dimensions: [Buffer_Size, Num_Envs]
gathered_data = {k: v.flatten(start_dim=0, end_dim=1).float() for k, v in local_data.items()}

with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
train(fabric, agent, optimizer, local_data, aggregator, cfg)
train(fabric, agent, optimizer, gathered_data, aggregator, cfg)
train_step += world_size

# Log metrics
if policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters or cfg.dry_run:
Expand Down Expand Up @@ -357,16 +412,18 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
last_train = train_step

# Checkpoint model
if (
(cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every)
or cfg.dry_run
or (iter_num == total_iters and cfg.checkpoint.save_last)
if (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) or (
iter_num == total_iters and cfg.checkpoint.save_last
):
last_checkpoint = policy_step
state = {
"agent": agent.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict() if cfg.algo.anneal_lr else None,
"iter_num": iter_num * world_size,
"batch_size": cfg.algo.per_rank_batch_size * fabric.world_size,
"last_log": last_log,
"last_checkpoint": last_checkpoint,
}
ckpt_path = os.path.join(log_dir, f"checkpoint/ckpt_{policy_step}_{fabric.global_rank}.ckpt")
fabric.call("on_checkpoint_coupled", fabric=fabric, ckpt_path=ckpt_path, state=state)
Expand Down
Loading

0 comments on commit b9e57ed

Please sign in to comment.