Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 16, 2024
1 parent 3cff702 commit c17fea0
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 12 deletions.
22 changes: 11 additions & 11 deletions sota-implementations/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,17 @@ def main(cfg: "DictConfig"): # noqa: F821
buffer_size=cfg.replay_buffer.size,
scratch_dir=cfg.replay_buffer.scratch_dir,
device=device,
compile=bool(compile_mode),
)

# Create optimizers
optimizer_actor, optimizer_critic = make_optimizer(cfg, loss_module)

def update(sampled_tensordict, update_actor):
prb = cfg.replay_buffer.prb

def update(update_actor, prb=prb):
sampled_tensordict = replay_buffer.sample()

# Compute loss
q_loss, *_ = loss_module.value_loss(sampled_tensordict)

Expand All @@ -133,6 +138,10 @@ def update(sampled_tensordict, update_actor):
else:
actor_loss = q_loss.new_zeros(())

# Update priority
if prb:
replay_buffer.update_priority(sampled_tensordict)

return q_loss.detach(), actor_loss.detach()

if cfg.compile.compile:
Expand All @@ -156,7 +165,6 @@ def update(sampled_tensordict, update_actor):
* cfg.optim.utd_ratio
)
delayed_updates = cfg.optim.policy_update_delay
prb = cfg.replay_buffer.prb
eval_rollout_steps = cfg.env.max_episode_steps
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch
Expand Down Expand Up @@ -196,22 +204,14 @@ def update(sampled_tensordict, update_actor):
update_counter += 1
update_actor = update_counter % delayed_updates == 0

with timeit("rb - sample"):
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()

with timeit("update"):
torch.compiler.cudagraph_mark_step_begin()
q_loss, actor_loss = update(sampled_tensordict, update_actor)
q_loss, actor_loss = update(update_actor)

q_losses.append(q_loss.clone())
if update_actor:
actor_losses.append(actor_loss.clone())

# Update priority
if prb:
replay_buffer.update_priority(sampled_tensordict)

episode_end = (
tensordict["next", "done"]
if tensordict["next", "done"].any()
Expand Down
5 changes: 4 additions & 1 deletion sota-implementations/td3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,15 @@ def make_replay_buffer(
scratch_dir: str | None = None,
device: torch.device = "cpu",
prefetch: int = 3,
compile: bool = False,
):
with (
tempfile.TemporaryDirectory()
if scratch_dir is None
else nullcontext(scratch_dir)
) as scratch_dir:
storage_cls = (
functools.partial(LazyTensorStorage, device=device)
functools.partial(LazyTensorStorage, device=device, compilable=compile)
if not scratch_dir
else functools.partial(
LazyMemmapStorage, device="cpu", scratch_dir=scratch_dir
Expand All @@ -160,13 +161,15 @@ def make_replay_buffer(
prefetch=prefetch,
storage=storage_cls(buffer_size),
batch_size=batch_size,
compilable=compile,
)
else:
replay_buffer = TensorDictReplayBuffer(
pin_memory=False,
prefetch=prefetch,
storage=storage_cls(buffer_size),
batch_size=batch_size,
compilable=compile,
)
if scratch_dir:
replay_buffer.append_transform(lambda td: td.to(device))
Expand Down
8 changes: 8 additions & 0 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,6 +1094,9 @@ class TensorDictReplayBuffer(ReplayBuffer):
.. warning:: As of now, the generator has no effect on the transforms.
shared (bool, optional): whether the buffer will be shared using multiprocessing or not.
Defaults to ``False``.
compilable (bool, optional): whether the writer is compilable.
If ``True``, the writer cannot be shared between multiple processes.
Defaults to ``False``.
Examples:
>>> import torch
Expand Down Expand Up @@ -1437,6 +1440,9 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer):
.. warning:: As of now, the generator has no effect on the transforms.
shared (bool, optional): whether the buffer will be shared using multiprocessing or not.
Defaults to ``False``.
compilable (bool, optional): whether the writer is compilable.
If ``True``, the writer cannot be shared between multiple processes.
Defaults to ``False``.
Examples:
>>> import torch
Expand Down Expand Up @@ -1512,6 +1518,7 @@ def __init__(
dim_extend: int | None = None,
generator: torch.Generator | None = None,
shared: bool = False,
compilable: bool = False,
) -> None:
if storage is None:
storage = ListStorage(max_size=1_000)
Expand All @@ -1530,6 +1537,7 @@ def __init__(
dim_extend=dim_extend,
generator=generator,
shared=shared,
compilable=compilable,
)


Expand Down

0 comments on commit c17fea0

Please sign in to comment.