Skip to content

Commit

Permalink
Merge pull request #42 from Eclectic-Sheep/fix/sac_ae
Browse files Browse the repository at this point in the history
Fix/sac ae
  • Loading branch information
belerico authored Jun 26, 2023
2 parents c210971 + 92e0f82 commit 304fe5b
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 4 deletions.
6 changes: 4 additions & 2 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,15 +460,17 @@ def main():
buffer_size = (
args.buffer_size // int(args.num_envs * fabric.world_size * args.action_repeat) if not args.dry_run else 2
)
rb = SequentialReplayBuffer(buffer_size, args.num_envs, device="cpu", memmap=args.memmap_buffer)
rb = SequentialReplayBuffer(
buffer_size, args.num_envs, device=fabric.device if args.memmap_buffer else "cpu", memmap=args.memmap_buffer
)
if args.checkpoint_path and args.checkpoint_buffer:
if isinstance(state["rb"], list) and fabric.world_size == len(state["rb"]):
rb = state["rb"][fabric.global_rank]
elif isinstance(state["rb"], SequentialReplayBuffer):
rb = state["rb"]
else:
raise RuntimeError(f"Given {len(state['rb'])}, but {fabric.world_size} processes are instantiated")
step_data = TensorDict({}, batch_size=[args.num_envs], device="cpu")
step_data = TensorDict({}, batch_size=[args.num_envs], device=fabric.device if args.memmap_buffer else "cpu")
expl_decay_steps = state["expl_decay_steps"] if args.checkpoint_path else 0

# Global variables
Expand Down
155 changes: 155 additions & 0 deletions sheeprl/algos/sac_pixel/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# SAC-AutoEncoder (SAC-AE)
Images are everywhere, thus having effective RL approaches that can utilize pixels as input would potentially enable solutions for a wide range of real world applications, for example robotics and videogames. In [SAC-AutoEncoder](https://arxiv.org/abs/1910.01741) the standard [SAC](https://arxiv.org/abs/1801.01290) agent is enriched with a convolutional encoder, which encodes images into features, shared between the actor and critic. Also, to improve the quality of the extracted features, a convolutional decoder is used to reconstruct the input images from the features, effectively creating the Encoder-Decoder architecture.

The architecture is depicted in the following figure:

![](https://eclecticsheep.ai/assets/images/sac_ae.png)

Since learning directly from images can be cumbersome, as the authors have found out, some tricks must be taken into account: i.e.:

1. Deterministic autoencoder: the encoder-decoder architecture is a standard deterministic one, which means that we are not going to learn a distribution over the extracted features conditioned on the input images
2. The encoder will receive the gradients from the critic but not from the actor: receiving the gradients from the actor changes also the Q-function during the actor update, since the encoder is shared between the actor and the critic.
3. To overcame the slowdown in the encoder update due to 2., the convolutional weights of the target Q-function are updated faster than the rest of the network’s parameters (effectively using a $\tau_{\text{enc}} > \tau_{\text{Q}}$)

Since we are learning in a distributed environment we also need to be careful regarding the weights reduction during the backward pass. For all those reasons the models are initialized as follows:

```python
# Define the encoder and decoder and setup them with fabric.
# Then we will set the critic encoder and actor decoder as the unwrapped encoder module:
# we do not need it wrapped with the strategy inside actor and critic
encoder = Encoder(in_channels=in_channels, features_dim=args.features_dim, screen_size=args.screen_size)
decoder = Decoder(
encoder.conv_output_shape,
features_dim=args.features_dim,
screen_size=args.screen_size,
out_channels=in_channels,
)
encoder = fabric.setup_module(encoder)
decoder = fabric.setup_module(decoder)

# Setup actor and critic. Those will initialize with orthogonal weights
# both the actor and critic
actor = SACPixelContinuousActor(
encoder=copy.deepcopy(encoder.module), # Unwrapping the strategy and deepcopy the encoder module
action_dim=act_dim,
hidden_size=args.actor_hidden_size,
action_low=envs.single_action_space.low,
action_high=envs.single_action_space.high,
)
qfs = [
SACPixelQFunction(
input_dim=encoder.output_dim, action_dim=act_dim, hidden_size=args.critic_hidden_size, output_dim=1
)
for _ in range(args.num_critics)
]
critic = SACPixelCritic(encoder=encoder.module, qfs=qfs) # Unwrapping the encoder module. This is already tied to the wrapped encoder
actor = fabric.setup_module(actor)
critic = fabric.setup_module(critic)

# The agent will tied convolutional weights between the encoder actor and critic
agent = SACPixelAgent(
actor,
critic,
target_entropy,
alpha=args.alpha,
tau=args.tau,
encoder_tau=args.encoder_tau,
device=fabric.device,
)

# Optimizers
qf_optimizer, actor_optimizer, alpha_optimizer, encoder_optimizer, decoder_optimizer = fabric.setup_optimizers(
Adam(agent.critic.parameters(), lr=args.q_lr),
Adam(agent.actor.parameters(), lr=args.policy_lr),
Adam([agent.log_alpha], lr=args.alpha_lr, betas=(0.5, 0.999)),
Adam(encoder.parameters(), lr=args.encoder_lr),
Adam(decoder.parameters(), lr=args.decoder_lr, weight_decay=args.decoder_wd),
)
```

The three losses of SAC-AE are the same ones used for SAC, implemented in the `sheeprl/algos/sac/loss.py` file.
To account for the points 2. and 3. above, the training function is the following:

```python
# Prevent OOM for both CPU and GPU memory. Data is memory mapped if args.memmap_buffer=True (which is recommended)
data = data.to(fabric.device)
normalized_obs = data["observations"] / 255.0
normalized_next_obs = data["next_observations"] / 255.0

# Update the soft-critic
next_target_qf_value = agent.get_next_target_q_values(
normalized_next_obs, data["rewards"], data["dones"], args.gamma
)
qf_values = agent.get_q_values(normalized_obs, data["actions"])
qf_loss = critic_loss(qf_values, next_target_qf_value, agent.num_critics)
qf_optimizer.zero_grad(set_to_none=True)
fabric.backward(qf_loss)
qf_optimizer.step()
aggregator.update("Loss/value_loss", qf_loss)

# Update the target networks with EMA. `args.target_network_frequency` is set to 2 by default
if global_step % args.target_network_frequency == 0:
agent.critic_target_ema() # Target update of the qfs only
agent.critic_encoder_target_ema() # Target update of the encoder only

# Update the actor. `args.actor_network_frequency` is set to 2 by default.
# In here the features extracted by the encoder are detach from the computational graph to prevent
# the actor gradients' to update the encoder
if global_step % args.actor_network_frequency == 0:
actions, logprobs = agent.get_actions_and_log_probs(normalized_obs, detach_encoder_features=True)
qf_values = agent.get_q_values(normalized_obs, actions, detach_encoder_features=True)
min_qf_values = torch.min(qf_values, dim=-1, keepdim=True)[0]
actor_loss = policy_loss(agent.alpha, logprobs, min_qf_values)
actor_optimizer.zero_grad(set_to_none=True)
fabric.backward(actor_loss)
actor_optimizer.step()
aggregator.update("Loss/policy_loss", actor_loss)

# Update the entropy value
alpha_loss = entropy_loss(agent.log_alpha, logprobs.detach(), agent.target_entropy)
alpha_optimizer.zero_grad(set_to_none=True)
fabric.backward(alpha_loss)
agent.log_alpha.grad = fabric.all_reduce(agent.log_alpha.grad, group=group)
alpha_optimizer.step()
aggregator.update("Loss/alpha_loss", alpha_loss)

# Update the encoder/decoder. This should reflect the update also to the `agent.critic.encoder` module.
if global_step % args.decoder_update_freq == 0:
hidden = encoder(normalized_obs)
reconstruction = decoder(hidden)
reconstruction_loss = (
F.mse_loss(preprocess_obs(data["observations"], bits=5), reconstruction) # Reconstruction
+ args.decoder_l2_lambda * (0.5 * hidden.pow(2).sum(1)).mean() # L2 penalty on the hidden state
)
encoder_optimizer.zero_grad(set_to_none=True)
decoder_optimizer.zero_grad(set_to_none=True)
fabric.backward(reconstruction_loss)
encoder_optimizer.step()
decoder_optimizer.step()
aggregator.update("Loss/reconstruction_loss", reconstruction_loss)
```

## Agent
The models of the SAC-AE agent are defined in the `agent.py` file in order to have a clearer definition of the components of the agent. Our implementation of SAC-AE assumes that the observations are images of shape `(3, 64, 64)`, while both the ecoder and decoder are fixed as specified in the paper.

## Packages
In order to use a broader set of environments of provided by [Gymnasium](https://gymnasium.farama.org/) it is necessary to install optional packages:

* Mujoco environments: `pip install gymnasium[mujoco]`
* Atari environments: `pip install gymnasium[atari]` and `pip install gymnasium[accept-rom-license]`

## Hyper-parameters
For SAC-AE, we decided to fix the number of environments to `1`, in order to have a clearer and more understandable management of the environment interaction. In addition, we would like to recommend the value of the `per_rank_batch_size` hyper-parameter to the users: the recommended batch size for the SAC-AE agent is 128 for single-process training, if you want to use distributed training, we recommend to divide the batch size by the number of processes and to set the `per_rank_batch_size` hyper-parameter accordingly.

## Atari environments
There are two versions for most Atari environments: one version uses the *frame skip* property by default, whereas the second does not implement it. If the first version is selected, then the value of the `action_repeat` hyper-parameter must be `1`; instead, to select an environment without *frame skip*, it is necessary to insert `NoFrameskip` in the environment id and remove the prefix `ALE/` from it. For instance, the environment `ALE/AirRaid-v5` must be instantiated with `action_repeat=1`, whereas its version without *frame skip* is `AirRaidNoFrameskip-v4` and can be istanziated with any value of `action_repeat` greater than zero.
For more information see the official documentation of [Gymnasium Atari environments](https://gymnasium.farama.org/environments/atari/).

## DMC environments
It is possible to use the environments provided by the [DeepMind Control suite](https://www.deepmind.com/open-source/deepmind-control-suite). To use such environments it is necessary to specify the "dmc" domain in the `env_id` hyper-parameter, e.g., `env_id = dmc_walker_walk` will create an instance of the walker walk environment. For more information about all the environments, check their [paper](https://arxiv.org/abs/1801.00690).

When running DreamerV1 in a DMC environment on a server (or a PC without a video terminal) it could be necessary to add two variables to the command to launch the script: `PYOPENGL_PLATFORM="" MUJOCO_GL=osmesa <command>`. For instance, to run walker walk with DreamerV1 on two gpus (0 and 1) it is necessary to runthe following command: `PYOPENGL_PLATFORM="" MUJOCO_GL=osmesa CUDA_VISIBLE_DEVICES="2,3" lightning run model --devices=2 --accelerator=gpu sheeprl.py sac_pixel --env_id=dmc_walker_walk --action_repeat=2 --capture_video --checkpoint_every=80000 --seed=1`.
Other possibitities for the variable `MUJOCO_GL` are: `GLFW` for rendering to an X11 window or and `EGL` for hardware accelerated headless. (For more information, click [here](https://mujoco.readthedocs.io/en/stable/programming/index.html#using-opengl)).

## Recommendations
Since SAC-AE requires a huge number of steps and consequently a large buffer size, we recommend keeping the buffer on cpu and not moving it to cuda, while mapping it to shared-memory by setting the flag `--memmap_buffer=True` when launhing the script. Furthermore, in order to limit memory usage, we recommend to store the observations in `uint8` format and to normalize the observations just before starting the training one batch at a time. Finally, it is important to remind the user that SAC-AE works only with observations in pixel form, therefore, only environments with observation space that is an instance of `gym.spaces.Box` can be selected when used with gymnasium.
6 changes: 4 additions & 2 deletions sheeprl/algos/sac_pixel/sac_pixel_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,10 @@ def main():

# Local data
buffer_size = args.buffer_size // int(args.num_envs * fabric.world_size) if not args.dry_run else 1
rb = ReplayBuffer(buffer_size, args.num_envs, device="cpu", memmap=args.memmap_buffer)
step_data = TensorDict({}, batch_size=[args.num_envs], device="cpu")
rb = ReplayBuffer(
buffer_size, args.num_envs, device=fabric.device if args.memmap_buffer else "cpu", memmap=args.memmap_buffer
)
step_data = TensorDict({}, batch_size=[args.num_envs], device=fabric.device if args.memmap_buffer else "cpu")

# Global variables
start_time = time.time()
Expand Down
1 change: 1 addition & 0 deletions sheeprl/data/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def add(self, data: Union["ReplayBuffer", TensorDictBase]) -> None:
"`data` must have 2 batch dimensions: [sequence_length, n_envs]. "
"`sequence_length` and `n_envs` should be 1. Shape is: {}".format(data.shape)
)
data = data.to(self.device)
data_len = data.shape[0]
next_pos = (self._pos + data_len) % self._buffer_size
if next_pos < self._pos or (data_len >= self._buffer_size and not self._full):
Expand Down

0 comments on commit 304fe5b

Please sign in to comment.