Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Lightning integration example #2057

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ coverage.xml
.hypothesis/
.pytest_cache/
cover/
pytest_artifacts

# Translations
*.mo
Expand Down
58 changes: 58 additions & 0 deletions test/test_lightning_integration.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should belong to test_libs.py or test_trainer.py

Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""Tests for Lightning integration."""

import pytest

import lightning.pytorch as pl
from lightning.pytorch.loggers import CSVLogger

from torchrl.lightning.ppo import PPOPendulum


def test_ppo() -> None:
"""Test PPO on InvertedDoublePendulum."""
frame_skip = 1
frames_per_batch = frame_skip * 5
total_frames = 100
model = PPOPendulum(
frame_skip=frame_skip,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
n_mlp_layers=4,
use_checkpoint_callback=True,
)
# Rollout
rollout = model.env.rollout(3)
print(f"Rollout of three steps: {rollout}")
print(f"Shape of the rollout TensorDict: {rollout.batch_size}")
print(f"Env reset: {model.env.reset()}")
print(f"Running policy: {model.policy_module(model.env.reset())}")
print(f"Running value: {model.value_module(model.env.reset())}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need these prints?

# Collector
model.setup()
collector = model.train_dataloader()
for _, tensordict_data in enumerate(collector):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How long is this training loop?

print(f"Tensordict data:\n{tensordict_data}")
batch_size = int(tensordict_data.batch_size[0])
rollout_size = int(tensordict_data.batch_size[1])
assert rollout_size == int(frames_per_batch // frame_skip)
assert batch_size == model.num_envs
break
# Training
max_steps = 2
trainer = pl.Trainer(
accelerator="cpu",
max_steps=max_steps,
val_check_interval=2,
log_every_n_steps=1,
logger=CSVLogger(
save_dir="pytest_artifacts",
name=model.__class__.__name__,
),
)
trainer.fit(model)
# Test we stopped quite because the max number of steps was reached
assert max_steps >= trainer.global_step


if __name__ == "__main__":
pytest.main([__file__, "-x", "-s"])
1 change: 1 addition & 0 deletions torchrl/lightning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .loops import *
205 changes: 205 additions & 0 deletions torchrl/lightning/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
"""Creates a helper class for more complex models."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing headers


__all__ = ["BaseRL"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't use __all__ but import relevant classes in __init__.py


import typing as ty
from loguru import logger
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unless this is packed with lightning, we won't be using it.
torchrl has a logger under torchrl._utils

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, forgot to remove loguru, which is used in the project I copied most of the code from. I will check out torchrl's logger or remove logging entirely for this.


import torch
from tensordict.nn import TensorDictModule # type: ignore
from tensordict.nn.distributions import NormalParamExtractor # type: ignore

from torchrl.envs import EnvBase, GymEnv
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.objectives.value import GAE
from torchrl.objectives import ClipPPOLoss, CQLLoss, SoftUpdate

from .loops import RLTrainingLoop


class BaseRL(RLTrainingLoop):
"""Base for RL Model. See: https://pytorch.org/rl/tutorials/coding_ppo.html#training-loop"""

def __init__(
self,
actor_nn: torch.nn.Module,
value_nn: torch.nn.Module,
env_name: str = "InvertedDoublePendulum-v4",
model: str = "ppo",
gamma: float = 0.99,
lmbda: float = 0.95,
entropy_eps: float = 1e-4,
clip_epsilon: float = 0.2,
alpha_init: float = 1,
loss_function: str = "smooth_l1",
flatten_state: bool = False,
tau: float = 1e-2,
**kwargs: ty.Any,
) -> None:
"""
Args:
env (ty.Union[str, EnvBase], optional): _description_. Defaults to "InvertedDoublePendulum-v4".
num_cells (int, optional): _description_. Defaults to 256.
lr (float, optional): _description_. Defaults to 3e-4.
max_grad_norm (float, optional): _description_. Defaults to 1.0.
frame_skip (int, optional): _description_. Defaults to 1.
frames_per_batch (int, optional): _description_. Defaults to 100.
total_frames (int, optional): _description_. Defaults to 100_000.
accelerator (ty.Union[str, torch.device], optional): _description_. Defaults to "cpu".
sub_batch_size (int, optional):
Cardinality of the sub-samples gathered from the current data in the inner loop.
Defaults to `1`.
clip_epsilon (float, optional): _description_. Defaults to 0.2.
gamma (float, optional): _description_. Defaults to 0.99.
lmbda (float, optional): _description_. Defaults to 0.95.
entropy_eps (float, optional): _description_. Defaults to 1e-4.
lr_monitor (str, optional): _description_. Defaults to "loss/train".
lr_monitor_strict (bool, optional): _description_. Defaults to False.
rollout_max_steps (int, optional): _description_. Defaults to 1000.
n_mlp_layers (int, optional): _description_. Defaults to 3.
flatten (bool, optional): _description_. Defaults to False.
flatten_start_dim (int, optional): _description_. Defaults to 0.
legacy (bool, optional): _description_. Defaults to False.
automatic_optimization (bool, optional): _description_. Defaults to True.
"""
self.save_hyperparameters(
ignore=[
"base_env",
"env",
"loss_module",
"policy_module",
"value_module",
"actor_nn",
"value_nn",
]
)
self.gamma = gamma
self.lmbda = lmbda
self.entropy_eps = entropy_eps
self.env_name = env_name
self.device_info = kwargs.get("device", "cpu")
self.frame_skip = kwargs.get("frame_skip", 1)
# Environment
base_env = self.make_env()
# Env transformations
env = self.transformed_env(base_env)
# Specs
observation_spec = base_env.observation_spec["observation"]
action_space = base_env.action_spec
# Sanity check
logger.debug(f"observation_spec: {observation_spec}")
logger.debug(f"reward_spec: {base_env.reward_spec}")
logger.debug(f"done_spec: {base_env.done_spec}")
logger.debug(f"action_spec: {base_env.action_spec}")
logger.debug(f"state_spec: {base_env.state_spec}")
# Actor
out_features = action_space.shape[-1]
logger.debug(f"MLP out_shape: {out_features}")
actor_net = torch.nn.Sequential(
torch.nn.Flatten(0) if flatten_state else torch.nn.Identity(),
actor_nn,
NormalParamExtractor(),
)
logger.debug(f"Initialized actor: {actor_net}")
policy_module = TensorDictModule(
actor_net,
in_keys=["observation"],
out_keys=["loc", "scale"],
)
td = env.reset()
policy_module(td)
policy_module = ProbabilisticActor(
module=policy_module,
spec=env.action_spec,
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
distribution_kwargs={
"min": 0, # env.action_spec.space.minimum,
"max": 1, # env.action_spec.space.maximum,
},
return_log_prob=True, # we'll need the log-prob for the numerator of the importance weights
)
logger.debug(f"Initialized policy: {policy_module}")
# Critic and loss depend on the model
target_net_updater = None
if model in ["cql"]:
advantage_module = None
# Q-Value
value_module = ValueOperator(
module=value_nn,
in_keys=["observation", "action"],
out_keys=["state_action_value"],
)
td = env.reset()
td = env.rand_action(td)
td = env.step(td)
td = value_module(td)
logger.debug(f"Initialized value_module: {td}")
# Loss CQL
loss_module = CQLLoss(
actor_network=policy_module,
qvalue_network=value_module,
action_spec=env.action_spec,
alpha_init=alpha_init,
loss_function=loss_function,
)
loss_module.make_value_estimator(gamma=gamma)
target_net_updater = SoftUpdate(loss_module, tau=tau)
elif model in ["ppo"]:
# Value
value_net = torch.nn.Sequential(
torch.nn.Flatten(1) if flatten_state else torch.nn.Identity(),
value_nn,
)
value_module = ValueOperator(
module=value_net,
in_keys=["observation"],
)
td = env.reset()
value_module(td)
# Loss PPO
advantage_module = GAE(
gamma=gamma,
lmbda=lmbda,
value_network=value_module,
average_gae=True,
)
loss_module = ClipPPOLoss(
actor=policy_module,
critic=value_module,
clip_epsilon=clip_epsilon,
entropy_bonus=bool(entropy_eps),
entropy_coef=entropy_eps,
# these keys match by default but we set this for completeness
critic_coef=1.0,
# gamma=0.99,
loss_critic_type=loss_function,
)
loss_module.set_keys(value_target=advantage_module.value_target_key)
else:
raise ValueError(f"Unrecognized model {model}")
# Call superclass
super().__init__(
loss_module=loss_module,
policy_module=policy_module,
value_module=value_module,
target_net_updater=target_net_updater,
**kwargs,
)
self.advantage_module = advantage_module

def make_env(self) -> EnvBase:
"""Utility function to init an env.

Args:
env (ty.Union[str, EnvBase]): _description_

Returns:
EnvBase: _description_
"""
env = GymEnv(
self.env_name,
device=self.device_info,
frame_skip=self.frame_skip,
)
return env
33 changes: 33 additions & 0 deletions torchrl/lightning/accelerators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
__all__ = ["find_device"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing headers
we don't use __all__


import typing as ty

import torch
from lightning.pytorch.accelerators.cuda import CUDAAccelerator
from lightning.pytorch.accelerators.mps import MPSAccelerator


def find_device(accelerator: ty.Union[torch.device, str] = None) -> torch.device:
svnv-svsv-jm marked this conversation as resolved.
Show resolved Hide resolved
"""Automatically finds system's device for PyTorch."""
if accelerator is None:
accelerator = "auto"
if isinstance(accelerator, torch.device):
return accelerator # pragma: no cover
device = _choose_auto_accelerator(accelerator)
assert device in ("cpu", "mps", "cuda")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't use assert in the main code base, only in tests. Raise a ValueError if the device isn't the one expected

return torch.device(device)


def _choose_auto_accelerator(accelerator_flag: str) -> str:
"""Choose the accelerator type (str) based on availability when `accelerator='auto'`."""
accelerator_flag = accelerator_flag.lower()
assert accelerator_flag in ("auto", "cpu", "mps", "cuda")
try:
if accelerator_flag == "auto":
if MPSAccelerator.is_available():
return "mps"
if CUDAAccelerator.is_available(): # pragma: no cover
return "cuda" # pragma: no cover
return "cpu" # pragma: no cover
except Exception: # pragma: no cover
return "cpu" # pragma: no cover
83 changes: 83 additions & 0 deletions torchrl/lightning/collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# pylint: disable=abstract-method
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing headers
why the pylint annotation?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get Method '__getitem__' is abstract in class 'Dataset' but is not overridden in child class 'CollectorDataset'Pylint[W0223:abstract-method](https://pylint.readthedocs.io/en/latest/user_guide/messages/warning/abstract-method.html)

"""Create the `CollectorDataset` needed by the `pl.Trainer`."""

__all__ = ["CollectorDataset"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't use __all__


import typing as ty

import torch
from tensordict.nn import TensorDictModule # type: ignore
from tensordict import TensorDict # type: ignore

from torch.utils.data import IterableDataset
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import EnvBase

from .accelerators import find_device


class CollectorDataset(IterableDataset):
"""Iterable Dataset containing the `ReplayBuffer` which will be
updated with new experiences during training, and the `SyncDataCollector`."""

def __init__(
self,
collector: ty.Optional[SyncDataCollector] = None,
env: ty.Optional[EnvBase] = None,
svnv-svsv-jm marked this conversation as resolved.
Show resolved Hide resolved
policy_module: TensorDictModule = None,
frames_per_batch: int = 1,
total_frames: int = 1000,
device: torch.device = find_device(),
split_trajs: bool = False,
batch_size: int = 1,
init_random_frames: int = 1,
) -> None:
# Attributes
self.batch_size = batch_size
self.device = device
self.frames_per_batch = frames_per_batch
self.total_frames = total_frames
# Collector
if collector is None:
if env is None:
raise ValueError(
"Please provide an environment when not providing a collector."
)
self.collector = SyncDataCollector(
env,
policy_module,
frames_per_batch=self.frames_per_batch,
total_frames=self.total_frames,
device=self.device,
storing_device=self.device,
split_trajs=split_trajs,
init_random_frames=init_random_frames,
)
else:
self.collector = collector
# ReplayBuffer
self.replay_buffer = ReplayBuffer(
storage=LazyTensorStorage(frames_per_batch),
sampler=SamplerWithoutReplacement(),
batch_size=self.batch_size,
)
# States
self.length: ty.Optional[int] = None
svnv-svsv-jm marked this conversation as resolved.
Show resolved Hide resolved

def __iter__(self) -> ty.Iterator[TensorDict]:
"""Yield experiences from `SyncDataCollector` and store them in `ReplayBuffer`."""
i = 0
for i, tensordict_data in enumerate(self.collector):
assert isinstance(tensordict_data, TensorDict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no assert in the main codebase

data_view: TensorDict = tensordict_data.reshape(-1)
self.replay_buffer.extend(data_view.cpu())
yield tensordict_data.to(self.device)
self.length = i

def sample(self, **kwargs: ty.Any) -> TensorDict:
"""Sample from `ReplayBuffer`."""
data: TensorDict = self.replay_buffer.sample(**kwargs)
return data.to(self.device)
Loading
Loading