Skip to content

Commit

Permalink
samples data for parallel training
Browse files Browse the repository at this point in the history
  • Loading branch information
wenzhangliu committed Sep 26, 2024
1 parent c779d1a commit dd078fc
Show file tree
Hide file tree
Showing 22 changed files with 115 additions and 171 deletions.
12 changes: 11 additions & 1 deletion xuance/torch/agents/base/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,19 @@ def __init__(self,
self.rank = int(os.environ['LOCAL_RANK'])
master_port = config.master_port if hasattr(config, "master_port") else None
init_distributed_mode(master_port=master_port)
if self.config.buffer_size < self.world_size:
raise AttributeError("The config.buffer_size is less than the number of GPUs.")
else:
self.buffer_size = self.config.buffer_size // self.world_size
if self.config.batch_size < self.world_size:
raise AttributeError("The config.batch_size is less than the number of GPUs.")
else:
self.batch_size = self.config.batch_size // self.world_size
else:
self.rank = 0
self.world_size = 1
self.rank = 0
self.buffer_size = self.config.buffer_size
self.batch_size = self.config.batch_size

self.gamma = config.gamma
self.start_training = config.start_training if hasattr(config, "start_training") else 1
Expand Down
19 changes: 9 additions & 10 deletions xuance/torch/agents/core/off_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def _build_memory(self, auxiliary_info_shape=None):
action_space=self.action_space,
auxiliary_shape=auxiliary_info_shape,
n_envs=self.n_envs,
buffer_size=self.config.buffer_size,
batch_size=self.config.batch_size)
buffer_size=self.buffer_size,
batch_size=self.batch_size)
return Buffer(**input_buffer)

def _build_policy(self) -> Module:
Expand Down Expand Up @@ -137,14 +137,13 @@ def train(self, train_steps):
self.ret_rms.update(self.returns[i:i + 1])
self.returns[i] = 0.0
self.current_episode[i] += 1
if self.rank == 0:
if self.use_wandb:
step_info[f"Episode-Steps/env-{i}"] = infos[i]["episode_step"]
step_info[f"Train-Episode-Rewards/env-{i}"] = infos[i]["episode_score"]
else:
step_info["Episode-Steps"] = {f"env-{i}": infos[i]["episode_step"]}
step_info["Train-Episode-Rewards"] = {f"env-{i}": infos[i]["episode_score"]}
self.log_infos(step_info, self.current_step)
if self.use_wandb:
step_info[f"Episode-Steps/rank_{self.rank}/env-{i}"] = infos[i]["episode_step"]
step_info[f"Train-Episode-Rewards/rank_{self.rank}/env-{i}"] = infos[i]["episode_score"]
else:
step_info[f"Episode-Steps/rank_{self.rank}"] = {f"env-{i}": infos[i]["episode_step"]}
step_info[f"Train-Episode-Rewards/rank_{self.rank}"] = {f"env-{i}": infos[i]["episode_score"]}
self.log_infos(step_info, self.current_step)

self.current_step += self.n_envs
self._update_explore_factor()
Expand Down
22 changes: 0 additions & 22 deletions xuance/torch/learners/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,28 +47,6 @@ def __init__(self,
self.running_steps = config.running_steps
self.iterations = 0

def build_training_data(self, samples: Optional[dict]):
batch_size = samples['batch_size']
samples_Tensor = {}
if self.world_size > 1: # i.e., Multi-GPU settings.
rank = int(os.environ['RANK'])
batch_size_local = batch_size // self.world_size
if rank < self.world_size - 1:
indices = range(rank * batch_size_local, (rank + 1) * batch_size_local)
else:
indices = range(rank * batch_size_local, batch_size)
for k, v in samples.items():
if k == 'batch_size':
continue
samples_Tensor[k] = torch.as_tensor(v[indices], device=self.device)
else:
for k, v in samples.items():
if k == 'batch_size':
continue
samples_Tensor[k] = torch.as_tensor(v, device=self.device)

return samples_Tensor

def save_model(self, model_path):
torch.save(self.policy.state_dict(), model_path)
if self.distributed_training:
Expand Down
9 changes: 4 additions & 5 deletions xuance/torch/learners/policy_gradient/a2c_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@ def __init__(self,

def update(self, **samples):
self.iterations += 1
sample_Tensor = self.build_training_data(samples=samples)
obs_batch = sample_Tensor['obs']
act_batch = sample_Tensor['actions']
ret_batch = sample_Tensor['returns']
adv_batch = sample_Tensor['advantages']
obs_batch = torch.as_tensor(samples['obs'], device=self.device)
act_batch = torch.as_tensor(samples['actions'], device=self.device)
ret_batch = torch.as_tensor(samples['returns'], device=self.device)
adv_batch = torch.as_tensor(samples['adv_batch'], device=self.device)

outputs, a_dist, v_pred = self.policy(obs_batch)
log_prob = a_dist.log_prob(act_batch)
Expand Down
11 changes: 5 additions & 6 deletions xuance/torch/learners/policy_gradient/ddpg_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@ def __init__(self,

def update(self, **samples):
self.iterations += 1
sample_Tensor = self.build_training_data(samples=samples)
obs_batch = sample_Tensor['obs']
act_batch = sample_Tensor['actions']
next_batch = sample_Tensor['obs_next']
rew_batch = sample_Tensor['rewards']
ter_batch = sample_Tensor['terminals']
obs_batch = torch.as_tensor(samples['obs'], device=self.device)
act_batch = torch.as_tensor(samples['actions'], device=self.device)
next_batch = torch.as_tensor(samples['obs_next'], device=self.device)
rew_batch = torch.as_tensor(samples['rewards'], device=self.device)
ter_batch = torch.as_tensor(samples['terminals'], device=self.device)

# critic update
action_q = self.policy.Qaction(obs_batch, act_batch).reshape([-1])
Expand Down
11 changes: 5 additions & 6 deletions xuance/torch/learners/policy_gradient/mpdqn_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@ def __init__(self,

def update(self, **samples):
self.iterations += 1
sample_Tensor = self.build_training_data(samples=samples)
obs_batch = sample_Tensor['obs']
hyact_batch = sample_Tensor['actions']
rew_batch = sample_Tensor['rewards']
next_batch = sample_Tensor['obs_next']
ter_batch = sample_Tensor['terminals']
obs_batch = torch.as_tensor(samples['obs'], device=self.device)
hyact_batch = torch.as_tensor(samples['actions'], device=self.device)
next_batch = torch.as_tensor(samples['obs_next'], device=self.device)
rew_batch = torch.as_tensor(samples['rewards'], device=self.device)
ter_batch = torch.as_tensor(samples['terminals'], device=self.device)
disact_batch = hyact_batch[:, 0].long()
conact_batch = hyact_batch[:, 1:]

Expand Down
11 changes: 5 additions & 6 deletions xuance/torch/learners/policy_gradient/pdqn_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@ def __init__(self,

def update(self, **samples):
self.iterations += 1
sample_Tensor = self.build_training_data(samples=samples)
obs_batch = sample_Tensor['obs']
hyact_batch = sample_Tensor['actions']
rew_batch = sample_Tensor['rewards']
next_batch = sample_Tensor['obs_next']
ter_batch = sample_Tensor['terminals']
obs_batch = torch.as_tensor(samples['obs'], device=self.device)
hyact_batch = torch.as_tensor(samples['actions'], device=self.device)
next_batch = torch.as_tensor(samples['obs_next'], device=self.device)
rew_batch = torch.as_tensor(samples['rewards'], device=self.device)
ter_batch = torch.as_tensor(samples['terminals'], device=self.device)
disact_batch = hyact_batch[:, 0].long()
conact_batch = hyact_batch[:, 1:]

Expand Down
7 changes: 3 additions & 4 deletions xuance/torch/learners/policy_gradient/pg_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@ def __init__(self,

def update(self, **samples):
self.iterations += 1
sample_Tensor = self.build_training_data(samples=samples)
obs_batch = sample_Tensor['obs']
act_batch = sample_Tensor['actions']
ret_batch = sample_Tensor['returns']
obs_batch = torch.as_tensor(samples['obs'], device=self.device)
act_batch = torch.as_tensor(samples['actions'], device=self.device)
ret_batch = torch.as_tensor(samples['returns'], device=self.device)

_, a_dist, _ = self.policy(obs_batch)
log_prob = a_dist.log_prob(act_batch)
Expand Down
21 changes: 9 additions & 12 deletions xuance/torch/learners/policy_gradient/ppg_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,10 @@ def __init__(self,

def update_policy(self, **samples):
self.policy_iterations += 1
sample_Tensor = self.build_training_data(samples=samples)
obs_batch = sample_Tensor['obs']
act_batch = sample_Tensor['actions']
adv_batch = sample_Tensor['advantages']
old_dist = merge_distributions(sample_Tensor['aux_batch']['old_dist'])
obs_batch = torch.as_tensor(samples['obs'], device=self.device)
act_batch = torch.as_tensor(samples['actions'], device=self.device)
adv_batch = torch.as_tensor(samples['advantages'], device=self.device)
old_dist = merge_distributions(samples['aux_batch']['old_dist'])
old_logp_batch = old_dist.log_prob(act_batch).detach()

outputs, a_dist, _, _ = self.policy(obs_batch)
Expand Down Expand Up @@ -73,9 +72,8 @@ def update_policy(self, **samples):

def update_critic(self, **samples):
self.value_iterations += 1
sample_Tensor = self.build_training_data(samples=samples)
obs_batch = sample_Tensor['obs']
ret_batch = sample_Tensor['returns']
obs_batch = torch.as_tensor(samples['obs'], device=self.device)
ret_batch = torch.as_tensor(samples['returns'], device=self.device)

_, _, v_pred, _ = self.policy(obs_batch)
loss = self.mse_loss(v_pred, ret_batch)
Expand All @@ -92,10 +90,9 @@ def update_critic(self, **samples):
return info

def update_auxiliary(self, **samples):
sample_Tensor = self.build_training_data(samples=samples)
obs_batch = sample_Tensor['obs']
ret_batch = sample_Tensor['returns']
old_dist = merge_distributions(sample_Tensor['aux_batch']['old_dist'])
obs_batch = torch.as_tensor(samples['obs'], device=self.device)
ret_batch = torch.as_tensor(samples['returns'], device=self.device)
old_dist = merge_distributions(samples['aux_batch']['old_dist'])

outputs, a_dist, v, aux_v = self.policy(obs_batch)
aux_loss = self.mse_loss(v.detach(), aux_v)
Expand Down
11 changes: 5 additions & 6 deletions xuance/torch/learners/policy_gradient/ppoclip_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@ def __init__(self,

def update(self, **samples):
self.iterations += 1
sample_Tensor = self.build_training_data(samples=samples)
obs_batch = sample_Tensor['obs']
act_batch = sample_Tensor['actions']
ret_batch = sample_Tensor['returns']
adv_batch = sample_Tensor['advantages']
old_logp_batch = sample_Tensor['aux_batch']['old_logp']
obs_batch = torch.as_tensor(samples['obs'], device=self.device)
act_batch = torch.as_tensor(samples['actions'], device=self.device)
ret_batch = torch.as_tensor(samples['returns'], device=self.device)
adv_batch = torch.as_tensor(samples['advantages'], device=self.device)
old_logp_batch = samples['aux_batch']['old_logp']

outputs, a_dist, v_pred = self.policy(obs_batch)
log_prob = a_dist.log_prob(act_batch)
Expand Down
9 changes: 4 additions & 5 deletions xuance/torch/learners/policy_gradient/ppokl_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,10 @@ def __init__(self,

def update(self, **samples):
self.iterations += 1
sample_Tensor = self.build_training_data(samples=samples)
obs_batch = sample_Tensor['obs']
act_batch = sample_Tensor['actions']
ret_batch = sample_Tensor['returns']
adv_batch = sample_Tensor['advantages']
obs_batch = torch.as_tensor(samples['obs'], device=self.device)
act_batch = torch.as_tensor(samples['actions'], device=self.device)
ret_batch = torch.as_tensor(samples['returns'], device=self.device)
adv_batch = torch.as_tensor(samples['advantages'], device=self.device)
old_dists = samples['aux_batch']['old_dist']

_, a_dist, v_pred = self.policy(obs_batch)
Expand Down
11 changes: 5 additions & 6 deletions xuance/torch/learners/policy_gradient/sac_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,11 @@ def __init__(self,

def update(self, **samples):
self.iterations += 1
sample_Tensor = self.build_training_data(samples=samples)
obs_batch = sample_Tensor['obs']
act_batch = sample_Tensor['actions']
next_batch = sample_Tensor['obs_next']
rew_batch = sample_Tensor['rewards']
ter_batch = sample_Tensor['terminals']
obs_batch = torch.as_tensor(samples['obs'], device=self.device)
act_batch = torch.as_tensor(samples['actions'], device=self.device)
next_batch = torch.as_tensor(samples['obs_next'], device=self.device)
rew_batch = torch.as_tensor(samples['rewards'], device=self.device)
ter_batch = torch.as_tensor(samples['terminals'], device=self.device)

# actor update
log_pi, policy_q_1, policy_q_2 = self.policy.Qpolicy(obs_batch)
Expand Down
11 changes: 5 additions & 6 deletions xuance/torch/learners/policy_gradient/sacdis_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,11 @@ def __init__(self,

def update(self, **samples):
self.iterations += 1
sample_Tensor = self.build_training_data(samples=samples)
obs_batch = sample_Tensor['obs']
act_batch = sample_Tensor['actions'].unsqueeze(-1)
next_batch = sample_Tensor['obs_next']
rew_batch = sample_Tensor['rewards'].unsqueeze(-1)
ter_batch = sample_Tensor['terminals'].reshape([-1, 1])
obs_batch = torch.as_tensor(samples['obs'], device=self.device)
act_batch = torch.as_tensor(samples['actions'], device=self.device).unsqueeze(-1)
next_batch = torch.as_tensor(samples['obs_next'], device=self.device)
rew_batch = torch.as_tensor(samples['rewards'], device=self.device).unsqueeze(-1)
ter_batch = torch.as_tensor(samples['terminals'], device=self.device).reshape([-1, 1])

# actor update
action_prob, log_pi, policy_q_1, policy_q_2 = self.policy.Qpolicy(obs_batch)
Expand Down
11 changes: 5 additions & 6 deletions xuance/torch/learners/policy_gradient/spdqn_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@ def __init__(self,

def update(self, **samples):
self.iterations += 1
sample_Tensor = self.build_training_data(samples=samples)
obs_batch = sample_Tensor['obs']
hyact_batch = sample_Tensor['actions']
rew_batch = sample_Tensor['rewards']
next_batch = sample_Tensor['obs_next']
ter_batch = sample_Tensor['terminals']
obs_batch = torch.as_tensor(samples['obs'], device=self.device)
hyact_batch = torch.as_tensor(samples['actions'], device=self.device)
next_batch = torch.as_tensor(samples['obs_next'], device=self.device)
rew_batch = torch.as_tensor(samples['rewards'], device=self.device)
ter_batch = torch.as_tensor(samples['terminals'], device=self.device)
disact_batch = hyact_batch[:, 0].long()
conact_batch = hyact_batch[:, 1:]

Expand Down
11 changes: 5 additions & 6 deletions xuance/torch/learners/policy_gradient/td3_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,11 @@ def __init__(self,
def update(self, **samples):
self.iterations += 1
info = {}
sample_Tensor = self.build_training_data(samples=samples)
obs_batch = sample_Tensor['obs']
act_batch = sample_Tensor['actions']
next_batch = sample_Tensor['obs_next']
rew_batch = sample_Tensor['rewards']
ter_batch = sample_Tensor['terminals']
obs_batch = torch.as_tensor(samples['obs'], device=self.device)
act_batch = torch.as_tensor(samples['actions'], device=self.device)
next_batch = torch.as_tensor(samples['obs_next'], device=self.device)
rew_batch = torch.as_tensor(samples['rewards'], device=self.device)
ter_batch = torch.as_tensor(samples['terminals'], device=self.device)

# critic update
action_q_A, action_q_B = self.policy.Qaction(obs_batch, act_batch)
Expand Down
11 changes: 5 additions & 6 deletions xuance/torch/learners/qlearning_family/c51_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@ def __init__(self,

def update(self, **samples):
self.iterations += 1
sample_Tensor = self.build_training_data(samples=samples)
obs_batch = sample_Tensor['obs']
act_batch = sample_Tensor['actions']
next_batch = sample_Tensor['obs_next']
rew_batch = sample_Tensor['rewards']
ter_batch = sample_Tensor['terminals']
obs_batch = torch.as_tensor(samples['obs'], device=self.device)
act_batch = torch.as_tensor(samples['actions'], device=self.device)
next_batch = torch.as_tensor(samples['obs_next'], device=self.device)
rew_batch = torch.as_tensor(samples['rewards'], device=self.device)
ter_batch = torch.as_tensor(samples['terminals'], device=self.device)

_, _, evalZ = self.policy(obs_batch)
_, targetA, targetZ = self.policy.target(next_batch)
Expand Down
11 changes: 5 additions & 6 deletions xuance/torch/learners/qlearning_family/ddqn_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,11 @@ def __init__(self,

def update(self, **samples):
self.iterations += 1
sample_Tensor = self.build_training_data(samples=samples)
obs_batch = sample_Tensor['obs']
act_batch = sample_Tensor['actions']
next_batch = sample_Tensor['obs_next']
rew_batch = sample_Tensor['rewards']
ter_batch = sample_Tensor['terminals']
obs_batch = torch.as_tensor(samples['obs'], device=self.device)
act_batch = torch.as_tensor(samples['actions'], device=self.device)
next_batch = torch.as_tensor(samples['obs_next'], device=self.device)
rew_batch = torch.as_tensor(samples['rewards'], device=self.device)
ter_batch = torch.as_tensor(samples['terminals'], device=self.device)

_, _, evalQ = self.policy(obs_batch)
_, targetA, targetQ = self.policy.target(next_batch)
Expand Down
11 changes: 5 additions & 6 deletions xuance/torch/learners/qlearning_family/dqn_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,11 @@ def __init__(self,

def update(self, **samples):
self.iterations += 1
sample_Tensor = self.build_training_data(samples=samples)
obs_batch = sample_Tensor['obs']
act_batch = sample_Tensor['actions']
next_batch = sample_Tensor['obs_next']
rew_batch = sample_Tensor['rewards']
ter_batch = sample_Tensor['terminals']
obs_batch = torch.as_tensor(samples['obs'], device=self.device)
act_batch = torch.as_tensor(samples['actions'], device=self.device)
next_batch = torch.as_tensor(samples['obs_next'], device=self.device)
rew_batch = torch.as_tensor(samples['rewards'], device=self.device)
ter_batch = torch.as_tensor(samples['terminals'], device=self.device)

_, _, evalQ = self.policy(obs_batch)
_, _, targetQ = self.policy.target(next_batch)
Expand Down
11 changes: 5 additions & 6 deletions xuance/torch/learners/qlearning_family/drqn_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,11 @@ def __init__(self,

def update(self, **samples):
self.iterations += 1
sample_Tensor = self.build_training_data(samples=samples)
obs_batch = sample_Tensor['obs']
act_batch = sample_Tensor['actions']
rew_batch = sample_Tensor['rewards']
ter_batch = sample_Tensor['terminals']
batch_size = sample_Tensor['batch_size']
obs_batch = torch.as_tensor(samples['obs'], device=self.device)
act_batch = torch.as_tensor(samples['actions'], device=self.device)
rew_batch = torch.as_tensor(samples['rewards'], device=self.device)
ter_batch = torch.as_tensor(samples['terminals'], device=self.device)
batch_size = samples['batch_size']

rnn_hidden = self.policy.init_hidden(batch_size)
_, _, evalQ, _ = self.policy(obs_batch[:, 0:-1], *rnn_hidden)
Expand Down
Loading

0 comments on commit dd078fc

Please sign in to comment.