Skip to content

Commit

Permalink
[Feature] TD3 compatibility with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: fb94307557f2b8604403b48211e3da6fb2139e28
Pull Request resolved: #2658
  • Loading branch information
vmoens committed Dec 16, 2024
1 parent 87a59fb commit 016d5dd
Show file tree
Hide file tree
Showing 16 changed files with 501 additions and 502 deletions.
6 changes: 1 addition & 5 deletions sota-implementations/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,7 @@ def update(sampled_tensordict):
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
prb = cfg.replay_buffer.prb
frames_per_batch = cfg.collector.frames_per_batch
evaluation_interval = cfg.logger.log_interval
Expand Down
6 changes: 1 addition & 5 deletions sota-implementations/cql/discrete_cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,7 @@ def update(sampled_tensordict):
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
prb = cfg.replay_buffer.prb
eval_rollout_steps = cfg.env.max_episode_steps
eval_iter = cfg.logger.eval_iter
Expand Down
6 changes: 1 addition & 5 deletions sota-implementations/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,7 @@ def update(sampled_tensordict: TensorDict, update_actor: bool):
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
prb = cfg.replay_buffer.prb
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch
Expand Down
6 changes: 1 addition & 5 deletions sota-implementations/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,7 @@ def update(sampled_tensordict):
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
prb = cfg.replay_buffer.prb
frames_per_batch = cfg.collector.frames_per_batch
eval_iter = cfg.logger.eval_iter
Expand Down
6 changes: 1 addition & 5 deletions sota-implementations/discrete_sac/discrete_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,7 @@ def update(sampled_tensordict):
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
prb = cfg.replay_buffer.prb
eval_rollout_steps = cfg.env.max_episode_steps
eval_iter = cfg.logger.eval_iter
Expand Down
6 changes: 1 addition & 5 deletions sota-implementations/iql/discrete_iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,7 @@ def update(sampled_tensordict):
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
prb = cfg.replay_buffer.prb
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch
Expand Down
6 changes: 1 addition & 5 deletions sota-implementations/iql/iql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,7 @@ def update(sampled_tensordict):
collected_frames = 0

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
prb = cfg.replay_buffer.prb
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/sac/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ collector:
frames_per_batch: 1000
init_env_steps: 1000
device:
env_per_collector: 1
env_per_collector: 8
reset_at_each_iter: False

# replay buffer
Expand Down
6 changes: 1 addition & 5 deletions sota-implementations/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,7 @@ def update(sampled_tensordict):
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
prb = cfg.replay_buffer.prb
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch
Expand Down
7 changes: 6 additions & 1 deletion sota-implementations/td3/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ collector:
frames_per_batch: 1000
reset_at_each_iter: False
device:
env_per_collector: 1
env_per_collector: 8
num_workers: 1

# replay buffer
Expand Down Expand Up @@ -52,3 +52,8 @@ logger:
mode: online
eval_iter: 25000
video: False

compile:
compile: False
compile_mode:
cudagraphs: False
Loading

0 comments on commit 016d5dd

Please sign in to comment.