Skip to content

Commit

Permalink
Merge branch 'main' of github.com:Eclectic-Sheep/sheeprl into main
Browse files Browse the repository at this point in the history
  • Loading branch information
belerico committed Jul 18, 2023
2 parents 592338e + 9b2916b commit 59637d2
Show file tree
Hide file tree
Showing 18 changed files with 1,382 additions and 71 deletions.
23 changes: 12 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@
An easy-to-use framework for reinforcement learning in PyTorch, accelerated with [Lightning Fabric](https://lightning.ai/docs/fabric/stable/).
The algorithms sheeped by sheeprl out-of-the-box are:

| Algorithm | Coupled | Decoupled | Recurrent | Pixel | Status |
| ---------- | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ |
| A2C | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :construction: |
| A3C | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | :construction: |
| PPO | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| SAC | :heavy_check_mark: | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: |
| DroQ | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| Dreamer-V1 | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Dreamer-V2 | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Dreamer-V3 | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :construction: |
| Plan2Explore | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Algorithm | Coupled | Decoupled | Recurrent | Pixel | Status |
| ------------------------- | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ |
| A2C | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :construction: |
| A3C | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | :construction: |
| PPO | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| SAC | :heavy_check_mark: | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: |
| DroQ | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| Dreamer-V1 | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Dreamer-V2 | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Dreamer-V3 | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :construction: |
| Plan2Explore (Dreamer V1) | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Plan2Explore (Dreamer V2) | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |

and more are coming soon! [Open a PR](https://github.com/Eclectic-Sheep/sheeprl/pulls) if you have any particular request :sheep:

Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[virtualenvs]
create = true
in-project = true
[build-system]
requires = ["setuptools >= 61.0.0"]
build-backend = "setuptools.build_meta"
Expand Down
3 changes: 2 additions & 1 deletion sheeprl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from sheeprl.algos.dreamer_v1 import dreamer_v1
from sheeprl.algos.dreamer_v2 import dreamer_v2
from sheeprl.algos.droq import droq
from sheeprl.algos.p2e.p2e_dv1 import p2e_dv1
from sheeprl.algos.p2e_dv1 import p2e_dv1
from sheeprl.algos.p2e_dv2 import p2e_dv2
from sheeprl.algos.ppo import ppo, ppo_decoupled
from sheeprl.algos.ppo_continuous import ppo_continuous
from sheeprl.algos.ppo_pixel import ppo_pixel_continuous
Expand Down
14 changes: 7 additions & 7 deletions sheeprl/algos/dreamer_v1/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ class Actor(nn.Module):
Default to 400.
dense_act (int): the activation function to apply after the dense layers.
Default to nn.ELU.
num_layers (int): the number of MLP layers.
mlp_layers (int): the number of MLP layers.
Default to 4.
"""

Expand All @@ -248,13 +248,13 @@ def __init__(
min_std: float = 1e-4,
dense_units: int = 400,
dense_act: nn.Module = nn.ELU,
num_layers: int = 4,
mlp_layers: int = 4,
) -> None:
super().__init__()
self.model = MLP(
input_dims=latent_state_size,
output_dim=np.sum(actions_dim) * 2 if is_continuous else np.sum(actions_dim),
hidden_sizes=[dense_units] * num_layers,
hidden_sizes=[dense_units] * mlp_layers,
activation=dense_act,
flatten_dim=None,
)
Expand Down Expand Up @@ -534,15 +534,15 @@ def build_models(
reward_model = MLP(
input_dims=args.stochastic_size + args.recurrent_state_size,
output_dim=1,
hidden_sizes=[args.dense_units] * args.num_layers,
hidden_sizes=[args.dense_units] * args.mlp_layers,
activation=dense_act,
flatten_dim=None,
)
if args.use_continues:
continue_model = MLP(
input_dims=args.stochastic_size + args.recurrent_state_size,
output_dim=1,
hidden_sizes=[args.dense_units] * args.num_layers,
hidden_sizes=[args.dense_units] * args.mlp_layers,
activation=dense_act,
flatten_dim=None,
)
Expand All @@ -562,12 +562,12 @@ def build_models(
args.actor_min_std,
args.dense_units,
dense_act,
args.num_layers,
args.mlp_layers,
)
critic = MLP(
input_dims=args.stochastic_size + args.recurrent_state_size,
output_dim=1,
hidden_sizes=[args.dense_units] * args.num_layers,
hidden_sizes=[args.dense_units] * args.mlp_layers,
activation=dense_act,
flatten_dim=None,
)
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v1/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class DreamerV1Args(StandardArgs):
actor_min_std: float = Arg(default=1e-4, help="the minimum standard deviation for the actions")
clip_gradients: float = Arg(default=100.0, help="how much to clip the gradient norms")
dense_units: int = Arg(default=400, help="the number of units in dense layers, must be greater than zero")
num_layers: int = Arg(
mlp_layers: int = Arg(
default=4,
help="the number of MLP layers for every model: actor, critic, reward and possibly the continue model",
)
Expand Down
5 changes: 0 additions & 5 deletions sheeprl/algos/dreamer_v2/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,6 @@ class RSSM(nn.Module):
For more information see [https://arxiv.org/abs/2010.02193](https://arxiv.org/abs/2010.02193).
transition_model (nn.Module): the transition model described in [https://arxiv.org/abs/2010.02193](https://arxiv.org/abs/2010.02193).
The model is composed by a multu-layer perceptron to predict the stochastic part of the latent state.
min_std (float, optional): the minimum value of the standard deviation computed by the transition model.
Default to 0.1.
discrete (int, optional): the size of the Categorical variables.
Defaults to 32.
"""
Expand All @@ -237,14 +235,12 @@ def __init__(
recurrent_model: nn.Module,
representation_model: nn.Module,
transition_model: nn.Module,
min_std: Optional[float] = 0.1,
discrete: Optional[int] = 32,
) -> None:
super().__init__()
self.recurrent_model = recurrent_model
self.representation_model = representation_model
self.transition_model = transition_model
self.min_std = min_std
self.discrete = discrete

def dynamic(
Expand Down Expand Up @@ -753,7 +749,6 @@ def build_models(
recurrent_model.apply(init_weights),
representation_model.apply(init_weights),
transition_model.apply(init_weights),
args.min_std,
args.discrete_size,
)
observation_model = MultiDecoder(
Expand Down
1 change: 0 additions & 1 deletion sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,6 @@ def train(
# predict values and rewards
predicted_target_values = target_critic(imagined_trajectories)
predicted_rewards = world_model.reward_model(imagined_trajectories)

if args.use_continues and world_model.continue_model:
done_mask = Independent(Bernoulli(logits=world_model.continue_model(imagined_trajectories)), 1).mean
true_done = (1 - data["dones"]).flatten().reshape(1, -1, 1) * args.gamma
Expand Down
10 changes: 5 additions & 5 deletions sheeprl/algos/dreamer_v2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ def make_env(
task_id = "_".join(env_id.split("_")[1:])
start_position = (
{
"x": args.mine_start_position[0],
"y": args.mine_start_position[1],
"z": args.mine_start_position[2],
"pitch": args.mine_start_position[3],
"yaw": args.mine_start_position[4],
"x": float(args.mine_start_position[0]),
"y": float(args.mine_start_position[1]),
"z": float(args.mine_start_position[2]),
"pitch": float(args.mine_start_position[3]),
"yaw": float(args.mine_start_position[4]),
}
if args.mine_start_position is not None
else None
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import Tensor, nn

from sheeprl.algos.dreamer_v1.agent import RSSM, Actor, Encoder, RecurrentModel, WorldModel
from sheeprl.algos.p2e.p2e_dv1.args import P2EArgs
from sheeprl.algos.p2e_dv1.args import P2EDV1Args
from sheeprl.models.models import MLP, DeCNN
from sheeprl.utils.utils import init_weights

Expand All @@ -17,7 +17,7 @@ def build_models(
actions_dim: Sequence[int],
observation_shape: Tuple[int, ...],
is_continuous: bool,
args: P2EArgs,
args: P2EDV1Args,
world_model_state: Optional[Dict[str, Tensor]] = None,
actor_task_state: Optional[Dict[str, Tensor]] = None,
critic_task_state: Optional[Dict[str, Tensor]] = None,
Expand All @@ -31,7 +31,7 @@ def build_models(
action_dim (int): the dimension of the actions.
observation_shape (Tuple[int, ...]): the shape of the observations.
is_continuous (bool): whether or not the actions are continuous.
args (P2EArgs): the hyper-parameters of Dreamer_v1.
args (P2EDV1Args): the hyper-parameters of Dreamer_v1.
world_model_state (Dict[str, Tensor], optional): the state of the world model.
Default to None.
actor_task_state (Dict[str, Tensor], optional): the state of the actor_task.
Expand Down Expand Up @@ -119,15 +119,15 @@ def build_models(
reward_model = MLP(
input_dims=args.stochastic_size + args.recurrent_state_size,
output_dim=1,
hidden_sizes=[args.dense_units] * args.num_layers,
hidden_sizes=[args.dense_units] * args.mlp_layers,
activation=dense_act,
flatten_dim=None,
)
if args.use_continues:
continue_model = MLP(
input_dims=args.stochastic_size + args.recurrent_state_size,
output_dim=1,
hidden_sizes=[args.dense_units] * args.num_layers,
hidden_sizes=[args.dense_units] * args.mlp_layers,
activation=dense_act,
flatten_dim=None,
)
Expand All @@ -147,12 +147,12 @@ def build_models(
args.actor_min_std,
args.dense_units,
dense_act,
args.num_layers,
args.mlp_layers,
)
critic_task = MLP(
input_dims=args.stochastic_size + args.recurrent_state_size,
output_dim=1,
hidden_sizes=[args.dense_units] * args.num_layers,
hidden_sizes=[args.dense_units] * args.mlp_layers,
activation=dense_act,
flatten_dim=None,
)
Expand All @@ -168,12 +168,12 @@ def build_models(
args.actor_min_std,
args.dense_units,
dense_act,
args.num_layers,
args.mlp_layers,
)
critic_exploration = MLP(
input_dims=args.stochastic_size + args.recurrent_state_size,
output_dim=1,
hidden_sizes=[args.dense_units] * args.num_layers,
hidden_sizes=[args.dense_units] * args.mlp_layers,
activation=dense_act,
flatten_dim=None,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


@dataclass
class P2EArgs(DreamerV1Args):
class P2EDV1Args(DreamerV1Args):
# override
stochastic_size: int = Arg(default=60, help="the dimension of the stochastic state")
hidden_size: int = Arg(default=400, help="the hidden size for the transition and representation model")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from sheeprl.algos.dreamer_v1.agent import Player, WorldModel
from sheeprl.algos.dreamer_v1.loss import actor_loss, critic_loss, reconstruction_loss
from sheeprl.algos.dreamer_v1.utils import cnn_forward, make_env, test
from sheeprl.algos.p2e.p2e_dv1.agent import build_models
from sheeprl.algos.p2e.p2e_dv1.args import P2EArgs
from sheeprl.algos.p2e_dv1.agent import build_models
from sheeprl.algos.p2e_dv1.args import P2EDV1Args
from sheeprl.data.buffers import SequentialReplayBuffer
from sheeprl.models.models import MLP
from sheeprl.utils.callback import CheckpointCallback
Expand All @@ -48,7 +48,7 @@ def train(
critic_task_optimizer: _FabricOptimizer,
data: TensorDictBase,
aggregator: MetricAggregator,
args: P2EArgs,
args: P2EDV1Args,
ensembles: _FabricModule,
ensemble_optimizer: _FabricOptimizer,
actor_exploration: _FabricModule,
Expand Down Expand Up @@ -174,7 +174,7 @@ def train(
loss = 0.0
ensemble_optimizer.zero_grad(set_to_none=True)
for ens in ensembles:
out = ens(torch.cat((priors.detach(), recurrent_states.detach(), data["actions"].detach()), -1))[:-1]
out = ens(torch.cat((posteriors.detach(), recurrent_states.detach(), data["actions"].detach()), -1))[:-1]
next_obs_embedding_dist = Independent(Normal(out, 1), 1)
loss -= next_obs_embedding_dist.log_prob(embedded_obs.detach()[1:]).mean()
loss.backward()
Expand All @@ -190,9 +190,9 @@ def train(
aggregator.update(f"Loss/ensemble_loss", loss.detach().cpu())

# Behaviour Learning Exploration
imagined_stochastic_state = posteriors.detach().reshape(1, -1, args.stochastic_size)
imagined_prior = posteriors.detach().reshape(1, -1, args.stochastic_size)
recurrent_state = recurrent_states.detach().reshape(1, -1, args.recurrent_state_size)
imagined_latent_states = torch.cat((imagined_stochastic_state, recurrent_state), -1)
imagined_latent_state = torch.cat((imagined_prior, recurrent_state), -1)
imagined_trajectories = torch.empty(
args.horizon, batch_size * sequence_length, args.stochastic_size + args.recurrent_state_size, device=device
)
Expand All @@ -203,21 +203,19 @@ def train(

# imagine trajectories in the latent space
for i in range(args.horizon):
actions = torch.cat(actor_exploration(imagined_latent_states.detach()), dim=-1)
actions = torch.cat(actor_exploration(imagined_latent_state.detach()), dim=-1)
imagined_actions[i] = actions
imagined_stochastic_state, recurrent_state = world_model.rssm.imagination(
imagined_stochastic_state, recurrent_state, actions
)
imagined_latent_states = torch.cat((imagined_stochastic_state, recurrent_state), -1)
imagined_trajectories[i] = imagined_latent_states
predicted_values = Independent(Normal(critic_exploration(imagined_trajectories), 1), 1).mean
imagined_prior, recurrent_state = world_model.rssm.imagination(imagined_prior, recurrent_state, actions)
imagined_latent_state = torch.cat((imagined_prior, recurrent_state), -1)
imagined_trajectories[i] = imagined_latent_state
predicted_values = critic_exploration(imagined_trajectories)

# Predict intrinsic reward
next_obs_embedding = torch.zeros(
len(ensembles),
args.horizon,
batch_size * sequence_length,
world_model.encoder.output_size,
embedded_obs.shape[-1],
device=device,
)
for i, ens in enumerate(ensembles):
Expand Down Expand Up @@ -284,22 +282,20 @@ def train(
world_optimizer.zero_grad(set_to_none=True)

# Behaviour Learning Task
imagined_stochastic_state = posteriors.detach().reshape(1, -1, args.stochastic_size)
imagined_prior = posteriors.detach().reshape(1, -1, args.stochastic_size)
recurrent_state = recurrent_states.detach().reshape(1, -1, args.recurrent_state_size)
imagined_latent_states = torch.cat((imagined_stochastic_state, recurrent_state), -1)
imagined_latent_state = torch.cat((imagined_prior, recurrent_state), -1)
imagined_trajectories = torch.empty(
args.horizon, batch_size * sequence_length, args.stochastic_size + args.recurrent_state_size, device=device
)
for i in range(args.horizon):
actions = torch.cat(actor_task(imagined_latent_states.detach()), dim=-1)
imagined_stochastic_state, recurrent_state = world_model.rssm.imagination(
imagined_stochastic_state, recurrent_state, actions
)
imagined_latent_states = torch.cat((imagined_stochastic_state, recurrent_state), -1)
imagined_trajectories[i] = imagined_latent_states
actions = torch.cat(actor_task(imagined_latent_state.detach()), dim=-1)
imagined_prior, recurrent_state = world_model.rssm.imagination(imagined_prior, recurrent_state, actions)
imagined_latent_state = torch.cat((imagined_prior, recurrent_state), -1)
imagined_trajectories[i] = imagined_latent_state

predicted_values = Independent(Normal(critic_task(imagined_trajectories), 1), 1).mean
predicted_rewards = Independent(Normal(world_model.reward_model(imagined_trajectories), 1), 1).mean
predicted_values = critic_task(imagined_trajectories)
predicted_rewards = world_model.reward_model(imagined_trajectories)
if args.use_continues and world_model.continue_model:
predicted_continues = Independent(Bernoulli(logits=world_model.continue_model(imagined_trajectories)), 1).mean
else:
Expand Down Expand Up @@ -351,8 +347,8 @@ def train(

@register_algorithm()
def main():
parser = HfArgumentParser(P2EArgs)
args: P2EArgs = parser.parse_args_into_dataclasses()[0]
parser = HfArgumentParser(P2EDV1Args)
args: P2EDV1Args = parser.parse_args_into_dataclasses()[0]
args.num_envs = 1
torch.set_num_threads(1)

Expand All @@ -368,7 +364,7 @@ def main():
if args.checkpoint_path:
state = fabric.load(args.checkpoint_path)
state["args"]["checkpoint_path"] = args.checkpoint_path
args = P2EArgs(**state["args"])
args = P2EDV1Args(**state["args"])
args.per_rank_batch_size = state["batch_size"] // fabric.world_size
ckpt_path = pathlib.Path(args.checkpoint_path)

Expand All @@ -384,7 +380,7 @@ def main():
root_dir = (
args.root_dir
if args.root_dir is not None
else os.path.join("logs", "p2e", datetime.today().strftime("%Y-%m-%d_%H-%M-%S"))
else os.path.join("logs", "p2e_dv1", datetime.today().strftime("%Y-%m-%d_%H-%M-%S"))
)
run_name = (
args.run_name
Expand Down Expand Up @@ -450,7 +446,7 @@ def main():
MLP(
input_dims=int(np.sum(actions_dim) + args.recurrent_state_size + args.stochastic_size),
output_dim=world_model.encoder.output_size,
hidden_sizes=[args.dense_units] * args.num_layers,
hidden_sizes=[args.dense_units] * args.mlp_layers,
).apply(init_weights)
)
ensembles = nn.ModuleList(ens_list)
Expand Down
File renamed without changes.
Loading

0 comments on commit 59637d2

Please sign in to comment.