diff --git a/sota-implementations/td3/td3.py b/sota-implementations/td3/td3.py index f8f24bec35a..46451bc087c 100644 --- a/sota-implementations/td3/td3.py +++ b/sota-implementations/td3/td3.py @@ -106,12 +106,17 @@ def main(cfg: "DictConfig"): # noqa: F821 buffer_size=cfg.replay_buffer.size, scratch_dir=cfg.replay_buffer.scratch_dir, device=device, + compile=bool(compile_mode), ) # Create optimizers optimizer_actor, optimizer_critic = make_optimizer(cfg, loss_module) - def update(sampled_tensordict, update_actor): + prb = cfg.replay_buffer.prb + + def update(update_actor, prb=prb): + sampled_tensordict = replay_buffer.sample() + # Compute loss q_loss, *_ = loss_module.value_loss(sampled_tensordict) @@ -133,6 +138,10 @@ def update(sampled_tensordict, update_actor): else: actor_loss = q_loss.new_zeros(()) + # Update priority + if prb: + replay_buffer.update_priority(sampled_tensordict) + return q_loss.detach(), actor_loss.detach() if cfg.compile.compile: @@ -156,7 +165,6 @@ def update(sampled_tensordict, update_actor): * cfg.optim.utd_ratio ) delayed_updates = cfg.optim.policy_update_delay - prb = cfg.replay_buffer.prb eval_rollout_steps = cfg.env.max_episode_steps eval_iter = cfg.logger.eval_iter frames_per_batch = cfg.collector.frames_per_batch @@ -196,22 +204,14 @@ def update(sampled_tensordict, update_actor): update_counter += 1 update_actor = update_counter % delayed_updates == 0 - with timeit("rb - sample"): - # Sample from replay buffer - sampled_tensordict = replay_buffer.sample() - with timeit("update"): torch.compiler.cudagraph_mark_step_begin() - q_loss, actor_loss = update(sampled_tensordict, update_actor) + q_loss, actor_loss = update(update_actor) q_losses.append(q_loss.clone()) if update_actor: actor_losses.append(actor_loss.clone()) - # Update priority - if prb: - replay_buffer.update_priority(sampled_tensordict) - episode_end = ( tensordict["next", "done"] if tensordict["next", "done"].any() diff --git a/sota-implementations/td3/utils.py b/sota-implementations/td3/utils.py index 04d11a913b3..3f797064f35 100644 --- a/sota-implementations/td3/utils.py +++ b/sota-implementations/td3/utils.py @@ -138,6 +138,7 @@ def make_replay_buffer( scratch_dir: str | None = None, device: torch.device = "cpu", prefetch: int = 3, + compile: bool = False, ): with ( tempfile.TemporaryDirectory() @@ -145,7 +146,7 @@ def make_replay_buffer( else nullcontext(scratch_dir) ) as scratch_dir: storage_cls = ( - functools.partial(LazyTensorStorage, device=device) + functools.partial(LazyTensorStorage, device=device, compilable=compile) if not scratch_dir else functools.partial( LazyMemmapStorage, device="cpu", scratch_dir=scratch_dir @@ -160,6 +161,7 @@ def make_replay_buffer( prefetch=prefetch, storage=storage_cls(buffer_size), batch_size=batch_size, + compilable=compile, ) else: replay_buffer = TensorDictReplayBuffer( @@ -167,6 +169,7 @@ def make_replay_buffer( prefetch=prefetch, storage=storage_cls(buffer_size), batch_size=batch_size, + compilable=compile, ) if scratch_dir: replay_buffer.append_transform(lambda td: td.to(device)) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index fbb76b5a681..4ddf059d5b4 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -1094,6 +1094,9 @@ class TensorDictReplayBuffer(ReplayBuffer): .. warning:: As of now, the generator has no effect on the transforms. shared (bool, optional): whether the buffer will be shared using multiprocessing or not. Defaults to ``False``. + compilable (bool, optional): whether the writer is compilable. + If ``True``, the writer cannot be shared between multiple processes. + Defaults to ``False``. Examples: >>> import torch @@ -1437,6 +1440,9 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer): .. warning:: As of now, the generator has no effect on the transforms. shared (bool, optional): whether the buffer will be shared using multiprocessing or not. Defaults to ``False``. + compilable (bool, optional): whether the writer is compilable. + If ``True``, the writer cannot be shared between multiple processes. + Defaults to ``False``. Examples: >>> import torch @@ -1512,6 +1518,7 @@ def __init__( dim_extend: int | None = None, generator: torch.Generator | None = None, shared: bool = False, + compilable: bool = False, ) -> None: if storage is None: storage = ListStorage(max_size=1_000) @@ -1530,6 +1537,7 @@ def __init__( dim_extend=dim_extend, generator=generator, shared=shared, + compilable=compilable, )