-
Notifications
You must be signed in to change notification settings - Fork 310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Lightning integration example #2057
Open
svnv-svsv-jm
wants to merge
9
commits into
pytorch:main
Choose a base branch
from
svnv-svsv-jm:feature/lightning-integration
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
fdeeb9a
passed tests
svnv-svsv-jm 36211ff
Use pipe symbol instead of Union
svnv-svsv-jm 7fdddca
Use | instead of Union
svnv-svsv-jm 3db2aef
Use | instead of Union
svnv-svsv-jm 48f040f
Use | instead of Union
svnv-svsv-jm 6f59d8d
handle pr comments
svnv-svsv-jm 7775755
add examples + test for it
svnv-svsv-jm 394571a
allow "from torchrl.trainers import RLTrainingLoop"
svnv-svsv-jm 3b767ce
remove "lightning" module and move all to "trainers"
svnv-svsv-jm File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,6 +50,7 @@ coverage.xml | |
.hypothesis/ | ||
.pytest_cache/ | ||
cover/ | ||
pytest_artifacts | ||
|
||
# Translations | ||
*.mo | ||
|
158 changes: 158 additions & 0 deletions
158
examples/lightning/train_ppo_on_pendulum_with_lightning.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import typing as ty | ||
|
||
import lightning.pytorch as pl | ||
import torch | ||
from lightning.pytorch.loggers import CSVLogger | ||
|
||
from tensordict.nn import TensorDictModule | ||
from tensordict.nn.distributions import NormalParamExtractor | ||
|
||
from torchrl.envs import EnvBase, GymEnv | ||
from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator | ||
from torchrl.objectives import ClipPPOLoss | ||
from torchrl.objectives.value import GAE | ||
from torchrl.trainers import RLTrainingLoop | ||
|
||
|
||
def make_env(env_name: str = "InvertedDoublePendulum-v4", **kwargs: ty.Any) -> EnvBase: | ||
"""Utility function to init an env.""" | ||
env = GymEnv(env_name, **kwargs) | ||
return env | ||
|
||
|
||
def create_nets(base_env: EnvBase) -> ty.Tuple[torch.nn.Module, torch.nn.Module]: | ||
out_features = base_env.action_spec.shape[-1] | ||
actor_nn = MLP( | ||
out_features=2 * out_features, | ||
depth=3, | ||
num_cells=256, | ||
dropout=True, | ||
) | ||
value_nn = MLP( | ||
out_features=1, | ||
depth=3, | ||
num_cells=256, | ||
dropout=True, | ||
) | ||
return actor_nn, value_nn | ||
|
||
|
||
def make_actor(env: EnvBase, actor_nn: torch.nn.Module) -> ProbabilisticActor: | ||
# Actor | ||
actor_net = torch.nn.Sequential( | ||
actor_nn, | ||
NormalParamExtractor(), | ||
) | ||
policy_module = TensorDictModule( | ||
actor_net, | ||
in_keys=["observation"], | ||
out_keys=["loc", "scale"], | ||
).double() | ||
td = env.reset() | ||
policy_module(td) | ||
policy_module = ProbabilisticActor( | ||
module=policy_module, | ||
spec=env.action_spec, | ||
in_keys=["loc", "scale"], | ||
distribution_class=TanhNormal, | ||
distribution_kwargs={ | ||
"min": 0, # env.action_spec.space.minimum, | ||
"max": 1, # env.action_spec.space.maximum, | ||
}, | ||
return_log_prob=True, # we'll need the log-prob for the numerator of the importance weights | ||
) | ||
return policy_module | ||
|
||
|
||
def make_critic( | ||
env: EnvBase, | ||
value_nn: torch.nn.Module, | ||
gamma: float = 0.99, | ||
lmbda: float = 0.95, | ||
) -> ty.Tuple[ValueOperator, GAE]: | ||
value_module = ValueOperator( | ||
module=value_nn, | ||
in_keys=["observation"], | ||
).double() | ||
td = env.reset() | ||
value_module(td) | ||
advantage_module = GAE( | ||
gamma=gamma, | ||
lmbda=lmbda, | ||
value_network=value_module, | ||
average_gae=True, | ||
) | ||
return value_module, advantage_module | ||
|
||
|
||
def make_loss_module( | ||
policy_module, | ||
value_module, | ||
advantage_module, | ||
entropy_eps: float = 1e-4, | ||
clip_epsilon: float = 0.2, | ||
loss_function: str = "smooth_l1", | ||
) -> ClipPPOLoss: | ||
loss_module = ClipPPOLoss( | ||
actor=policy_module, | ||
critic=value_module, | ||
clip_epsilon=clip_epsilon, | ||
entropy_bonus=bool(entropy_eps), | ||
entropy_coef=entropy_eps, | ||
critic_coef=1.0, | ||
loss_critic_type=loss_function, | ||
) | ||
loss_module.set_keys(value_target=advantage_module.value_target_key) | ||
return loss_module | ||
|
||
|
||
class PPOPendulum(RLTrainingLoop): | ||
def make_env(self) -> EnvBase: | ||
"""You have to implement this method, which has to take no inputs and return | ||
your environment.""" | ||
return make_env() | ||
|
||
|
||
def create_model() -> pl.LightningModule: | ||
env = make_env() | ||
actorn_nn, critic_nn = create_nets(env) | ||
policy_module = make_actor(env, actorn_nn) | ||
value_module, advantage_module = make_critic(env, critic_nn) | ||
loss_module = make_loss_module(policy_module, value_module, advantage_module) | ||
frame_skip = 1 | ||
frames_per_batch = frame_skip * 5 | ||
total_frames = 100 | ||
model = PPOPendulum( | ||
loss_module=loss_module, | ||
policy_module=policy_module, | ||
advantage_module=advantage_module, | ||
frame_skip=frame_skip, | ||
frames_per_batch=frames_per_batch, | ||
total_frames=total_frames, | ||
use_checkpoint_callback=True, | ||
) | ||
return model | ||
|
||
|
||
def main() -> None: | ||
model = create_model() | ||
trainer = pl.Trainer( | ||
accelerator="cpu", | ||
max_steps=4, | ||
val_check_interval=2, | ||
log_every_n_steps=1, | ||
logger=CSVLogger( | ||
save_dir="pytest_artifacts", | ||
name=model.__class__.__name__, | ||
), | ||
) | ||
trainer.fit(model) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
"""Tests for Lightning integration.""" | ||
|
||
import lightning.pytorch as pl | ||
import pytest | ||
from lightning.pytorch.loggers import CSVLogger | ||
|
||
from torchrl.trainers.ppo import PPOPendulum | ||
|
||
|
||
def test_example_ppo_pl() -> None: | ||
"""Tray to run the example from here, | ||
to make sure it is tested.""" | ||
import os, sys | ||
|
||
sys.path.append(os.path.join("examples", "lightning")) | ||
from train_ppo_on_pendulum_with_lightning import main | ||
|
||
main() | ||
|
||
|
||
def test_ppo() -> None: | ||
"""Test PPO on InvertedDoublePendulum.""" | ||
frame_skip = 1 | ||
frames_per_batch = frame_skip * 5 | ||
total_frames = 100 | ||
model = PPOPendulum( | ||
frame_skip=frame_skip, | ||
frames_per_batch=frames_per_batch, | ||
total_frames=total_frames, | ||
n_mlp_layers=4, | ||
use_checkpoint_callback=True, | ||
) | ||
# Rollout | ||
rollout = model.env.rollout(3) | ||
print(f"Rollout of three steps: {rollout}") | ||
print(f"Shape of the rollout TensorDict: {rollout.batch_size}") | ||
print(f"Env reset: {model.env.reset()}") | ||
print(f"Running policy: {model.policy_module(model.env.reset())}") | ||
# Collector | ||
model.setup() | ||
collector = model.train_dataloader() | ||
for _, tensordict_data in enumerate(collector): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How long is this training loop? |
||
print(f"Tensordict data:\n{tensordict_data}") | ||
batch_size = int(tensordict_data.batch_size[0]) | ||
rollout_size = int(tensordict_data.batch_size[1]) | ||
assert rollout_size == int(frames_per_batch // frame_skip) | ||
assert batch_size == model.num_envs | ||
break | ||
# Training | ||
max_steps = 2 | ||
trainer = pl.Trainer( | ||
accelerator="cpu", | ||
max_steps=max_steps, | ||
val_check_interval=2, | ||
log_every_n_steps=1, | ||
logger=CSVLogger( | ||
save_dir="pytest_artifacts", | ||
name=model.__class__.__name__, | ||
), | ||
) | ||
trainer.fit(model) | ||
# Test we stopped quite because the max number of steps was reached | ||
assert max_steps >= trainer.global_step | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main([__file__, "-x", "-s"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should belong to test_libs.py or test_trainer.py