diff --git a/sota-implementations/td3/td3.py b/sota-implementations/td3/td3.py index 46451bc087c..a977a7caebe 100644 --- a/sota-implementations/td3/td3.py +++ b/sota-implementations/td3/td3.py @@ -114,8 +114,7 @@ def main(cfg: "DictConfig"): # noqa: F821 prb = cfg.replay_buffer.prb - def update(update_actor, prb=prb): - sampled_tensordict = replay_buffer.sample() + def update(sampled_tensordict, update_actor, prb=prb): # Compute loss q_loss, *_ = loss_module.value_loss(sampled_tensordict) @@ -138,10 +137,6 @@ def update(update_actor, prb=prb): else: actor_loss = q_loss.new_zeros(()) - # Update priority - if prb: - replay_buffer.update_priority(sampled_tensordict) - return q_loss.detach(), actor_loss.detach() if cfg.compile.compile: @@ -204,9 +199,16 @@ def update(update_actor, prb=prb): update_counter += 1 update_actor = update_counter % delayed_updates == 0 + with timeit("rb - sample"): + sampled_tensordict = replay_buffer.sample() with timeit("update"): torch.compiler.cudagraph_mark_step_begin() - q_loss, actor_loss = update(update_actor) + q_loss, actor_loss = update(sampled_tensordict, update_actor) + + # Update priority + if prb: + with timeit("rb - priority"): + replay_buffer.update_priority(sampled_tensordict) q_losses.append(q_loss.clone()) if update_actor: