-
Notifications
You must be signed in to change notification settings - Fork 310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Lightning integration example #2057
base: main
Are you sure you want to change the base?
Changes from 1 commit
fdeeb9a
36211ff
7fdddca
3db2aef
48f040f
6f59d8d
7775755
394571a
3b767ce
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,6 +50,7 @@ coverage.xml | |
.hypothesis/ | ||
.pytest_cache/ | ||
cover/ | ||
pytest_artifacts | ||
|
||
# Translations | ||
*.mo | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
"""Tests for Lightning integration.""" | ||
|
||
import pytest | ||
|
||
import lightning.pytorch as pl | ||
from lightning.pytorch.loggers import CSVLogger | ||
|
||
from torchrl.lightning.ppo import PPOPendulum | ||
|
||
|
||
def test_ppo() -> None: | ||
"""Test PPO on InvertedDoublePendulum.""" | ||
frame_skip = 1 | ||
frames_per_batch = frame_skip * 5 | ||
total_frames = 100 | ||
model = PPOPendulum( | ||
frame_skip=frame_skip, | ||
frames_per_batch=frames_per_batch, | ||
total_frames=total_frames, | ||
n_mlp_layers=4, | ||
use_checkpoint_callback=True, | ||
) | ||
# Rollout | ||
rollout = model.env.rollout(3) | ||
print(f"Rollout of three steps: {rollout}") | ||
print(f"Shape of the rollout TensorDict: {rollout.batch_size}") | ||
print(f"Env reset: {model.env.reset()}") | ||
print(f"Running policy: {model.policy_module(model.env.reset())}") | ||
print(f"Running value: {model.value_module(model.env.reset())}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need these prints? |
||
# Collector | ||
model.setup() | ||
collector = model.train_dataloader() | ||
for _, tensordict_data in enumerate(collector): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How long is this training loop? |
||
print(f"Tensordict data:\n{tensordict_data}") | ||
batch_size = int(tensordict_data.batch_size[0]) | ||
rollout_size = int(tensordict_data.batch_size[1]) | ||
assert rollout_size == int(frames_per_batch // frame_skip) | ||
assert batch_size == model.num_envs | ||
break | ||
# Training | ||
max_steps = 2 | ||
trainer = pl.Trainer( | ||
accelerator="cpu", | ||
max_steps=max_steps, | ||
val_check_interval=2, | ||
log_every_n_steps=1, | ||
logger=CSVLogger( | ||
save_dir="pytest_artifacts", | ||
name=model.__class__.__name__, | ||
), | ||
) | ||
trainer.fit(model) | ||
# Test we stopped quite because the max number of steps was reached | ||
assert max_steps >= trainer.global_step | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main([__file__, "-x", "-s"]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .loops import * |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
"""Creates a helper class for more complex models.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing headers |
||
|
||
__all__ = ["BaseRL"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we don't use |
||
|
||
import typing as ty | ||
from loguru import logger | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unless this is packed with lightning, we won't be using it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, forgot to remove |
||
|
||
import torch | ||
from tensordict.nn import TensorDictModule # type: ignore | ||
from tensordict.nn.distributions import NormalParamExtractor # type: ignore | ||
|
||
from torchrl.envs import EnvBase, GymEnv | ||
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator | ||
from torchrl.objectives.value import GAE | ||
from torchrl.objectives import ClipPPOLoss, CQLLoss, SoftUpdate | ||
|
||
from .loops import RLTrainingLoop | ||
|
||
|
||
class BaseRL(RLTrainingLoop): | ||
"""Base for RL Model. See: https://pytorch.org/rl/tutorials/coding_ppo.html#training-loop""" | ||
|
||
def __init__( | ||
self, | ||
actor_nn: torch.nn.Module, | ||
value_nn: torch.nn.Module, | ||
env_name: str = "InvertedDoublePendulum-v4", | ||
model: str = "ppo", | ||
gamma: float = 0.99, | ||
lmbda: float = 0.95, | ||
entropy_eps: float = 1e-4, | ||
clip_epsilon: float = 0.2, | ||
alpha_init: float = 1, | ||
loss_function: str = "smooth_l1", | ||
flatten_state: bool = False, | ||
tau: float = 1e-2, | ||
**kwargs: ty.Any, | ||
) -> None: | ||
""" | ||
Args: | ||
env (ty.Union[str, EnvBase], optional): _description_. Defaults to "InvertedDoublePendulum-v4". | ||
num_cells (int, optional): _description_. Defaults to 256. | ||
lr (float, optional): _description_. Defaults to 3e-4. | ||
max_grad_norm (float, optional): _description_. Defaults to 1.0. | ||
frame_skip (int, optional): _description_. Defaults to 1. | ||
frames_per_batch (int, optional): _description_. Defaults to 100. | ||
total_frames (int, optional): _description_. Defaults to 100_000. | ||
accelerator (ty.Union[str, torch.device], optional): _description_. Defaults to "cpu". | ||
sub_batch_size (int, optional): | ||
Cardinality of the sub-samples gathered from the current data in the inner loop. | ||
Defaults to `1`. | ||
clip_epsilon (float, optional): _description_. Defaults to 0.2. | ||
gamma (float, optional): _description_. Defaults to 0.99. | ||
lmbda (float, optional): _description_. Defaults to 0.95. | ||
entropy_eps (float, optional): _description_. Defaults to 1e-4. | ||
lr_monitor (str, optional): _description_. Defaults to "loss/train". | ||
lr_monitor_strict (bool, optional): _description_. Defaults to False. | ||
rollout_max_steps (int, optional): _description_. Defaults to 1000. | ||
n_mlp_layers (int, optional): _description_. Defaults to 3. | ||
flatten (bool, optional): _description_. Defaults to False. | ||
flatten_start_dim (int, optional): _description_. Defaults to 0. | ||
legacy (bool, optional): _description_. Defaults to False. | ||
automatic_optimization (bool, optional): _description_. Defaults to True. | ||
""" | ||
self.save_hyperparameters( | ||
ignore=[ | ||
"base_env", | ||
"env", | ||
"loss_module", | ||
"policy_module", | ||
"value_module", | ||
"actor_nn", | ||
"value_nn", | ||
] | ||
) | ||
self.gamma = gamma | ||
self.lmbda = lmbda | ||
self.entropy_eps = entropy_eps | ||
self.env_name = env_name | ||
self.device_info = kwargs.get("device", "cpu") | ||
self.frame_skip = kwargs.get("frame_skip", 1) | ||
# Environment | ||
base_env = self.make_env() | ||
# Env transformations | ||
env = self.transformed_env(base_env) | ||
# Specs | ||
observation_spec = base_env.observation_spec["observation"] | ||
action_space = base_env.action_spec | ||
# Sanity check | ||
logger.debug(f"observation_spec: {observation_spec}") | ||
logger.debug(f"reward_spec: {base_env.reward_spec}") | ||
logger.debug(f"done_spec: {base_env.done_spec}") | ||
logger.debug(f"action_spec: {base_env.action_spec}") | ||
logger.debug(f"state_spec: {base_env.state_spec}") | ||
# Actor | ||
out_features = action_space.shape[-1] | ||
logger.debug(f"MLP out_shape: {out_features}") | ||
actor_net = torch.nn.Sequential( | ||
torch.nn.Flatten(0) if flatten_state else torch.nn.Identity(), | ||
actor_nn, | ||
NormalParamExtractor(), | ||
) | ||
logger.debug(f"Initialized actor: {actor_net}") | ||
policy_module = TensorDictModule( | ||
actor_net, | ||
in_keys=["observation"], | ||
out_keys=["loc", "scale"], | ||
) | ||
td = env.reset() | ||
policy_module(td) | ||
policy_module = ProbabilisticActor( | ||
module=policy_module, | ||
spec=env.action_spec, | ||
in_keys=["loc", "scale"], | ||
distribution_class=TanhNormal, | ||
distribution_kwargs={ | ||
"min": 0, # env.action_spec.space.minimum, | ||
"max": 1, # env.action_spec.space.maximum, | ||
}, | ||
return_log_prob=True, # we'll need the log-prob for the numerator of the importance weights | ||
) | ||
logger.debug(f"Initialized policy: {policy_module}") | ||
# Critic and loss depend on the model | ||
target_net_updater = None | ||
if model in ["cql"]: | ||
advantage_module = None | ||
# Q-Value | ||
value_module = ValueOperator( | ||
module=value_nn, | ||
in_keys=["observation", "action"], | ||
out_keys=["state_action_value"], | ||
) | ||
td = env.reset() | ||
td = env.rand_action(td) | ||
td = env.step(td) | ||
td = value_module(td) | ||
logger.debug(f"Initialized value_module: {td}") | ||
# Loss CQL | ||
loss_module = CQLLoss( | ||
actor_network=policy_module, | ||
qvalue_network=value_module, | ||
action_spec=env.action_spec, | ||
alpha_init=alpha_init, | ||
loss_function=loss_function, | ||
) | ||
loss_module.make_value_estimator(gamma=gamma) | ||
target_net_updater = SoftUpdate(loss_module, tau=tau) | ||
elif model in ["ppo"]: | ||
# Value | ||
value_net = torch.nn.Sequential( | ||
torch.nn.Flatten(1) if flatten_state else torch.nn.Identity(), | ||
value_nn, | ||
) | ||
value_module = ValueOperator( | ||
module=value_net, | ||
in_keys=["observation"], | ||
) | ||
td = env.reset() | ||
value_module(td) | ||
# Loss PPO | ||
advantage_module = GAE( | ||
gamma=gamma, | ||
lmbda=lmbda, | ||
value_network=value_module, | ||
average_gae=True, | ||
) | ||
loss_module = ClipPPOLoss( | ||
actor=policy_module, | ||
critic=value_module, | ||
clip_epsilon=clip_epsilon, | ||
entropy_bonus=bool(entropy_eps), | ||
entropy_coef=entropy_eps, | ||
# these keys match by default but we set this for completeness | ||
critic_coef=1.0, | ||
# gamma=0.99, | ||
loss_critic_type=loss_function, | ||
) | ||
loss_module.set_keys(value_target=advantage_module.value_target_key) | ||
else: | ||
raise ValueError(f"Unrecognized model {model}") | ||
# Call superclass | ||
super().__init__( | ||
loss_module=loss_module, | ||
policy_module=policy_module, | ||
value_module=value_module, | ||
target_net_updater=target_net_updater, | ||
**kwargs, | ||
) | ||
self.advantage_module = advantage_module | ||
|
||
def make_env(self) -> EnvBase: | ||
"""Utility function to init an env. | ||
|
||
Args: | ||
env (ty.Union[str, EnvBase]): _description_ | ||
|
||
Returns: | ||
EnvBase: _description_ | ||
""" | ||
env = GymEnv( | ||
self.env_name, | ||
device=self.device_info, | ||
frame_skip=self.frame_skip, | ||
) | ||
return env |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
__all__ = ["find_device"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing headers |
||
|
||
import typing as ty | ||
|
||
import torch | ||
from lightning.pytorch.accelerators.cuda import CUDAAccelerator | ||
from lightning.pytorch.accelerators.mps import MPSAccelerator | ||
|
||
|
||
def find_device(accelerator: ty.Union[torch.device, str] = None) -> torch.device: | ||
svnv-svsv-jm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Automatically finds system's device for PyTorch.""" | ||
if accelerator is None: | ||
accelerator = "auto" | ||
if isinstance(accelerator, torch.device): | ||
return accelerator # pragma: no cover | ||
device = _choose_auto_accelerator(accelerator) | ||
assert device in ("cpu", "mps", "cuda") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we don't use |
||
return torch.device(device) | ||
|
||
|
||
def _choose_auto_accelerator(accelerator_flag: str) -> str: | ||
"""Choose the accelerator type (str) based on availability when `accelerator='auto'`.""" | ||
accelerator_flag = accelerator_flag.lower() | ||
assert accelerator_flag in ("auto", "cpu", "mps", "cuda") | ||
try: | ||
if accelerator_flag == "auto": | ||
if MPSAccelerator.is_available(): | ||
return "mps" | ||
if CUDAAccelerator.is_available(): # pragma: no cover | ||
return "cuda" # pragma: no cover | ||
return "cpu" # pragma: no cover | ||
except Exception: # pragma: no cover | ||
return "cpu" # pragma: no cover |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# pylint: disable=abstract-method | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing headers There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I get |
||
"""Create the `CollectorDataset` needed by the `pl.Trainer`.""" | ||
|
||
__all__ = ["CollectorDataset"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we don't use |
||
|
||
import typing as ty | ||
|
||
import torch | ||
from tensordict.nn import TensorDictModule # type: ignore | ||
from tensordict import TensorDict # type: ignore | ||
|
||
from torch.utils.data import IterableDataset | ||
from torchrl.collectors import SyncDataCollector | ||
from torchrl.data.replay_buffers import ReplayBuffer | ||
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement | ||
from torchrl.data.replay_buffers.storages import LazyTensorStorage | ||
from torchrl.envs import EnvBase | ||
|
||
from .accelerators import find_device | ||
|
||
|
||
class CollectorDataset(IterableDataset): | ||
"""Iterable Dataset containing the `ReplayBuffer` which will be | ||
updated with new experiences during training, and the `SyncDataCollector`.""" | ||
|
||
def __init__( | ||
self, | ||
collector: ty.Optional[SyncDataCollector] = None, | ||
env: ty.Optional[EnvBase] = None, | ||
svnv-svsv-jm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
policy_module: TensorDictModule = None, | ||
frames_per_batch: int = 1, | ||
total_frames: int = 1000, | ||
device: torch.device = find_device(), | ||
split_trajs: bool = False, | ||
batch_size: int = 1, | ||
init_random_frames: int = 1, | ||
) -> None: | ||
# Attributes | ||
self.batch_size = batch_size | ||
self.device = device | ||
self.frames_per_batch = frames_per_batch | ||
self.total_frames = total_frames | ||
# Collector | ||
if collector is None: | ||
if env is None: | ||
raise ValueError( | ||
"Please provide an environment when not providing a collector." | ||
) | ||
self.collector = SyncDataCollector( | ||
env, | ||
policy_module, | ||
frames_per_batch=self.frames_per_batch, | ||
total_frames=self.total_frames, | ||
device=self.device, | ||
storing_device=self.device, | ||
split_trajs=split_trajs, | ||
init_random_frames=init_random_frames, | ||
) | ||
else: | ||
self.collector = collector | ||
# ReplayBuffer | ||
self.replay_buffer = ReplayBuffer( | ||
storage=LazyTensorStorage(frames_per_batch), | ||
sampler=SamplerWithoutReplacement(), | ||
batch_size=self.batch_size, | ||
) | ||
# States | ||
self.length: ty.Optional[int] = None | ||
svnv-svsv-jm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __iter__(self) -> ty.Iterator[TensorDict]: | ||
"""Yield experiences from `SyncDataCollector` and store them in `ReplayBuffer`.""" | ||
i = 0 | ||
for i, tensordict_data in enumerate(self.collector): | ||
assert isinstance(tensordict_data, TensorDict) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no assert in the main codebase |
||
data_view: TensorDict = tensordict_data.reshape(-1) | ||
self.replay_buffer.extend(data_view.cpu()) | ||
yield tensordict_data.to(self.device) | ||
self.length = i | ||
|
||
def sample(self, **kwargs: ty.Any) -> TensorDict: | ||
"""Sample from `ReplayBuffer`.""" | ||
data: TensorDict = self.replay_buffer.sample(**kwargs) | ||
return data.to(self.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should belong to test_libs.py or test_trainer.py