Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/p2e #44

Merged
merged 22 commits into from
Jun 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
226d774
feat: added p2e
michele-milesi Jun 19, 2023
0aee1d5
Fix p2e to be conditioned on both recurrent and stochastic
belerico Jun 20, 2023
1d525b2
Minedojo wrapper for p2e
michele-milesi Jun 20, 2023
f752bbc
Merge branch 'feature/p2e' of github.com:Eclectic-Sheep/sheeprl into …
michele-milesi Jun 20, 2023
cf4df24
fix: minedojo wrapper and p2e hyper-parameters
michele-milesi Jun 21, 2023
c668663
fix(p2e): indices in train() function
michele-milesi Jun 21, 2023
c1b1e82
feat: added encoder wrapper in dreamer_v1
michele-milesi Jun 22, 2023
d66f77c
fix: dreamer variable renaming
michele-milesi Jun 22, 2023
f9deecb
fear: added zero-shot learning + tests
michele-milesi Jun 22, 2023
23a30ee
feat: added zero-shot test
michele-milesi Jun 22, 2023
429d52e
feat: added one-shot p2e
michele-milesi Jun 22, 2023
ea3d582
feat: p2e few shot
michele-milesi Jun 22, 2023
53842f6
fix: few-shot learning
michele-milesi Jun 23, 2023
5711c75
fix: ensemble loss + continue_loss + variable renaming
michele-milesi Jun 23, 2023
9887a4c
feat: added sticky attack and jump in minedojo environment
michele-milesi Jun 23, 2023
ff60c9a
Merge branch 'main' of https://github.com/Eclectic-Sheep/sheeprl into…
belerico Jun 26, 2023
f9ab494
fix: minedojo env
michele-milesi Jun 26, 2023
e719105
Separate hidden_size for transition and representation model
belerico Jun 26, 2023
c3873c7
Single p2e for both zero-shot and finetuning
belerico Jun 26, 2023
2fb051e
Merge branch 'feature/p2e' of https://github.com/Eclectic-Sheep/sheep…
belerico Jun 26, 2023
f7cdcb0
Removed benchmark.py
belerico Jun 26, 2023
0485325
Fix imports
belerico Jun 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sheeprl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from sheeprl.algos.dreamer_v1 import dreamer_v1
from sheeprl.algos.droq import droq
from sheeprl.algos.p2e import p2e
from sheeprl.algos.ppo import ppo, ppo_decoupled
from sheeprl.algos.ppo_continuous import ppo_continuous
from sheeprl.algos.ppo_pixel import ppo_pixel_continuous
Expand Down
95 changes: 73 additions & 22 deletions sheeprl/algos/dreamer_v1/agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional, Tuple
from typing import Dict, Optional, Sequence, Tuple, Union

import torch
import torch.nn.functional as F
Expand All @@ -10,6 +10,50 @@
from sheeprl.algos.dreamer_v1.args import DreamerV1Args
from sheeprl.algos.dreamer_v1.utils import cnn_forward, compute_stochastic_state, init_weights
from sheeprl.models.models import CNN, MLP, DeCNN
from sheeprl.utils.model import ArgsType, ModuleType


class Encoder(nn.Module):
"""The wrapper class for the encoder.

Args:
input_channels (int): the number of channels in input.
hidden_channels (Sequence[int]): the hidden channels of the CNN.
layer_args (ArgsType): the args of the layers of the CNN.
activation (Optional[Union[ModuleType, Sequence[ModuleType]]]): the activation function to use in the CNN.
Default nn.ReLU.
observation_shape (Tuple[int, int, int]): the dimension of the observations, channels first.
Default to (3, 64, 64).
"""

def __init__(
self,
input_channels: int,
hidden_channels: Sequence[int],
layer_args: ArgsType,
activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU,
observation_shape: Tuple[int, int, int] = (3, 64, 64),
) -> None:
super().__init__()
self._module = nn.Sequential(
CNN(
input_channels=input_channels,
hidden_channels=hidden_channels,
layer_args=layer_args,
activation=activation,
),
nn.Flatten(-3, -1),
)
self._observation_shape = observation_shape
with torch.no_grad():
self._output_size = self._module(torch.zeros(*observation_shape)).shape[-1]

@property
def output_size(self) -> None:
return self._output_size

def forward(self, x) -> Tensor:
return self._module(x)


class RecurrentModel(nn.Module):
Expand Down Expand Up @@ -181,6 +225,8 @@ class Actor(nn.Module):
Default to 400.
dense_act (int): the activation function to apply after the dense layers.
Default to nn.ELU.
num_layers (int): the number of MLP layers.
Default to 4.
"""

def __init__(
Expand All @@ -193,12 +239,13 @@ def __init__(
min_std: float = 1e-4,
dense_units: int = 400,
dense_act: nn.Module = nn.ELU,
num_layers: int = 4,
) -> None:
super().__init__()
self.model = MLP(
input_dims=latent_state_size,
output_dim=action_dim * 2 if is_continuous else action_dim,
hidden_sizes=[dense_units, dense_units, dense_units, dense_units],
hidden_sizes=[dense_units] * num_layers,
activation=dense_act,
flatten_dim=None,
)
Expand Down Expand Up @@ -384,13 +431,19 @@ def build_models(
observation_shape (Tuple[int, ...]): the shape of the observations.
is_continuous (bool): whether or not the actions are continuous.
args (DreamerV1Args): the hyper-parameters of Dreamer_v1.
world_model_state (Dict[str, Tensor]): the state of the world model.
Default to None.
actor_state (Dict[str, Tensor]): the state of the actor.
Default to None.
critic_state (Dict[str, Tensor]): the state of the critic.
Default to None.

Returns:
The world model (WorldModel): composed by the encoder, rssm, observation and reward models and the continue model.
The actor (_FabricModule).
The critic (_FabricModule).
"""
n_obs_channels = 1 if args.grayscale_obs else 3
n_obs_channels = 1 if args.grayscale_obs and "minedojo" not in args.env_id.lower() else 3
if args.cnn_channels_multiplier <= 0:
raise ValueError(f"cnn_channels_multiplier must be greater than zero, given {args.cnn_channels_multiplier}")
if args.dense_units <= 0:
Expand All @@ -411,29 +464,26 @@ def build_models(
)

# Define models
encoder = nn.Sequential(
CNN(
input_channels=n_obs_channels,
hidden_channels=(torch.tensor([1, 2, 4, 8]) * args.cnn_channels_multiplier).tolist(),
layer_args={"kernel_size": 4, "stride": 2},
activation=cnn_act,
),
nn.Flatten(-3, -1),
encoder = Encoder(
input_channels=n_obs_channels,
hidden_channels=(torch.tensor([1, 2, 4, 8]) * args.cnn_channels_multiplier).tolist(),
layer_args={"kernel_size": 4, "stride": 2},
activation=cnn_act,
observation_shape=observation_shape,
)
with torch.no_grad():
encoder_output_size = encoder(torch.zeros(*observation_shape)).shape[-1]

recurrent_model = RecurrentModel(action_dim + args.stochastic_size, args.recurrent_state_size)
representation_model = MLP(
input_dims=args.recurrent_state_size + encoder_output_size,
input_dims=args.recurrent_state_size + encoder.output_size,
output_dim=args.stochastic_size * 2,
hidden_sizes=[args.recurrent_state_size],
hidden_sizes=[args.hidden_size],
activation=dense_act,
flatten_dim=None,
)
transition_model = MLP(
input_dims=args.recurrent_state_size,
output_dim=args.stochastic_size * 2,
hidden_sizes=[args.recurrent_state_size],
hidden_sizes=[args.hidden_size],
activation=dense_act,
flatten_dim=None,
)
Expand All @@ -444,10 +494,10 @@ def build_models(
args.min_std,
)
observation_model = nn.Sequential(
nn.Linear(args.stochastic_size + args.recurrent_state_size, encoder_output_size),
nn.Unflatten(1, (encoder_output_size, 1, 1)),
nn.Linear(args.stochastic_size + args.recurrent_state_size, encoder.output_size),
nn.Unflatten(1, (encoder.output_size, 1, 1)),
DeCNN(
input_channels=encoder_output_size,
input_channels=encoder.output_size,
hidden_channels=(torch.tensor([4, 2, 1]) * args.cnn_channels_multiplier).tolist() + [n_obs_channels],
layer_args=[
{"kernel_size": 5, "stride": 2},
Expand All @@ -461,15 +511,15 @@ def build_models(
reward_model = MLP(
input_dims=args.stochastic_size + args.recurrent_state_size,
output_dim=1,
hidden_sizes=[args.dense_units, args.dense_units],
hidden_sizes=[args.dense_units] * args.num_layers,
activation=dense_act,
flatten_dim=None,
)
if args.use_continues:
continue_model = MLP(
input_dims=args.stochastic_size + args.recurrent_state_size,
output_dim=1,
hidden_sizes=[args.dense_units, args.dense_units, args.dense_units],
hidden_sizes=[args.dense_units] * args.num_layers,
activation=dense_act,
flatten_dim=None,
)
Expand All @@ -489,11 +539,12 @@ def build_models(
args.actor_min_std,
args.dense_units,
dense_act,
args.num_layers,
)
critic = MLP(
input_dims=args.stochastic_size + args.recurrent_state_size,
output_dim=1,
hidden_sizes=[args.dense_units, args.dense_units, args.dense_units],
hidden_sizes=[args.dense_units] * args.num_layers,
activation=dense_act,
flatten_dim=None,
)
Expand Down
12 changes: 11 additions & 1 deletion sheeprl/algos/dreamer_v1/args.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Optional
from typing import List, Optional

from sheeprl.algos.args import StandardArgs
from sheeprl.utils.parser import Arg
Expand Down Expand Up @@ -32,6 +32,7 @@ class DreamerV1Args(StandardArgs):
lmbda: float = Arg(default=0.95, help="the lambda for the TD lambda values")
use_continues: bool = Arg(default=False, help="wheter or not to use the continue predictor")
stochastic_size: int = Arg(default=30, help="the dimension of the stochastic state")
hidden_size: int = Arg(default=200, help="the hidden size for the transition and representation model")
recurrent_state_size: int = Arg(default=200, help="the dimension of the recurrent state")
kl_free_nats: float = Arg(default=3.0, help="the minimum value for the kl divergence")
kl_regularizer: float = Arg(default=1.0, help="the scale factor for the kl divergence")
Expand All @@ -46,6 +47,10 @@ class DreamerV1Args(StandardArgs):
actor_min_std: float = Arg(default=1e-4, help="the minimum standard deviation for the actions")
clip_gradients: float = Arg(default=100.0, help="how much to clip the gradient norms")
dense_units: int = Arg(default=400, help="the number of units in dense layers, must be greater than zero")
num_layers: int = Arg(
default=4,
help="the number of MLP layers for every model: actor, critic, reward and possibly the continue model",
)
cnn_channels_multiplier: int = Arg(default=32, help="cnn width multiplication factor, must be greater than zero")
dense_act: str = Arg(
default="ELU",
Expand Down Expand Up @@ -74,3 +79,8 @@ class DreamerV1Args(StandardArgs):
)
clip_rewards: bool = Arg(default=False, help="whether or not to clip rewards using tanh")
grayscale_obs: bool = Arg(default=False, help="whether or not to the observations are grayscale")
mine_min_pitch: int = Arg(default=-60, help="The minimum value of pitch in Minecraft environmnets.")
mine_max_pitch: int = Arg(default=60, help="The maximum value of pitch in Minecraft environmnets.")
mine_start_position: Optional[List[float]] = Arg(
default=None, help="The starting position of the agent in Minecraft environment. (x, y, z, pitch, yaw)"
)
39 changes: 20 additions & 19 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
from lightning.fabric.fabric import _is_using_cli
from lightning.fabric.loggers import TensorBoardLogger
from lightning.fabric.plugins.collectives import TorchCollective
from lightning.fabric.wrappers import _FabricModule
from lightning.fabric.wrappers import _FabricModule, _FabricOptimizer
from tensordict import TensorDict
from tensordict.tensordict import TensorDictBase
from torch.distributions import Bernoulli, Independent, Normal
from torch.optim import Adam, Optimizer
from torch.optim import Adam
from torch.utils.data import BatchSampler
from torchmetrics import MeanMetric

Expand All @@ -41,9 +41,9 @@ def train(
world_model: WorldModel,
actor: _FabricModule,
critic: _FabricModule,
world_optimizer: Optimizer,
actor_optimizer: Optimizer,
critic_optimizer: Optimizer,
world_optimizer: _FabricOptimizer,
actor_optimizer: _FabricOptimizer,
critic_optimizer: _FabricOptimizer,
data: TensorDictBase,
aggregator: MetricAggregator,
args: DreamerV1Args,
Expand Down Expand Up @@ -85,9 +85,9 @@ def train(
world_model (_FabricModule): the world model wrapped with Fabric.
actor (_FabricModule): the actor model wrapped with Fabric.
critic (_FabricModule): the critic model wrapped with Fabric.
world_optimizer (Optimizer): the world optimizer.
actor_optimizer (Optimizer): the actor optimizer.
critic_optimizer (Optimizer): the critic optimizer.
world_optimizer (_FabricOptimizer): the world optimizer.
actor_optimizer (_FabricOptimizer): the actor optimizer.
critic_optimizer (_FabricOptimizer): the critic optimizer.
data (TensorDictBase): the batch of data to use for training.
aggregator (MetricAggregator): the aggregator to print the metrics.
args (DreamerV1Args): the configs.
Expand Down Expand Up @@ -164,7 +164,7 @@ def train(

# compute predictions for terminal steps, if required
if args.use_continues and world_model.continue_model:
qc = Bernoulli(logits=world_model.continue_model(latent_states))
qc = Independent(Bernoulli(logits=world_model.continue_model(latent_states), validate_args=False), 1)
continue_targets = (1 - data["dones"]) * args.gamma
else:
qc = continue_targets = None
Expand Down Expand Up @@ -244,18 +244,19 @@ def train(
predicted_values = Independent(Normal(critic(imagined_trajectories), 1), 1).mean
predicted_rewards = Independent(Normal(world_model.reward_model(imagined_trajectories), 1), 1).mean

# predict the probability that the episode will continue in the imagined states
if args.use_continues and world_model.continue_model:
done_mask = Independent(Bernoulli(logits=world_model.continue_model(imagined_trajectories)), 1).mean
predicted_continues = Independent(Bernoulli(logits=world_model.continue_model(imagined_trajectories)), 1).mean
else:
done_mask = torch.ones_like(predicted_rewards.detach()) * args.gamma
predicted_continues = torch.ones_like(predicted_rewards.detach()) * args.gamma

# compute the lambda_values, by passing as last values the values of the last imagined state
# the dimensions of the lambda_values tensor are
# (horizon, batch_size * sequence_length, recurrent_state_size + stochastic_size)
lambda_values = compute_lambda_values(
predicted_rewards,
predicted_values,
done_mask,
predicted_continues,
last_values=predicted_values[-1],
horizon=args.horizon,
lmbda=args.lmbda,
Expand All @@ -270,32 +271,32 @@ def train(
# in [https://doi.org/10.48550/arXiv.1912.01603](https://doi.org/10.48550/arXiv.1912.01603)
#
# Suppose the case in which the continue model is not used and gamma = .99
# done_mask.shape = (15, 2500, 1)
# done_mask = [
# predicted_continues.shape = (15, 2500, 1)
# predicted_continues = [
# [ [.99], ..., [.99] ], (2500 columns)
# ...
# ] (15 rows)
# torch.ones_like(done_mask[:1]) = [
# torch.ones_like(predicted_continues[:1]) = [
# [ [1.], ..., [1.] ]
# ] (1 row and 2500 columns), the discount of the time step 0 is 1.
# done_mask[:-2] = [
# predicted_continues[:-2] = [
# [ [.99], ..., [.99] ], (2500 columns)
# ...
# ] (13 rows)
# torch.cat((torch.ones_like(done_mask[:1]), done_mask[:-2]), 0) = [
# torch.cat((torch.ones_like(predicted_continues[:1]), predicted_continues[:-2]), 0) = [
# [ [1.], ..., [1.] ], (2500 columns)
# [ [.99], ..., [.99] ],
# ...,
# [ [.99], ..., [.99] ],
# ] (14 rows), the total number of imagined steps is 15, but one is lost because of the values computation
# torch.cumprod(torch.cat((torch.ones_like(done_mask[:1]), done_mask[:-2]), 0), 0) = [
# torch.cumprod(torch.cat((torch.ones_like(predicted_continues[:1]), predicted_continues[:-2]), 0), 0) = [
# [ [1.], ..., [1.] ], (2500 columns)
# [ [.99], ..., [.99] ],
# [ [.9801], ..., [.9801] ],
# ...,
# [ [.8775], ..., [.8775] ],
# ] (14 rows)
discount = torch.cumprod(torch.cat((torch.ones_like(done_mask[:1]), done_mask[:-2]), 0), 0)
discount = torch.cumprod(torch.cat((torch.ones_like(predicted_continues[:1]), predicted_continues[:-2]), 0), 0)

# actor optimization step
actor_optimizer.zero_grad(set_to_none=True)
Expand Down
11 changes: 5 additions & 6 deletions sheeprl/algos/dreamer_v1/loss.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Optional, Tuple

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.distributions import Distribution
from torch.distributions.kl import kl_divergence
Expand Down Expand Up @@ -45,7 +44,7 @@ def reconstruction_loss(
kl_free_nats: float = 3.0,
kl_regularizer: float = 1.0,
qc: Optional[Distribution] = None,
dones: Optional[Tensor] = None,
continue_targets: Optional[Tensor] = None,
continue_scale_factor: float = 10.0,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""
Expand All @@ -65,7 +64,7 @@ def reconstruction_loss(
qc (Bernoulli, optional): the predicted Bernoulli distribution of the terminal steps.
0s for the entries that are relative to a terminal step, 1s otherwise.
Default to None.
dones (Tensor, optional): 1s for the entries that are relative to a terminal step, 0s otherwise.
continue_targets (Tensor, optional): 1s for the entries that are relative to a terminal step, 0s otherwise.
Default to None.
continue_scale_factor (float): the scale factor for the continue loss.
Default to 10.
Expand All @@ -78,11 +77,11 @@ def reconstruction_loss(
reconstruction_loss (Tensor): the value of the overall reconstruction loss.
"""
device = observations.device
continue_loss = torch.tensor(0, device=device)
observation_loss = -qo.log_prob(observations).mean()
reward_loss = -qr.log_prob(rewards).mean()
state_loss = torch.max(torch.tensor(kl_free_nats, device=device), kl_divergence(p, q).mean())
if qc is not None and dones is not None:
continue_loss = continue_scale_factor * F.binary_cross_entropy(qc.probs, dones)
continue_loss = torch.tensor(0, device=device)
if qc is not None and continue_targets is not None:
continue_loss = continue_scale_factor * qc.log_prob(continue_targets)
reconstruction_loss = kl_regularizer * state_loss + observation_loss + reward_loss + continue_loss
return reconstruction_loss, state_loss, reward_loss, observation_loss, continue_loss
Loading