diff --git a/xuance/torch/agents/base/agent.py b/xuance/torch/agents/base/agent.py index f39970c3..e33ad81d 100644 --- a/xuance/torch/agents/base/agent.py +++ b/xuance/torch/agents/base/agent.py @@ -38,9 +38,19 @@ def __init__(self, self.rank = int(os.environ['LOCAL_RANK']) master_port = config.master_port if hasattr(config, "master_port") else None init_distributed_mode(master_port=master_port) + if self.config.buffer_size < self.world_size: + raise AttributeError("The config.buffer_size is less than the number of GPUs.") + else: + self.buffer_size = self.config.buffer_size // self.world_size + if self.config.batch_size < self.world_size: + raise AttributeError("The config.batch_size is less than the number of GPUs.") + else: + self.batch_size = self.config.batch_size // self.world_size else: - self.rank = 0 self.world_size = 1 + self.rank = 0 + self.buffer_size = self.config.buffer_size + self.batch_size = self.config.batch_size self.gamma = config.gamma self.start_training = config.start_training if hasattr(config, "start_training") else 1 diff --git a/xuance/torch/agents/core/off_policy.py b/xuance/torch/agents/core/off_policy.py index 70aeb6c3..dceb20fe 100644 --- a/xuance/torch/agents/core/off_policy.py +++ b/xuance/torch/agents/core/off_policy.py @@ -41,8 +41,8 @@ def _build_memory(self, auxiliary_info_shape=None): action_space=self.action_space, auxiliary_shape=auxiliary_info_shape, n_envs=self.n_envs, - buffer_size=self.config.buffer_size, - batch_size=self.config.batch_size) + buffer_size=self.buffer_size, + batch_size=self.batch_size) return Buffer(**input_buffer) def _build_policy(self) -> Module: @@ -137,14 +137,13 @@ def train(self, train_steps): self.ret_rms.update(self.returns[i:i + 1]) self.returns[i] = 0.0 self.current_episode[i] += 1 - if self.rank == 0: - if self.use_wandb: - step_info[f"Episode-Steps/env-{i}"] = infos[i]["episode_step"] - step_info[f"Train-Episode-Rewards/env-{i}"] = infos[i]["episode_score"] - else: - step_info["Episode-Steps"] = {f"env-{i}": infos[i]["episode_step"]} - step_info["Train-Episode-Rewards"] = {f"env-{i}": infos[i]["episode_score"]} - self.log_infos(step_info, self.current_step) + if self.use_wandb: + step_info[f"Episode-Steps/rank_{self.rank}/env-{i}"] = infos[i]["episode_step"] + step_info[f"Train-Episode-Rewards/rank_{self.rank}/env-{i}"] = infos[i]["episode_score"] + else: + step_info[f"Episode-Steps/rank_{self.rank}"] = {f"env-{i}": infos[i]["episode_step"]} + step_info[f"Train-Episode-Rewards/rank_{self.rank}"] = {f"env-{i}": infos[i]["episode_score"]} + self.log_infos(step_info, self.current_step) self.current_step += self.n_envs self._update_explore_factor() diff --git a/xuance/torch/learners/learner.py b/xuance/torch/learners/learner.py index ae842ab2..5be81e3e 100644 --- a/xuance/torch/learners/learner.py +++ b/xuance/torch/learners/learner.py @@ -47,28 +47,6 @@ def __init__(self, self.running_steps = config.running_steps self.iterations = 0 - def build_training_data(self, samples: Optional[dict]): - batch_size = samples['batch_size'] - samples_Tensor = {} - if self.world_size > 1: # i.e., Multi-GPU settings. - rank = int(os.environ['RANK']) - batch_size_local = batch_size // self.world_size - if rank < self.world_size - 1: - indices = range(rank * batch_size_local, (rank + 1) * batch_size_local) - else: - indices = range(rank * batch_size_local, batch_size) - for k, v in samples.items(): - if k == 'batch_size': - continue - samples_Tensor[k] = torch.as_tensor(v[indices], device=self.device) - else: - for k, v in samples.items(): - if k == 'batch_size': - continue - samples_Tensor[k] = torch.as_tensor(v, device=self.device) - - return samples_Tensor - def save_model(self, model_path): torch.save(self.policy.state_dict(), model_path) if self.distributed_training: diff --git a/xuance/torch/learners/policy_gradient/a2c_learner.py b/xuance/torch/learners/policy_gradient/a2c_learner.py index d03c59bb..278b5276 100644 --- a/xuance/torch/learners/policy_gradient/a2c_learner.py +++ b/xuance/torch/learners/policy_gradient/a2c_learner.py @@ -22,11 +22,10 @@ def __init__(self, def update(self, **samples): self.iterations += 1 - sample_Tensor = self.build_training_data(samples=samples) - obs_batch = sample_Tensor['obs'] - act_batch = sample_Tensor['actions'] - ret_batch = sample_Tensor['returns'] - adv_batch = sample_Tensor['advantages'] + obs_batch = torch.as_tensor(samples['obs'], device=self.device) + act_batch = torch.as_tensor(samples['actions'], device=self.device) + ret_batch = torch.as_tensor(samples['returns'], device=self.device) + adv_batch = torch.as_tensor(samples['adv_batch'], device=self.device) outputs, a_dist, v_pred = self.policy(obs_batch) log_prob = a_dist.log_prob(act_batch) diff --git a/xuance/torch/learners/policy_gradient/ddpg_learner.py b/xuance/torch/learners/policy_gradient/ddpg_learner.py index 8965b2fd..50978c1f 100644 --- a/xuance/torch/learners/policy_gradient/ddpg_learner.py +++ b/xuance/torch/learners/policy_gradient/ddpg_learner.py @@ -28,12 +28,11 @@ def __init__(self, def update(self, **samples): self.iterations += 1 - sample_Tensor = self.build_training_data(samples=samples) - obs_batch = sample_Tensor['obs'] - act_batch = sample_Tensor['actions'] - next_batch = sample_Tensor['obs_next'] - rew_batch = sample_Tensor['rewards'] - ter_batch = sample_Tensor['terminals'] + obs_batch = torch.as_tensor(samples['obs'], device=self.device) + act_batch = torch.as_tensor(samples['actions'], device=self.device) + next_batch = torch.as_tensor(samples['obs_next'], device=self.device) + rew_batch = torch.as_tensor(samples['rewards'], device=self.device) + ter_batch = torch.as_tensor(samples['terminals'], device=self.device) # critic update action_q = self.policy.Qaction(obs_batch, act_batch).reshape([-1]) diff --git a/xuance/torch/learners/policy_gradient/mpdqn_learner.py b/xuance/torch/learners/policy_gradient/mpdqn_learner.py index 2ac2c4c7..c70b8bf8 100644 --- a/xuance/torch/learners/policy_gradient/mpdqn_learner.py +++ b/xuance/torch/learners/policy_gradient/mpdqn_learner.py @@ -28,12 +28,11 @@ def __init__(self, def update(self, **samples): self.iterations += 1 - sample_Tensor = self.build_training_data(samples=samples) - obs_batch = sample_Tensor['obs'] - hyact_batch = sample_Tensor['actions'] - rew_batch = sample_Tensor['rewards'] - next_batch = sample_Tensor['obs_next'] - ter_batch = sample_Tensor['terminals'] + obs_batch = torch.as_tensor(samples['obs'], device=self.device) + hyact_batch = torch.as_tensor(samples['actions'], device=self.device) + next_batch = torch.as_tensor(samples['obs_next'], device=self.device) + rew_batch = torch.as_tensor(samples['rewards'], device=self.device) + ter_batch = torch.as_tensor(samples['terminals'], device=self.device) disact_batch = hyact_batch[:, 0].long() conact_batch = hyact_batch[:, 1:] diff --git a/xuance/torch/learners/policy_gradient/pdqn_learner.py b/xuance/torch/learners/policy_gradient/pdqn_learner.py index 44f4b253..bd0b1403 100644 --- a/xuance/torch/learners/policy_gradient/pdqn_learner.py +++ b/xuance/torch/learners/policy_gradient/pdqn_learner.py @@ -28,12 +28,11 @@ def __init__(self, def update(self, **samples): self.iterations += 1 - sample_Tensor = self.build_training_data(samples=samples) - obs_batch = sample_Tensor['obs'] - hyact_batch = sample_Tensor['actions'] - rew_batch = sample_Tensor['rewards'] - next_batch = sample_Tensor['obs_next'] - ter_batch = sample_Tensor['terminals'] + obs_batch = torch.as_tensor(samples['obs'], device=self.device) + hyact_batch = torch.as_tensor(samples['actions'], device=self.device) + next_batch = torch.as_tensor(samples['obs_next'], device=self.device) + rew_batch = torch.as_tensor(samples['rewards'], device=self.device) + ter_batch = torch.as_tensor(samples['terminals'], device=self.device) disact_batch = hyact_batch[:, 0].long() conact_batch = hyact_batch[:, 1:] diff --git a/xuance/torch/learners/policy_gradient/pg_learner.py b/xuance/torch/learners/policy_gradient/pg_learner.py index ed9318d5..333f898a 100644 --- a/xuance/torch/learners/policy_gradient/pg_learner.py +++ b/xuance/torch/learners/policy_gradient/pg_learner.py @@ -21,10 +21,9 @@ def __init__(self, def update(self, **samples): self.iterations += 1 - sample_Tensor = self.build_training_data(samples=samples) - obs_batch = sample_Tensor['obs'] - act_batch = sample_Tensor['actions'] - ret_batch = sample_Tensor['returns'] + obs_batch = torch.as_tensor(samples['obs'], device=self.device) + act_batch = torch.as_tensor(samples['actions'], device=self.device) + ret_batch = torch.as_tensor(samples['returns'], device=self.device) _, a_dist, _ = self.policy(obs_batch) log_prob = a_dist.log_prob(act_batch) diff --git a/xuance/torch/learners/policy_gradient/ppg_learner.py b/xuance/torch/learners/policy_gradient/ppg_learner.py index f8958767..d99b1506 100644 --- a/xuance/torch/learners/policy_gradient/ppg_learner.py +++ b/xuance/torch/learners/policy_gradient/ppg_learner.py @@ -27,11 +27,10 @@ def __init__(self, def update_policy(self, **samples): self.policy_iterations += 1 - sample_Tensor = self.build_training_data(samples=samples) - obs_batch = sample_Tensor['obs'] - act_batch = sample_Tensor['actions'] - adv_batch = sample_Tensor['advantages'] - old_dist = merge_distributions(sample_Tensor['aux_batch']['old_dist']) + obs_batch = torch.as_tensor(samples['obs'], device=self.device) + act_batch = torch.as_tensor(samples['actions'], device=self.device) + adv_batch = torch.as_tensor(samples['advantages'], device=self.device) + old_dist = merge_distributions(samples['aux_batch']['old_dist']) old_logp_batch = old_dist.log_prob(act_batch).detach() outputs, a_dist, _, _ = self.policy(obs_batch) @@ -73,9 +72,8 @@ def update_policy(self, **samples): def update_critic(self, **samples): self.value_iterations += 1 - sample_Tensor = self.build_training_data(samples=samples) - obs_batch = sample_Tensor['obs'] - ret_batch = sample_Tensor['returns'] + obs_batch = torch.as_tensor(samples['obs'], device=self.device) + ret_batch = torch.as_tensor(samples['returns'], device=self.device) _, _, v_pred, _ = self.policy(obs_batch) loss = self.mse_loss(v_pred, ret_batch) @@ -92,10 +90,9 @@ def update_critic(self, **samples): return info def update_auxiliary(self, **samples): - sample_Tensor = self.build_training_data(samples=samples) - obs_batch = sample_Tensor['obs'] - ret_batch = sample_Tensor['returns'] - old_dist = merge_distributions(sample_Tensor['aux_batch']['old_dist']) + obs_batch = torch.as_tensor(samples['obs'], device=self.device) + ret_batch = torch.as_tensor(samples['returns'], device=self.device) + old_dist = merge_distributions(samples['aux_batch']['old_dist']) outputs, a_dist, v, aux_v = self.policy(obs_batch) aux_loss = self.mse_loss(v.detach(), aux_v) diff --git a/xuance/torch/learners/policy_gradient/ppoclip_learner.py b/xuance/torch/learners/policy_gradient/ppoclip_learner.py index d88fd385..f3ec79b3 100644 --- a/xuance/torch/learners/policy_gradient/ppoclip_learner.py +++ b/xuance/torch/learners/policy_gradient/ppoclip_learner.py @@ -24,12 +24,11 @@ def __init__(self, def update(self, **samples): self.iterations += 1 - sample_Tensor = self.build_training_data(samples=samples) - obs_batch = sample_Tensor['obs'] - act_batch = sample_Tensor['actions'] - ret_batch = sample_Tensor['returns'] - adv_batch = sample_Tensor['advantages'] - old_logp_batch = sample_Tensor['aux_batch']['old_logp'] + obs_batch = torch.as_tensor(samples['obs'], device=self.device) + act_batch = torch.as_tensor(samples['actions'], device=self.device) + ret_batch = torch.as_tensor(samples['returns'], device=self.device) + adv_batch = torch.as_tensor(samples['advantages'], device=self.device) + old_logp_batch = samples['aux_batch']['old_logp'] outputs, a_dist, v_pred = self.policy(obs_batch) log_prob = a_dist.log_prob(act_batch) diff --git a/xuance/torch/learners/policy_gradient/ppokl_learner.py b/xuance/torch/learners/policy_gradient/ppokl_learner.py index 9847ecfe..5c2f9a22 100644 --- a/xuance/torch/learners/policy_gradient/ppokl_learner.py +++ b/xuance/torch/learners/policy_gradient/ppokl_learner.py @@ -27,11 +27,10 @@ def __init__(self, def update(self, **samples): self.iterations += 1 - sample_Tensor = self.build_training_data(samples=samples) - obs_batch = sample_Tensor['obs'] - act_batch = sample_Tensor['actions'] - ret_batch = sample_Tensor['returns'] - adv_batch = sample_Tensor['advantages'] + obs_batch = torch.as_tensor(samples['obs'], device=self.device) + act_batch = torch.as_tensor(samples['actions'], device=self.device) + ret_batch = torch.as_tensor(samples['returns'], device=self.device) + adv_batch = torch.as_tensor(samples['advantages'], device=self.device) old_dists = samples['aux_batch']['old_dist'] _, a_dist, v_pred = self.policy(obs_batch) diff --git a/xuance/torch/learners/policy_gradient/sac_learner.py b/xuance/torch/learners/policy_gradient/sac_learner.py index 079e68bf..ebb4dbf9 100644 --- a/xuance/torch/learners/policy_gradient/sac_learner.py +++ b/xuance/torch/learners/policy_gradient/sac_learner.py @@ -37,12 +37,11 @@ def __init__(self, def update(self, **samples): self.iterations += 1 - sample_Tensor = self.build_training_data(samples=samples) - obs_batch = sample_Tensor['obs'] - act_batch = sample_Tensor['actions'] - next_batch = sample_Tensor['obs_next'] - rew_batch = sample_Tensor['rewards'] - ter_batch = sample_Tensor['terminals'] + obs_batch = torch.as_tensor(samples['obs'], device=self.device) + act_batch = torch.as_tensor(samples['actions'], device=self.device) + next_batch = torch.as_tensor(samples['obs_next'], device=self.device) + rew_batch = torch.as_tensor(samples['rewards'], device=self.device) + ter_batch = torch.as_tensor(samples['terminals'], device=self.device) # actor update log_pi, policy_q_1, policy_q_2 = self.policy.Qpolicy(obs_batch) diff --git a/xuance/torch/learners/policy_gradient/sacdis_learner.py b/xuance/torch/learners/policy_gradient/sacdis_learner.py index 09370fe3..e955cfba 100644 --- a/xuance/torch/learners/policy_gradient/sacdis_learner.py +++ b/xuance/torch/learners/policy_gradient/sacdis_learner.py @@ -37,12 +37,11 @@ def __init__(self, def update(self, **samples): self.iterations += 1 - sample_Tensor = self.build_training_data(samples=samples) - obs_batch = sample_Tensor['obs'] - act_batch = sample_Tensor['actions'].unsqueeze(-1) - next_batch = sample_Tensor['obs_next'] - rew_batch = sample_Tensor['rewards'].unsqueeze(-1) - ter_batch = sample_Tensor['terminals'].reshape([-1, 1]) + obs_batch = torch.as_tensor(samples['obs'], device=self.device) + act_batch = torch.as_tensor(samples['actions'], device=self.device).unsqueeze(-1) + next_batch = torch.as_tensor(samples['obs_next'], device=self.device) + rew_batch = torch.as_tensor(samples['rewards'], device=self.device).unsqueeze(-1) + ter_batch = torch.as_tensor(samples['terminals'], device=self.device).reshape([-1, 1]) # actor update action_prob, log_pi, policy_q_1, policy_q_2 = self.policy.Qpolicy(obs_batch) diff --git a/xuance/torch/learners/policy_gradient/spdqn_learner.py b/xuance/torch/learners/policy_gradient/spdqn_learner.py index 39bb9069..a3309c20 100644 --- a/xuance/torch/learners/policy_gradient/spdqn_learner.py +++ b/xuance/torch/learners/policy_gradient/spdqn_learner.py @@ -28,12 +28,11 @@ def __init__(self, def update(self, **samples): self.iterations += 1 - sample_Tensor = self.build_training_data(samples=samples) - obs_batch = sample_Tensor['obs'] - hyact_batch = sample_Tensor['actions'] - rew_batch = sample_Tensor['rewards'] - next_batch = sample_Tensor['obs_next'] - ter_batch = sample_Tensor['terminals'] + obs_batch = torch.as_tensor(samples['obs'], device=self.device) + hyact_batch = torch.as_tensor(samples['actions'], device=self.device) + next_batch = torch.as_tensor(samples['obs_next'], device=self.device) + rew_batch = torch.as_tensor(samples['rewards'], device=self.device) + ter_batch = torch.as_tensor(samples['terminals'], device=self.device) disact_batch = hyact_batch[:, 0].long() conact_batch = hyact_batch[:, 1:] diff --git a/xuance/torch/learners/policy_gradient/td3_learner.py b/xuance/torch/learners/policy_gradient/td3_learner.py index 9a1fdf71..a1e8920e 100644 --- a/xuance/torch/learners/policy_gradient/td3_learner.py +++ b/xuance/torch/learners/policy_gradient/td3_learner.py @@ -30,12 +30,11 @@ def __init__(self, def update(self, **samples): self.iterations += 1 info = {} - sample_Tensor = self.build_training_data(samples=samples) - obs_batch = sample_Tensor['obs'] - act_batch = sample_Tensor['actions'] - next_batch = sample_Tensor['obs_next'] - rew_batch = sample_Tensor['rewards'] - ter_batch = sample_Tensor['terminals'] + obs_batch = torch.as_tensor(samples['obs'], device=self.device) + act_batch = torch.as_tensor(samples['actions'], device=self.device) + next_batch = torch.as_tensor(samples['obs_next'], device=self.device) + rew_batch = torch.as_tensor(samples['rewards'], device=self.device) + ter_batch = torch.as_tensor(samples['terminals'], device=self.device) # critic update action_q_A, action_q_B = self.policy.Qaction(obs_batch, act_batch) diff --git a/xuance/torch/learners/qlearning_family/c51_learner.py b/xuance/torch/learners/qlearning_family/c51_learner.py index 70a48684..2f854334 100644 --- a/xuance/torch/learners/qlearning_family/c51_learner.py +++ b/xuance/torch/learners/qlearning_family/c51_learner.py @@ -23,12 +23,11 @@ def __init__(self, def update(self, **samples): self.iterations += 1 - sample_Tensor = self.build_training_data(samples=samples) - obs_batch = sample_Tensor['obs'] - act_batch = sample_Tensor['actions'] - next_batch = sample_Tensor['obs_next'] - rew_batch = sample_Tensor['rewards'] - ter_batch = sample_Tensor['terminals'] + obs_batch = torch.as_tensor(samples['obs'], device=self.device) + act_batch = torch.as_tensor(samples['actions'], device=self.device) + next_batch = torch.as_tensor(samples['obs_next'], device=self.device) + rew_batch = torch.as_tensor(samples['rewards'], device=self.device) + ter_batch = torch.as_tensor(samples['terminals'], device=self.device) _, _, evalZ = self.policy(obs_batch) _, targetA, targetZ = self.policy.target(next_batch) diff --git a/xuance/torch/learners/qlearning_family/ddqn_learner.py b/xuance/torch/learners/qlearning_family/ddqn_learner.py index fcfd5200..a15dd114 100644 --- a/xuance/torch/learners/qlearning_family/ddqn_learner.py +++ b/xuance/torch/learners/qlearning_family/ddqn_learner.py @@ -25,12 +25,11 @@ def __init__(self, def update(self, **samples): self.iterations += 1 - sample_Tensor = self.build_training_data(samples=samples) - obs_batch = sample_Tensor['obs'] - act_batch = sample_Tensor['actions'] - next_batch = sample_Tensor['obs_next'] - rew_batch = sample_Tensor['rewards'] - ter_batch = sample_Tensor['terminals'] + obs_batch = torch.as_tensor(samples['obs'], device=self.device) + act_batch = torch.as_tensor(samples['actions'], device=self.device) + next_batch = torch.as_tensor(samples['obs_next'], device=self.device) + rew_batch = torch.as_tensor(samples['rewards'], device=self.device) + ter_batch = torch.as_tensor(samples['terminals'], device=self.device) _, _, evalQ = self.policy(obs_batch) _, targetA, targetQ = self.policy.target(next_batch) diff --git a/xuance/torch/learners/qlearning_family/dqn_learner.py b/xuance/torch/learners/qlearning_family/dqn_learner.py index 6e8d019c..8c9aaae5 100644 --- a/xuance/torch/learners/qlearning_family/dqn_learner.py +++ b/xuance/torch/learners/qlearning_family/dqn_learner.py @@ -25,12 +25,11 @@ def __init__(self, def update(self, **samples): self.iterations += 1 - sample_Tensor = self.build_training_data(samples=samples) - obs_batch = sample_Tensor['obs'] - act_batch = sample_Tensor['actions'] - next_batch = sample_Tensor['obs_next'] - rew_batch = sample_Tensor['rewards'] - ter_batch = sample_Tensor['terminals'] + obs_batch = torch.as_tensor(samples['obs'], device=self.device) + act_batch = torch.as_tensor(samples['actions'], device=self.device) + next_batch = torch.as_tensor(samples['obs_next'], device=self.device) + rew_batch = torch.as_tensor(samples['rewards'], device=self.device) + ter_batch = torch.as_tensor(samples['terminals'], device=self.device) _, _, evalQ = self.policy(obs_batch) _, _, targetQ = self.policy.target(next_batch) diff --git a/xuance/torch/learners/qlearning_family/drqn_learner.py b/xuance/torch/learners/qlearning_family/drqn_learner.py index f02500bd..d895ba3d 100644 --- a/xuance/torch/learners/qlearning_family/drqn_learner.py +++ b/xuance/torch/learners/qlearning_family/drqn_learner.py @@ -25,12 +25,11 @@ def __init__(self, def update(self, **samples): self.iterations += 1 - sample_Tensor = self.build_training_data(samples=samples) - obs_batch = sample_Tensor['obs'] - act_batch = sample_Tensor['actions'] - rew_batch = sample_Tensor['rewards'] - ter_batch = sample_Tensor['terminals'] - batch_size = sample_Tensor['batch_size'] + obs_batch = torch.as_tensor(samples['obs'], device=self.device) + act_batch = torch.as_tensor(samples['actions'], device=self.device) + rew_batch = torch.as_tensor(samples['rewards'], device=self.device) + ter_batch = torch.as_tensor(samples['terminals'], device=self.device) + batch_size = samples['batch_size'] rnn_hidden = self.policy.init_hidden(batch_size) _, _, evalQ, _ = self.policy(obs_batch[:, 0:-1], *rnn_hidden) diff --git a/xuance/torch/learners/qlearning_family/dueldqn_learner.py b/xuance/torch/learners/qlearning_family/dueldqn_learner.py index 7cf53bdf..0f348e1d 100644 --- a/xuance/torch/learners/qlearning_family/dueldqn_learner.py +++ b/xuance/torch/learners/qlearning_family/dueldqn_learner.py @@ -25,12 +25,11 @@ def __init__(self, def update(self, **samples): self.iterations += 1 - sample_Tensor = self.build_training_data(samples=samples) - obs_batch = sample_Tensor['obs'] - act_batch = sample_Tensor['actions'] - next_batch = sample_Tensor['obs_next'] - rew_batch = sample_Tensor['rewards'] - ter_batch = sample_Tensor['terminals'] + obs_batch = torch.as_tensor(samples['obs'], device=self.device) + act_batch = torch.as_tensor(samples['actions'], device=self.device) + next_batch = torch.as_tensor(samples['obs_next'], device=self.device) + rew_batch = torch.as_tensor(samples['rewards'], device=self.device) + ter_batch = torch.as_tensor(samples['terminals'], device=self.device) _, _, evalQ = self.policy(obs_batch) _, _, targetQ = self.policy.target(next_batch) diff --git a/xuance/torch/learners/qlearning_family/perdqn_learner.py b/xuance/torch/learners/qlearning_family/perdqn_learner.py index 353fb2f1..22c05626 100644 --- a/xuance/torch/learners/qlearning_family/perdqn_learner.py +++ b/xuance/torch/learners/qlearning_family/perdqn_learner.py @@ -26,36 +26,13 @@ def __init__(self, self.one_hot = nn.functional.one_hot self.n_actions = self.policy.action_dim - def build_training_data(self, samples: Optional[dict]): - batch_size = samples['batch_size'] - samples_Tensor = {} - if self.world_size > 1: # i.e., Multi-GPU settings. - rank = int(os.environ['RANK']) - batch_size_local = batch_size // self.world_size - if rank < self.world_size - 1: - indices = range(rank * batch_size_local, (rank + 1) * batch_size_local) - else: - indices = range(rank * batch_size_local, batch_size) - for k, v in samples.items(): - if k in ['batch_size', 'weights', 'step_choices']: - continue - samples_Tensor[k] = torch.as_tensor(v[indices], device=self.device) - else: - for k, v in samples.items(): - if k in ['batch_size', 'weights', 'step_choices']: - continue - samples_Tensor[k] = torch.as_tensor(v, device=self.device) - - return samples_Tensor - def update(self, **samples): self.iterations += 1 - sample_Tensor = self.build_training_data(samples=samples) - obs_batch = sample_Tensor['obs'] - act_batch = sample_Tensor['actions'] - next_batch = sample_Tensor['obs_next'] - rew_batch = sample_Tensor['rewards'] - ter_batch = sample_Tensor['terminals'] + obs_batch = torch.as_tensor(samples['obs'], device=self.device) + act_batch = torch.as_tensor(samples['actions'], device=self.device) + next_batch = torch.as_tensor(samples['obs_next'], device=self.device) + rew_batch = torch.as_tensor(samples['rewards'], device=self.device) + ter_batch = torch.as_tensor(samples['terminals'], device=self.device) _, _, evalQ = self.policy(obs_batch) _, _, targetQ = self.policy.target(next_batch) diff --git a/xuance/torch/learners/qlearning_family/qrdqn_learner.py b/xuance/torch/learners/qlearning_family/qrdqn_learner.py index 825c2474..4a99edbf 100644 --- a/xuance/torch/learners/qlearning_family/qrdqn_learner.py +++ b/xuance/torch/learners/qlearning_family/qrdqn_learner.py @@ -25,12 +25,11 @@ def __init__(self, def update(self, **samples): self.iterations += 1 - sample_Tensor = self.build_training_data(samples=samples) - obs_batch = sample_Tensor['obs'] - act_batch = sample_Tensor['actions'] - next_batch = sample_Tensor['obs_next'] - rew_batch = sample_Tensor['rewards'] - ter_batch = sample_Tensor['terminals'] + obs_batch = torch.as_tensor(samples['obs'], device=self.device) + act_batch = torch.as_tensor(samples['actions'], device=self.device) + next_batch = torch.as_tensor(samples['obs_next'], device=self.device) + rew_batch = torch.as_tensor(samples['rewards'], device=self.device) + ter_batch = torch.as_tensor(samples['terminals'], device=self.device) _, _, evalZ = self.policy(obs_batch) _, targetA, targetZ = self.policy(next_batch)