From 70a7cc0433a08a13de7c9fd70c151c198fad72d1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 19:11:01 -0800 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- sota-implementations/td3/td3.py | 11 ++++++++--- sota-implementations/td3/utils.py | 17 ++++++++--------- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/sota-implementations/td3/td3.py b/sota-implementations/td3/td3.py index 36183655fe0..f8f24bec35a 100644 --- a/sota-implementations/td3/td3.py +++ b/sota-implementations/td3/td3.py @@ -49,7 +49,8 @@ def main(cfg: "DictConfig"): # noqa: F821 device = torch.device("cuda:0") else: device = torch.device("cpu") - device = torch.device(device) + else: + device = torch.device(device) # Create logger exp_name = generate_exp_name("TD3", cfg.logger.exp_name) @@ -72,7 +73,7 @@ def main(cfg: "DictConfig"): # noqa: F821 np.random.seed(cfg.env.seed) # Create environments - train_env, eval_env = make_environment(cfg, logger=logger) + train_env, eval_env = make_environment(cfg, logger=logger, device=device) # Create agent model, exploration_policy = make_td3_agent(cfg, train_env, eval_env, device) @@ -91,7 +92,11 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create off-policy collector collector = make_collector( - cfg, train_env, exploration_policy, compile_mode=compile_mode + cfg, + train_env, + exploration_policy, + compile_mode=compile_mode, + device=device, ) # Create replay buffer diff --git a/sota-implementations/td3/utils.py b/sota-implementations/td3/utils.py index 13a234e31be..04d11a913b3 100644 --- a/sota-implementations/td3/utils.py +++ b/sota-implementations/td3/utils.py @@ -75,13 +75,14 @@ def apply_env_transforms(env, max_episode_steps): return transformed_env -def make_environment(cfg, logger=None): +def make_environment(cfg, logger, device): """Make environments for training and evaluation.""" partial = functools.partial(env_maker, cfg=cfg) parallel_env = ParallelEnv( cfg.collector.env_per_collector, EnvCreator(partial), serial_for_single=True, + device=device, ) parallel_env.set_seed(cfg.env.seed) @@ -98,6 +99,7 @@ def make_environment(cfg, logger=None): cfg.collector.env_per_collector, EnvCreator(partial), serial_for_single=True, + device=device, ), trsf_clone, ) @@ -109,14 +111,11 @@ def make_environment(cfg, logger=None): # --------------------------- -def make_collector(cfg, train_env, actor_model_explore, compile_mode): +def make_collector(cfg, train_env, actor_model_explore, compile_mode, device): """Make collector.""" - device = cfg.collector.device - if device in ("", None): - if torch.cuda.is_available(): - device = torch.device("cuda:0") - else: - device = torch.device("cpu") + collector_device = cfg.collector.device + if collector_device in ("", None): + collector_device = device collector = SyncDataCollector( train_env, actor_model_explore, @@ -124,7 +123,7 @@ def make_collector(cfg, train_env, actor_model_explore, compile_mode): frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, reset_at_each_iter=cfg.collector.reset_at_each_iter, - device=device, + device=collector_device, compile_policy={"mode": compile_mode} if compile_mode else False, cudagraph_policy=cfg.compile.cudagraphs, )