Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Lightning integration example #2057

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ coverage.xml
.hypothesis/
.pytest_cache/
cover/
pytest_artifacts

# Translations
*.mo
Expand Down
158 changes: 158 additions & 0 deletions examples/lightning/train_ppo_on_pendulum_with_lightning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import typing as ty

import lightning.pytorch as pl
import torch
from lightning.pytorch.loggers import CSVLogger

from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor

from torchrl.envs import EnvBase, GymEnv
from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
from torchrl.trainers import RLTrainingLoop


def make_env(env_name: str = "InvertedDoublePendulum-v4", **kwargs: ty.Any) -> EnvBase:
"""Utility function to init an env."""
env = GymEnv(env_name, **kwargs)
return env


def create_nets(base_env: EnvBase) -> ty.Tuple[torch.nn.Module, torch.nn.Module]:
out_features = base_env.action_spec.shape[-1]
actor_nn = MLP(
out_features=2 * out_features,
depth=3,
num_cells=256,
dropout=True,
)
value_nn = MLP(
out_features=1,
depth=3,
num_cells=256,
dropout=True,
)
return actor_nn, value_nn


def make_actor(env: EnvBase, actor_nn: torch.nn.Module) -> ProbabilisticActor:
# Actor
actor_net = torch.nn.Sequential(
actor_nn,
NormalParamExtractor(),
)
policy_module = TensorDictModule(
actor_net,
in_keys=["observation"],
out_keys=["loc", "scale"],
).double()
td = env.reset()
policy_module(td)
policy_module = ProbabilisticActor(
module=policy_module,
spec=env.action_spec,
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
distribution_kwargs={
"min": 0, # env.action_spec.space.minimum,
"max": 1, # env.action_spec.space.maximum,
},
return_log_prob=True, # we'll need the log-prob for the numerator of the importance weights
)
return policy_module


def make_critic(
env: EnvBase,
value_nn: torch.nn.Module,
gamma: float = 0.99,
lmbda: float = 0.95,
) -> ty.Tuple[ValueOperator, GAE]:
value_module = ValueOperator(
module=value_nn,
in_keys=["observation"],
).double()
td = env.reset()
value_module(td)
advantage_module = GAE(
gamma=gamma,
lmbda=lmbda,
value_network=value_module,
average_gae=True,
)
return value_module, advantage_module


def make_loss_module(
policy_module,
value_module,
advantage_module,
entropy_eps: float = 1e-4,
clip_epsilon: float = 0.2,
loss_function: str = "smooth_l1",
) -> ClipPPOLoss:
loss_module = ClipPPOLoss(
actor=policy_module,
critic=value_module,
clip_epsilon=clip_epsilon,
entropy_bonus=bool(entropy_eps),
entropy_coef=entropy_eps,
critic_coef=1.0,
loss_critic_type=loss_function,
)
loss_module.set_keys(value_target=advantage_module.value_target_key)
return loss_module


class PPOPendulum(RLTrainingLoop):
def make_env(self) -> EnvBase:
"""You have to implement this method, which has to take no inputs and return
your environment."""
return make_env()


def create_model() -> pl.LightningModule:
env = make_env()
actorn_nn, critic_nn = create_nets(env)
policy_module = make_actor(env, actorn_nn)
value_module, advantage_module = make_critic(env, critic_nn)
loss_module = make_loss_module(policy_module, value_module, advantage_module)
frame_skip = 1
frames_per_batch = frame_skip * 5
total_frames = 100
model = PPOPendulum(
loss_module=loss_module,
policy_module=policy_module,
advantage_module=advantage_module,
frame_skip=frame_skip,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
use_checkpoint_callback=True,
)
return model


def main() -> None:
model = create_model()
trainer = pl.Trainer(
accelerator="cpu",
max_steps=4,
val_check_interval=2,
log_every_n_steps=1,
logger=CSVLogger(
save_dir="pytest_artifacts",
name=model.__class__.__name__,
),
)
trainer.fit(model)


if __name__ == "__main__":
main()
67 changes: 67 additions & 0 deletions test/test_lightning_integration.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should belong to test_libs.py or test_trainer.py

Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Tests for Lightning integration."""

import lightning.pytorch as pl
import pytest
from lightning.pytorch.loggers import CSVLogger

from torchrl.trainers.ppo import PPOPendulum


def test_example_ppo_pl() -> None:
"""Tray to run the example from here,
to make sure it is tested."""
import os, sys

sys.path.append(os.path.join("examples", "lightning"))
from train_ppo_on_pendulum_with_lightning import main

main()


def test_ppo() -> None:
"""Test PPO on InvertedDoublePendulum."""
frame_skip = 1
frames_per_batch = frame_skip * 5
total_frames = 100
model = PPOPendulum(
frame_skip=frame_skip,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
n_mlp_layers=4,
use_checkpoint_callback=True,
)
# Rollout
rollout = model.env.rollout(3)
print(f"Rollout of three steps: {rollout}")
print(f"Shape of the rollout TensorDict: {rollout.batch_size}")
print(f"Env reset: {model.env.reset()}")
print(f"Running policy: {model.policy_module(model.env.reset())}")
# Collector
model.setup()
collector = model.train_dataloader()
for _, tensordict_data in enumerate(collector):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How long is this training loop?

print(f"Tensordict data:\n{tensordict_data}")
batch_size = int(tensordict_data.batch_size[0])
rollout_size = int(tensordict_data.batch_size[1])
assert rollout_size == int(frames_per_batch // frame_skip)
assert batch_size == model.num_envs
break
# Training
max_steps = 2
trainer = pl.Trainer(
accelerator="cpu",
max_steps=max_steps,
val_check_interval=2,
log_every_n_steps=1,
logger=CSVLogger(
save_dir="pytest_artifacts",
name=model.__class__.__name__,
),
)
trainer.fit(model)
# Test we stopped quite because the max number of steps was reached
assert max_steps >= trainer.global_step


if __name__ == "__main__":
pytest.main([__file__, "-x", "-s"])
2 changes: 2 additions & 0 deletions torchrl/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .accelerators import find_device
from .loops import RLTrainingLoop
from .trainers import (
BatchSubSampler,
ClearCudaCache,
Expand Down
Loading