Skip to content

Commit

Permalink
torch ddp basic identical rep
Browse files Browse the repository at this point in the history
  • Loading branch information
wenzhangliu committed Oct 4, 2024
1 parent 3ed2781 commit edb9798
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 54 deletions.
3 changes: 2 additions & 1 deletion xuance/common/common_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ def get_runner(method,
raise AttributeError("Cannot find a deep learning toolbox named " + dl_toolbox)

if distributed_training:
print(f"Calculating device: Multi-GPU distributed training.")
if rank == 0:
print(f"Calculating device: Multi-GPU distributed training.")
else:
print(f"Calculating device: {device}")

Expand Down
30 changes: 21 additions & 9 deletions xuance/torch/policies/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def __init__(self,
self.distributed_training = use_distributed_training
if self.distributed_training:
self.rank = int(os.environ["RANK"])
self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank])
if self.representation._get_name() != "Basic_Identical":
self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank])
self.actor = DistributedDataParallel(module=self.actor, device_ids=[self.rank])

def forward(self, observation: Union[np.ndarray, dict]):
Expand Down Expand Up @@ -111,7 +112,8 @@ def __init__(self,
self.distributed_training = use_distributed_training
if self.distributed_training:
self.rank = int(os.environ["RANK"])
self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank])
if self.representation._get_name() != "Basic_Identical":
self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank])
self.actor = DistributedDataParallel(module=self.actor, device_ids=[self.rank])
self.critic = DistributedDataParallel(module=self.critic, device_ids=[self.rank])

Expand Down Expand Up @@ -177,10 +179,15 @@ def __init__(self,
self.distributed_training = use_distributed_training
if self.distributed_training:
self.rank = int(os.environ["RANK"])
self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank])
self.actor = DistributedDataParallel(module=self.actor, device_ids=[self.rank])
self.critic = DistributedDataParallel(module=self.critic, device_ids=[self.rank])
self.aux_critic = DistributedDataParallel(module=self.aux_critic, device_ids=[self.rank])
if self.actor_representation._get_name() != "Basic_Identical":
self.actor_representation = DistributedDataParallel(self.representation, device_ids=[self.rank])
if self.critic_representation._get_name() != "Basic_Identical":
self.critic_representation = DistributedDataParallel(self.representation, device_ids=[self.rank])
if self.aux_critic_representation._get_name() != "Basic_Identical":
self.aux_critic_representation = DistributedDataParallel(self.representation, device_ids=[self.rank])
self.actor = DistributedDataParallel(self.actor, device_ids=[self.rank])
self.critic = DistributedDataParallel(self.critic, device_ids=[self.rank])
self.aux_critic = DistributedDataParallel(self.aux_critic, device_ids=[self.rank])

def forward(self, observation: Union[np.ndarray, dict]):
"""
Expand Down Expand Up @@ -258,9 +265,14 @@ def __init__(self,
self.distributed_training = use_distributed_training
if self.distributed_training:
self.rank = int(os.environ["RANK"])
self.actor_representation = DistributedDataParallel(self.actor_representation, device_ids=[self.rank])
self.critic_1_representation = DistributedDataParallel(self.critic_1_representation, device_ids=[self.rank])
self.critic_2_representation = DistributedDataParallel(self.critic_2_representation, device_ids=[self.rank])
if self.actor_representation._get_name() != "Basic_Identical":
self.actor_representation = DistributedDataParallel(self.actor_representation, device_ids=[self.rank])
if self.critic_1_representation._get_name() != "Basic_Identical":
self.critic_1_representation = DistributedDataParallel(self.critic_1_representation,
device_ids=[self.rank])
if self.critic_2_representation._get_name() != "Basic_Identical":
self.critic_2_representation = DistributedDataParallel(self.critic_2_representation,
device_ids=[self.rank])
self.actor = DistributedDataParallel(module=self.actor, device_ids=[self.rank])
self.critic_1 = DistributedDataParallel(module=self.critic_1, device_ids=[self.rank])
self.critic_2 = DistributedDataParallel(module=self.critic_2, device_ids=[self.rank])
Expand Down
10 changes: 6 additions & 4 deletions xuance/torch/policies/categorical_marl.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,12 @@ def __init__(self,
if self.distributed_training:
self.rank = int(os.environ["RANK"])
for key in self.model_keys:
self.actor_representation[key] = DistributedDataParallel(module=self.actor_representation[key],
device_ids=[self.rank])
self.critic_representation[key] = DistributedDataParallel(module=self.critic_representation[key],
device_ids=[self.rank])
if self.actor_representation[key]._get_name() != "Basic_Identical":
self.actor_representation[key] = DistributedDataParallel(self.actor_representation[key],
device_ids=[self.rank])
if self.critic_representation[key]._get_name() != "Basic_Identical":
self.critic_representation[key] = DistributedDataParallel(self.critic_representation[key],
device_ids=[self.rank])
self.actor[key] = DistributedDataParallel(module=self.actor[key], device_ids=[self.rank])
self.critic[key] = DistributedDataParallel(module=self.critic[key], device_ids=[self.rank])
if self.mixer is not None:
Expand Down
38 changes: 26 additions & 12 deletions xuance/torch/policies/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def __init__(self,
self.distributed_training = use_distributed_training
if self.distributed_training:
self.rank = int(os.environ["RANK"])
self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank])
if self.representation._get_name() != "Basic_Identical":
self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank])
self.eval_Qhead = DistributedDataParallel(module=self.eval_Qhead, device_ids=[self.rank])

def forward(self, observation: Union[np.ndarray, dict]):
Expand Down Expand Up @@ -127,7 +128,8 @@ def __init__(self,
self.distributed_training = use_distributed_training
if self.distributed_training:
self.rank = int(os.environ["RANK"])
self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank])
if self.representation._get_name() != "Basic_Identical":
self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank])
self.eval_Qhead = DistributedDataParallel(module=self.eval_Qhead, device_ids=[self.rank])

def forward(self, observation: Union[np.ndarray, dict]):
Expand Down Expand Up @@ -210,7 +212,8 @@ def __init__(self,
self.distributed_training = use_distributed_training
if self.distributed_training:
self.rank = int(os.environ["RANK"])
self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank])
if self.representation._get_name() != "Basic_Identical":
self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank])
self.eval_Qhead = DistributedDataParallel(module=self.eval_Qhead, device_ids=[self.rank])

def update_noise(self, noisy_bound: float = 0.0):
Expand Down Expand Up @@ -316,7 +319,8 @@ def __init__(self,
self.distributed_training = use_distributed_training
if self.distributed_training:
self.rank = int(os.environ["RANK"])
self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank])
if self.representation._get_name() != "Basic_Identical":
self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank])
self.eval_Zhead = DistributedDataParallel(module=self.eval_Zhead, device_ids=[self.rank])

def forward(self, observation: Union[np.ndarray, dict]):
Expand Down Expand Up @@ -401,7 +405,8 @@ def __init__(self,
self.distributed_training = use_distributed_training
if self.distributed_training:
self.rank = int(os.environ["RANK"])
self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank])
if self.representation._get_name() != "Basic_Identical":
self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank])
self.eval_Zhead = DistributedDataParallel(module=self.eval_Zhead, device_ids=[self.rank])

def forward(self, observation: Union[np.ndarray, dict]):
Expand Down Expand Up @@ -499,9 +504,11 @@ def __init__(self,
self.distributed_training = use_distributed_training
if self.distributed_training:
self.rank = int(os.environ["RANK"])
self.actor_representation = DistributedDataParallel(self.actor_representation, device_ids=[self.rank])
if self.actor_representation._get_name() != "Basic_Identical":
self.actor_representation = DistributedDataParallel(self.actor_representation, device_ids=[self.rank])
if self.critic_representation._get_name() != "Basic_Identical":
self.critic_representation = DistributedDataParallel(self.critic_representation, device_ids=[self.rank])
self.actor = DistributedDataParallel(module=self.actor, device_ids=[self.rank])
self.critic_representation = DistributedDataParallel(self.critic_representation, device_ids=[self.rank])
self.critic = DistributedDataParallel(module=self.critic, device_ids=[self.rank])

def forward(self, observation: Union[np.ndarray, dict]):
Expand Down Expand Up @@ -616,9 +623,14 @@ def __init__(self,
self.distributed_training = use_distributed_training
if self.distributed_training:
self.rank = int(os.environ["RANK"])
self.actor_representation = DistributedDataParallel(self.actor_representation, device_ids=[self.rank])
self.critic_A_representation = DistributedDataParallel(self.critic_A_representation, device_ids=[self.rank])
self.critic_B_representation = DistributedDataParallel(self.critic_B_representation, device_ids=[self.rank])
if self.actor_representation._get_name() != "Basic_Identical":
self.actor_representation = DistributedDataParallel(self.actor_representation, device_ids=[self.rank])
if self.critic_A_representation._get_name() != "Basic_Identical":
self.critic_A_representation = DistributedDataParallel(self.critic_A_representation,
device_ids=[self.rank])
if self.critic_B_representation._get_name() != "Basic_Identical":
self.critic_B_representation = DistributedDataParallel(self.critic_B_representation,
device_ids=[self.rank])
self.actor = DistributedDataParallel(module=self.actor, device_ids=[self.rank])
self.critic_A = DistributedDataParallel(module=self.critic_A, device_ids=[self.rank])
self.critic_B = DistributedDataParallel(module=self.critic_B, device_ids=[self.rank])
Expand Down Expand Up @@ -741,7 +753,8 @@ def __init__(self,
self.distributed_training = use_distributed_training
if self.distributed_training:
self.rank = int(os.environ["RANK"])
self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank])
if self.representation._get_name() != "Basic_Identical":
self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank])
self.qnetwork = DistributedDataParallel(module=self.qnetwork, device_ids=[self.rank])
self.conactor = DistributedDataParallel(module=self.conactor, device_ids=[self.rank])

Expand Down Expand Up @@ -993,7 +1006,8 @@ def __init__(self,
self.distributed_training = kwargs['use_distributed_training']
if self.distributed_training:
self.rank = int(os.environ["RANK"])
self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank])
if self.representation._get_name() != "Basic_Identical":
self.representation = DistributedDataParallel(module=self.representation, device_ids=[self.rank])
self.eval_Qhead = DistributedDataParallel(module=self.eval_Qhead, device_ids=[self.rank])

def forward(self, observation: Union[np.ndarray, dict], *rnn_hidden: Tensor):
Expand Down
30 changes: 18 additions & 12 deletions xuance/torch/policies/deterministic_marl.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ def __init__(self,
if self.distributed_training:
self.rank = int(os.environ["RANK"])
for key in self.model_keys:
self.representation[key] = DistributedDataParallel(module=self.representation[key],
device_ids=[self.rank])
if self.representation[key]._get_name() != "Basic_Identical":
self.representation[key] = DistributedDataParallel(module=self.representation[key],
device_ids=[self.rank])
self.eval_Qhead[key] = DistributedDataParallel(module=self.eval_Qhead[key], device_ids=[self.rank])

@property
Expand Down Expand Up @@ -716,10 +717,12 @@ def __init__(self,
if self.distributed_training:
self.rank = int(os.environ["RANK"])
for key in self.model_keys:
self.actor_representation[key] = DistributedDataParallel(module=self.actor_representation[key],
device_ids=[self.rank])
self.critic_representation[key] = DistributedDataParallel(module=self.critic_representation[key],
device_ids=[self.rank])
if self.actor_representation[key]._get_name() != "Basic_Identical":
self.actor_representation[key] = DistributedDataParallel(module=self.actor_representation[key],
device_ids=[self.rank])
if self.critic_representation[key]._get_name() != "Basic_Identical":
self.critic_representation[key] = DistributedDataParallel(module=self.critic_representation[key],
device_ids=[self.rank])
self.actor[key] = DistributedDataParallel(module=self.actor[key], device_ids=[self.rank])
self.critic[key] = DistributedDataParallel(module=self.critic[key], device_ids=[self.rank])

Expand Down Expand Up @@ -1141,12 +1144,15 @@ def __init__(self,
if self.distributed_training:
self.rank = int(os.environ["RANK"])
for key in self.model_keys:
self.actor_representation[key] = DistributedDataParallel(module=self.actor_representation[key],
device_ids=[self.rank])
self.critic_A_representation[key] = DistributedDataParallel(module=self.critic_A_representation[key],
device_ids=[self.rank])
self.critic_B_representation[key] = DistributedDataParallel(module=self.critic_B_representation[key],
device_ids=[self.rank])
if self.actor_representation[key]._get_name() != "Basic_Identical":
self.actor_representation[key] = DistributedDataParallel(self.actor_representation[key],
device_ids=[self.rank])
if self.critic_A_representation[key]._get_name() != "Basic_Identical":
self.critic_A_representation[key] = DistributedDataParallel(self.critic_A_representation[key],
device_ids=[self.rank])
if self.critic_B_representation[key]._get_name() != "Basic_Identical":
self.critic_B_representation[key] = DistributedDataParallel(self.critic_B_representation[key],
device_ids=[self.rank])
self.actor[key] = DistributedDataParallel(module=self.actor[key], device_ids=[self.rank])
self.critic_A[key] = DistributedDataParallel(module=self.critic_A[key], device_ids=[self.rank])
self.critic_B[key] = DistributedDataParallel(module=self.critic_B[key], device_ids=[self.rank])
Expand Down
Loading

0 comments on commit edb9798

Please sign in to comment.