diff --git a/xuance/common/common_tools.py b/xuance/common/common_tools.py index 3eaa831d..a1f18986 100644 --- a/xuance/common/common_tools.py +++ b/xuance/common/common_tools.py @@ -200,7 +200,8 @@ def get_runner(method, raise AttributeError("Cannot find a deep learning toolbox named " + dl_toolbox) if distributed_training: - print(f"Calculating device: Multi-GPU distributed training.") + if rank == 0: + print(f"Calculating device: Multi-GPU distributed training.") else: print(f"Calculating device: {device}") diff --git a/xuance/torch/policies/categorical.py b/xuance/torch/policies/categorical.py index 50165aeb..e5fad0fa 100644 --- a/xuance/torch/policies/categorical.py +++ b/xuance/torch/policies/categorical.py @@ -53,7 +53,8 @@ def __init__(self, self.distributed_training = use_distributed_training if self.distributed_training: self.rank = int(os.environ["RANK"]) - self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank]) + if self.representation._get_name() != "Basic_Identical": + self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank]) self.actor = DistributedDataParallel(module=self.actor, device_ids=[self.rank]) def forward(self, observation: Union[np.ndarray, dict]): @@ -111,7 +112,8 @@ def __init__(self, self.distributed_training = use_distributed_training if self.distributed_training: self.rank = int(os.environ["RANK"]) - self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank]) + if self.representation._get_name() != "Basic_Identical": + self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank]) self.actor = DistributedDataParallel(module=self.actor, device_ids=[self.rank]) self.critic = DistributedDataParallel(module=self.critic, device_ids=[self.rank]) @@ -177,10 +179,15 @@ def __init__(self, self.distributed_training = use_distributed_training if self.distributed_training: self.rank = int(os.environ["RANK"]) - self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank]) - self.actor = DistributedDataParallel(module=self.actor, device_ids=[self.rank]) - self.critic = DistributedDataParallel(module=self.critic, device_ids=[self.rank]) - self.aux_critic = DistributedDataParallel(module=self.aux_critic, device_ids=[self.rank]) + if self.actor_representation._get_name() != "Basic_Identical": + self.actor_representation = DistributedDataParallel(self.representation, device_ids=[self.rank]) + if self.critic_representation._get_name() != "Basic_Identical": + self.critic_representation = DistributedDataParallel(self.representation, device_ids=[self.rank]) + if self.aux_critic_representation._get_name() != "Basic_Identical": + self.aux_critic_representation = DistributedDataParallel(self.representation, device_ids=[self.rank]) + self.actor = DistributedDataParallel(self.actor, device_ids=[self.rank]) + self.critic = DistributedDataParallel(self.critic, device_ids=[self.rank]) + self.aux_critic = DistributedDataParallel(self.aux_critic, device_ids=[self.rank]) def forward(self, observation: Union[np.ndarray, dict]): """ @@ -258,9 +265,14 @@ def __init__(self, self.distributed_training = use_distributed_training if self.distributed_training: self.rank = int(os.environ["RANK"]) - self.actor_representation = DistributedDataParallel(self.actor_representation, device_ids=[self.rank]) - self.critic_1_representation = DistributedDataParallel(self.critic_1_representation, device_ids=[self.rank]) - self.critic_2_representation = DistributedDataParallel(self.critic_2_representation, device_ids=[self.rank]) + if self.actor_representation._get_name() != "Basic_Identical": + self.actor_representation = DistributedDataParallel(self.actor_representation, device_ids=[self.rank]) + if self.critic_1_representation._get_name() != "Basic_Identical": + self.critic_1_representation = DistributedDataParallel(self.critic_1_representation, + device_ids=[self.rank]) + if self.critic_2_representation._get_name() != "Basic_Identical": + self.critic_2_representation = DistributedDataParallel(self.critic_2_representation, + device_ids=[self.rank]) self.actor = DistributedDataParallel(module=self.actor, device_ids=[self.rank]) self.critic_1 = DistributedDataParallel(module=self.critic_1, device_ids=[self.rank]) self.critic_2 = DistributedDataParallel(module=self.critic_2, device_ids=[self.rank]) diff --git a/xuance/torch/policies/categorical_marl.py b/xuance/torch/policies/categorical_marl.py index ed8396c8..f66a6703 100644 --- a/xuance/torch/policies/categorical_marl.py +++ b/xuance/torch/policies/categorical_marl.py @@ -77,10 +77,12 @@ def __init__(self, if self.distributed_training: self.rank = int(os.environ["RANK"]) for key in self.model_keys: - self.actor_representation[key] = DistributedDataParallel(module=self.actor_representation[key], - device_ids=[self.rank]) - self.critic_representation[key] = DistributedDataParallel(module=self.critic_representation[key], - device_ids=[self.rank]) + if self.actor_representation[key]._get_name() != "Basic_Identical": + self.actor_representation[key] = DistributedDataParallel(self.actor_representation[key], + device_ids=[self.rank]) + if self.critic_representation[key]._get_name() != "Basic_Identical": + self.critic_representation[key] = DistributedDataParallel(self.critic_representation[key], + device_ids=[self.rank]) self.actor[key] = DistributedDataParallel(module=self.actor[key], device_ids=[self.rank]) self.critic[key] = DistributedDataParallel(module=self.critic[key], device_ids=[self.rank]) if self.mixer is not None: diff --git a/xuance/torch/policies/deterministic.py b/xuance/torch/policies/deterministic.py index 7a5b8b68..620ffe44 100644 --- a/xuance/torch/policies/deterministic.py +++ b/xuance/torch/policies/deterministic.py @@ -47,7 +47,8 @@ def __init__(self, self.distributed_training = use_distributed_training if self.distributed_training: self.rank = int(os.environ["RANK"]) - self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank]) + if self.representation._get_name() != "Basic_Identical": + self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank]) self.eval_Qhead = DistributedDataParallel(module=self.eval_Qhead, device_ids=[self.rank]) def forward(self, observation: Union[np.ndarray, dict]): @@ -127,7 +128,8 @@ def __init__(self, self.distributed_training = use_distributed_training if self.distributed_training: self.rank = int(os.environ["RANK"]) - self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank]) + if self.representation._get_name() != "Basic_Identical": + self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank]) self.eval_Qhead = DistributedDataParallel(module=self.eval_Qhead, device_ids=[self.rank]) def forward(self, observation: Union[np.ndarray, dict]): @@ -210,7 +212,8 @@ def __init__(self, self.distributed_training = use_distributed_training if self.distributed_training: self.rank = int(os.environ["RANK"]) - self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank]) + if self.representation._get_name() != "Basic_Identical": + self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank]) self.eval_Qhead = DistributedDataParallel(module=self.eval_Qhead, device_ids=[self.rank]) def update_noise(self, noisy_bound: float = 0.0): @@ -316,7 +319,8 @@ def __init__(self, self.distributed_training = use_distributed_training if self.distributed_training: self.rank = int(os.environ["RANK"]) - self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank]) + if self.representation._get_name() != "Basic_Identical": + self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank]) self.eval_Zhead = DistributedDataParallel(module=self.eval_Zhead, device_ids=[self.rank]) def forward(self, observation: Union[np.ndarray, dict]): @@ -401,7 +405,8 @@ def __init__(self, self.distributed_training = use_distributed_training if self.distributed_training: self.rank = int(os.environ["RANK"]) - self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank]) + if self.representation._get_name() != "Basic_Identical": + self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank]) self.eval_Zhead = DistributedDataParallel(module=self.eval_Zhead, device_ids=[self.rank]) def forward(self, observation: Union[np.ndarray, dict]): @@ -499,9 +504,11 @@ def __init__(self, self.distributed_training = use_distributed_training if self.distributed_training: self.rank = int(os.environ["RANK"]) - self.actor_representation = DistributedDataParallel(self.actor_representation, device_ids=[self.rank]) + if self.actor_representation._get_name() != "Basic_Identical": + self.actor_representation = DistributedDataParallel(self.actor_representation, device_ids=[self.rank]) + if self.critic_representation._get_name() != "Basic_Identical": + self.critic_representation = DistributedDataParallel(self.critic_representation, device_ids=[self.rank]) self.actor = DistributedDataParallel(module=self.actor, device_ids=[self.rank]) - self.critic_representation = DistributedDataParallel(self.critic_representation, device_ids=[self.rank]) self.critic = DistributedDataParallel(module=self.critic, device_ids=[self.rank]) def forward(self, observation: Union[np.ndarray, dict]): @@ -616,9 +623,14 @@ def __init__(self, self.distributed_training = use_distributed_training if self.distributed_training: self.rank = int(os.environ["RANK"]) - self.actor_representation = DistributedDataParallel(self.actor_representation, device_ids=[self.rank]) - self.critic_A_representation = DistributedDataParallel(self.critic_A_representation, device_ids=[self.rank]) - self.critic_B_representation = DistributedDataParallel(self.critic_B_representation, device_ids=[self.rank]) + if self.actor_representation._get_name() != "Basic_Identical": + self.actor_representation = DistributedDataParallel(self.actor_representation, device_ids=[self.rank]) + if self.critic_A_representation._get_name() != "Basic_Identical": + self.critic_A_representation = DistributedDataParallel(self.critic_A_representation, + device_ids=[self.rank]) + if self.critic_B_representation._get_name() != "Basic_Identical": + self.critic_B_representation = DistributedDataParallel(self.critic_B_representation, + device_ids=[self.rank]) self.actor = DistributedDataParallel(module=self.actor, device_ids=[self.rank]) self.critic_A = DistributedDataParallel(module=self.critic_A, device_ids=[self.rank]) self.critic_B = DistributedDataParallel(module=self.critic_B, device_ids=[self.rank]) @@ -741,7 +753,8 @@ def __init__(self, self.distributed_training = use_distributed_training if self.distributed_training: self.rank = int(os.environ["RANK"]) - self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank]) + if self.representation._get_name() != "Basic_Identical": + self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank]) self.qnetwork = DistributedDataParallel(module=self.qnetwork, device_ids=[self.rank]) self.conactor = DistributedDataParallel(module=self.conactor, device_ids=[self.rank]) @@ -993,7 +1006,8 @@ def __init__(self, self.distributed_training = kwargs['use_distributed_training'] if self.distributed_training: self.rank = int(os.environ["RANK"]) - self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank]) + if self.representation._get_name() != "Basic_Identical": + self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank]) self.eval_Qhead = DistributedDataParallel(module=self.eval_Qhead, device_ids=[self.rank]) def forward(self, observation: Union[np.ndarray, dict], *rnn_hidden: Tensor): diff --git a/xuance/torch/policies/deterministic_marl.py b/xuance/torch/policies/deterministic_marl.py index 0f7492ab..4ad2fe1a 100644 --- a/xuance/torch/policies/deterministic_marl.py +++ b/xuance/torch/policies/deterministic_marl.py @@ -66,8 +66,9 @@ def __init__(self, if self.distributed_training: self.rank = int(os.environ["RANK"]) for key in self.model_keys: - self.representation[key] = DistributedDataParallel(module=self.representation[key], - device_ids=[self.rank]) + if self.representation[key]._get_name() != "Basic_Identical": + self.representation[key] = DistributedDataParallel(module=self.representation[key], + device_ids=[self.rank]) self.eval_Qhead[key] = DistributedDataParallel(module=self.eval_Qhead[key], device_ids=[self.rank]) @property @@ -716,10 +717,12 @@ def __init__(self, if self.distributed_training: self.rank = int(os.environ["RANK"]) for key in self.model_keys: - self.actor_representation[key] = DistributedDataParallel(module=self.actor_representation[key], - device_ids=[self.rank]) - self.critic_representation[key] = DistributedDataParallel(module=self.critic_representation[key], - device_ids=[self.rank]) + if self.actor_representation[key]._get_name() != "Basic_Identical": + self.actor_representation[key] = DistributedDataParallel(module=self.actor_representation[key], + device_ids=[self.rank]) + if self.critic_representation[key]._get_name() != "Basic_Identical": + self.critic_representation[key] = DistributedDataParallel(module=self.critic_representation[key], + device_ids=[self.rank]) self.actor[key] = DistributedDataParallel(module=self.actor[key], device_ids=[self.rank]) self.critic[key] = DistributedDataParallel(module=self.critic[key], device_ids=[self.rank]) @@ -1141,12 +1144,15 @@ def __init__(self, if self.distributed_training: self.rank = int(os.environ["RANK"]) for key in self.model_keys: - self.actor_representation[key] = DistributedDataParallel(module=self.actor_representation[key], - device_ids=[self.rank]) - self.critic_A_representation[key] = DistributedDataParallel(module=self.critic_A_representation[key], - device_ids=[self.rank]) - self.critic_B_representation[key] = DistributedDataParallel(module=self.critic_B_representation[key], - device_ids=[self.rank]) + if self.actor_representation[key]._get_name() != "Basic_Identical": + self.actor_representation[key] = DistributedDataParallel(self.actor_representation[key], + device_ids=[self.rank]) + if self.critic_A_representation[key]._get_name() != "Basic_Identical": + self.critic_A_representation[key] = DistributedDataParallel(self.critic_A_representation[key], + device_ids=[self.rank]) + if self.critic_B_representation[key]._get_name() != "Basic_Identical": + self.critic_B_representation[key] = DistributedDataParallel(self.critic_B_representation[key], + device_ids=[self.rank]) self.actor[key] = DistributedDataParallel(module=self.actor[key], device_ids=[self.rank]) self.critic_A[key] = DistributedDataParallel(module=self.critic_A[key], device_ids=[self.rank]) self.critic_B[key] = DistributedDataParallel(module=self.critic_B[key], device_ids=[self.rank]) diff --git a/xuance/torch/policies/gaussian.py b/xuance/torch/policies/gaussian.py index 7634a482..575df473 100644 --- a/xuance/torch/policies/gaussian.py +++ b/xuance/torch/policies/gaussian.py @@ -48,7 +48,8 @@ def __init__(self, self.distributed_training = use_distributed_training if self.distributed_training: self.rank = int(os.environ["RANK"]) - self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank]) + if self.representation._get_name() != "Basic_Identical": + self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank]) self.actor = DistributedDataParallel(module=self.actor, device_ids=[self.rank]) def forward(self, observation: Union[np.ndarray, dict]): @@ -108,7 +109,8 @@ def __init__(self, self.distributed_training = use_distributed_training if self.distributed_training: self.rank = int(os.environ["RANK"]) - self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank]) + if self.representation._get_name() != "Basic_Identical": + self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank]) self.actor = DistributedDataParallel(module=self.actor, device_ids=[self.rank]) self.critic = DistributedDataParallel(module=self.critic, device_ids=[self.rank]) @@ -174,7 +176,8 @@ def __init__(self, self.distributed_training = use_distributed_training if self.distributed_training: self.rank = int(os.environ["RANK"]) - self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank]) + if self.representation._get_name() != "Basic_Identical": + self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank]) self.actor = DistributedDataParallel(module=self.actor, device_ids=[self.rank]) self.critic = DistributedDataParallel(module=self.critic, device_ids=[self.rank]) self.aux_critic = DistributedDataParallel(module=self.aux_critic, device_ids=[self.rank]) @@ -256,9 +259,14 @@ def __init__(self, self.distributed_training = use_distributed_training if self.distributed_training: self.rank = int(os.environ["RANK"]) - self.actor_representation = DistributedDataParallel(self.actor_representation, device_ids=[self.rank]) - self.critic_1_representation = DistributedDataParallel(self.critic_1_representation, device_ids=[self.rank]) - self.critic_2_representation = DistributedDataParallel(self.critic_2_representation, device_ids=[self.rank]) + if self.actor_representation._get_name() != "Basic_Identical": + self.actor_representation = DistributedDataParallel(self.actor_representation, device_ids=[self.rank]) + if self.critic_1_representation._get_name() != "Basic_Identical": + self.critic_1_representation = DistributedDataParallel(self.critic_1_representation, + device_ids=[self.rank]) + if self.critic_2_representation._get_name() != "Basic_Identical": + self.critic_2_representation = DistributedDataParallel(self.critic_2_representation, + device_ids=[self.rank]) self.actor = DistributedDataParallel(module=self.actor, device_ids=[self.rank]) self.critic_1 = DistributedDataParallel(module=self.critic_1, device_ids=[self.rank]) self.critic_2 = DistributedDataParallel(module=self.critic_2, device_ids=[self.rank]) diff --git a/xuance/torch/policies/gaussian_marl.py b/xuance/torch/policies/gaussian_marl.py index 530da25e..b9ffcfe2 100644 --- a/xuance/torch/policies/gaussian_marl.py +++ b/xuance/torch/policies/gaussian_marl.py @@ -72,10 +72,12 @@ def __init__(self, if self.distributed_training: self.rank = int(os.environ["RANK"]) for key in self.model_keys: - self.actor_representation[key] = DistributedDataParallel(module=self.actor_representation[key], - device_ids=[self.rank]) - self.critic_representation[key] = DistributedDataParallel(module=self.critic_representation[key], - device_ids=[self.rank]) + if self.actor_representation[key]._get_name() != "Basic_Identical": + self.actor_representation[key] = DistributedDataParallel(module=self.actor_representation[key], + device_ids=[self.rank]) + if self.critic_representation[key]._get_name() != "Basic_Identical": + self.critic_representation[key] = DistributedDataParallel(module=self.critic_representation[key], + device_ids=[self.rank]) self.actor[key] = DistributedDataParallel(module=self.actor[key], device_ids=[self.rank]) self.critic[key] = DistributedDataParallel(module=self.critic[key], device_ids=[self.rank]) @@ -246,12 +248,15 @@ def __init__(self, if self.distributed_training: self.rank = int(os.environ["RANK"]) for key in self.model_keys: - self.actor_representation[key] = DistributedDataParallel(module=self.actor_representation[key], - device_ids=[self.rank]) - self.critic_1_representation[key] = DistributedDataParallel(module=self.critic_1_representation[key], - device_ids=[self.rank]) - self.critic_2_representation[key] = DistributedDataParallel(module=self.critic_2_representation[key], - device_ids=[self.rank]) + if self.actor_representation[key]._get_name() != "Basic_Identical": + self.actor_representation[key] = DistributedDataParallel(self.actor_representation[key], + device_ids=[self.rank]) + if self.critic_1_representation[key]._get_name() != "Basic_Identical": + self.critic_1_representation[key] = DistributedDataParallel(self.critic_1_representation[key], + device_ids=[self.rank]) + if self.critic_2_representation[key]._get_name() != "Basic_Identical": + self.critic_2_representation[key] = DistributedDataParallel(self.critic_2_representation[key], + device_ids=[self.rank]) self.actor[key] = DistributedDataParallel(module=self.actor[key], device_ids=[self.rank]) self.critic_1[key] = DistributedDataParallel(module=self.critic_1[key], device_ids=[self.rank]) self.critic_2[key] = DistributedDataParallel(module=self.critic_2[key], device_ids=[self.rank])