diff --git a/sota-implementations/cql/cql_online.py b/sota-implementations/cql/cql_online.py index b61556874c3..03bdf6a493f 100644 --- a/sota-implementations/cql/cql_online.py +++ b/sota-implementations/cql/cql_online.py @@ -159,11 +159,7 @@ def update(sampled_tensordict): pbar = tqdm.tqdm(total=cfg.collector.total_frames) init_random_frames = cfg.collector.init_random_frames - num_updates = int( - cfg.collector.env_per_collector - * cfg.collector.frames_per_batch - * cfg.optim.utd_ratio - ) + num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio) prb = cfg.replay_buffer.prb frames_per_batch = cfg.collector.frames_per_batch evaluation_interval = cfg.logger.log_interval diff --git a/sota-implementations/cql/discrete_cql_online.py b/sota-implementations/cql/discrete_cql_online.py index e6a710f1f4b..35238c5c6ab 100644 --- a/sota-implementations/cql/discrete_cql_online.py +++ b/sota-implementations/cql/discrete_cql_online.py @@ -140,11 +140,7 @@ def update(sampled_tensordict): pbar = tqdm.tqdm(total=cfg.collector.total_frames) init_random_frames = cfg.collector.init_random_frames - num_updates = int( - cfg.collector.env_per_collector - * cfg.collector.frames_per_batch - * cfg.optim.utd_ratio - ) + num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio) prb = cfg.replay_buffer.prb eval_rollout_steps = cfg.env.max_episode_steps eval_iter = cfg.logger.eval_iter diff --git a/sota-implementations/crossq/crossq.py b/sota-implementations/crossq/crossq.py index 5f6d762d644..07de3e26175 100644 --- a/sota-implementations/crossq/crossq.py +++ b/sota-implementations/crossq/crossq.py @@ -179,11 +179,7 @@ def update(sampled_tensordict: TensorDict, update_actor: bool): pbar = tqdm.tqdm(total=cfg.collector.total_frames) init_random_frames = cfg.collector.init_random_frames - num_updates = int( - cfg.collector.env_per_collector - * cfg.collector.frames_per_batch - * cfg.optim.utd_ratio - ) + num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio) prb = cfg.replay_buffer.prb eval_iter = cfg.logger.eval_iter frames_per_batch = cfg.collector.frames_per_batch diff --git a/sota-implementations/ddpg/ddpg.py b/sota-implementations/ddpg/ddpg.py index 9d06dc2ff75..6e2a749c3f1 100644 --- a/sota-implementations/ddpg/ddpg.py +++ b/sota-implementations/ddpg/ddpg.py @@ -145,11 +145,7 @@ def update(sampled_tensordict): pbar = tqdm.tqdm(total=cfg.collector.total_frames) init_random_frames = cfg.collector.init_random_frames - num_updates = int( - cfg.collector.env_per_collector - * cfg.collector.frames_per_batch - * cfg.optim.utd_ratio - ) + num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio) prb = cfg.replay_buffer.prb frames_per_batch = cfg.collector.frames_per_batch eval_iter = cfg.logger.eval_iter diff --git a/sota-implementations/discrete_sac/discrete_sac.py b/sota-implementations/discrete_sac/discrete_sac.py index a5dad120a60..b7910c4e578 100644 --- a/sota-implementations/discrete_sac/discrete_sac.py +++ b/sota-implementations/discrete_sac/discrete_sac.py @@ -144,11 +144,7 @@ def update(sampled_tensordict): pbar = tqdm.tqdm(total=cfg.collector.total_frames) init_random_frames = cfg.collector.init_random_frames - num_updates = int( - cfg.collector.env_per_collector - * cfg.collector.frames_per_batch - * cfg.optim.utd_ratio - ) + num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio) prb = cfg.replay_buffer.prb eval_rollout_steps = cfg.env.max_episode_steps eval_iter = cfg.logger.eval_iter diff --git a/sota-implementations/iql/discrete_iql.py b/sota-implementations/iql/discrete_iql.py index 805cfc6e23d..e56661acf0c 100644 --- a/sota-implementations/iql/discrete_iql.py +++ b/sota-implementations/iql/discrete_iql.py @@ -148,11 +148,7 @@ def update(sampled_tensordict): pbar = tqdm.tqdm(total=cfg.collector.total_frames) init_random_frames = cfg.collector.init_random_frames - num_updates = int( - cfg.collector.env_per_collector - * cfg.collector.frames_per_batch - * cfg.optim.utd_ratio - ) + num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio) prb = cfg.replay_buffer.prb eval_iter = cfg.logger.eval_iter frames_per_batch = cfg.collector.frames_per_batch diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py index 28b35099286..7ec2a30dfd9 100644 --- a/sota-implementations/iql/iql_online.py +++ b/sota-implementations/iql/iql_online.py @@ -145,11 +145,7 @@ def update(sampled_tensordict): collected_frames = 0 init_random_frames = cfg.collector.init_random_frames - num_updates = int( - cfg.collector.env_per_collector - * cfg.collector.frames_per_batch - * cfg.optim.utd_ratio - ) + num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio) prb = cfg.replay_buffer.prb eval_iter = cfg.logger.eval_iter frames_per_batch = cfg.collector.frames_per_batch diff --git a/sota-implementations/sac/config.yaml b/sota-implementations/sac/config.yaml index a1ecb90aeba..d6cb09382aa 100644 --- a/sota-implementations/sac/config.yaml +++ b/sota-implementations/sac/config.yaml @@ -13,7 +13,7 @@ collector: frames_per_batch: 1000 init_env_steps: 1000 device: - env_per_collector: 1 + env_per_collector: 8 reset_at_each_iter: False # replay buffer diff --git a/sota-implementations/sac/sac.py b/sota-implementations/sac/sac.py index b97fed3091c..a1ec631fe39 100644 --- a/sota-implementations/sac/sac.py +++ b/sota-implementations/sac/sac.py @@ -143,11 +143,7 @@ def update(sampled_tensordict): pbar = tqdm.tqdm(total=cfg.collector.total_frames) init_random_frames = cfg.collector.init_random_frames - num_updates = int( - cfg.collector.env_per_collector - * cfg.collector.frames_per_batch - * cfg.optim.utd_ratio - ) + num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio) prb = cfg.replay_buffer.prb eval_iter = cfg.logger.eval_iter frames_per_batch = cfg.collector.frames_per_batch diff --git a/sota-implementations/td3/config.yaml b/sota-implementations/td3/config.yaml index 8207a41a9f5..31fa52b72f3 100644 --- a/sota-implementations/td3/config.yaml +++ b/sota-implementations/td3/config.yaml @@ -14,7 +14,7 @@ collector: frames_per_batch: 1000 reset_at_each_iter: False device: - env_per_collector: 1 + env_per_collector: 8 num_workers: 1 # replay buffer diff --git a/sota-implementations/td3/td3.py b/sota-implementations/td3/td3.py index a977a7caebe..bcbe6b879da 100644 --- a/sota-implementations/td3/td3.py +++ b/sota-implementations/td3/td3.py @@ -154,11 +154,7 @@ def update(sampled_tensordict, update_actor, prb=prb): pbar = tqdm.tqdm(total=cfg.collector.total_frames) init_random_frames = cfg.collector.init_random_frames - num_updates = int( - cfg.collector.env_per_collector - * cfg.collector.frames_per_batch - * cfg.optim.utd_ratio - ) + num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio) delayed_updates = cfg.optim.policy_update_delay eval_rollout_steps = cfg.env.max_episode_steps eval_iter = cfg.logger.eval_iter diff --git a/sota-implementations/td3/utils.py b/sota-implementations/td3/utils.py index cdf52927158..9562da65450 100644 --- a/sota-implementations/td3/utils.py +++ b/sota-implementations/td3/utils.py @@ -96,7 +96,7 @@ def make_environment(cfg, logger, device): ) eval_env = TransformedEnv( ParallelEnv( - cfg.collector.env_per_collector, + 1, EnvCreator(partial), serial_for_single=True, device=device,