Skip to content

Commit

Permalink
distributed training for tensorflow
Browse files Browse the repository at this point in the history
  • Loading branch information
wenzhangliu committed Sep 29, 2024
1 parent bed0104 commit b9b6ce1
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 38 deletions.
25 changes: 14 additions & 11 deletions xuance/common/common_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,17 +163,20 @@ def get_runner(method,
else:
device = args.device
distributed_training = True if args.distributed_training else False
if distributed_training:
rank = int(os.environ['RANK'])
num_gpus = int(os.environ['WORLD_SIZE'])
if rank == 0:
if num_gpus > 1:
print(f"Calculating devices: {num_gpus} visible GPUs for distributed training.")
else:
print(f"Calculating device: {num_gpus} visible GPU for distributed training.")
else:
rank = 0
print(f"Calculating device: {device}")
# if distributed_training:
# rank = int(os.environ['RANK'])
# num_gpus = int(os.environ['WORLD_SIZE'])
# if rank == 0:
# if num_gpus > 1:
# print(f"Calculating devices: {num_gpus} visible GPUs for distributed training.")
# else:
# print(f"Calculating device: {num_gpus} visible GPU for distributed training.")
# else:
# rank = 0
# print(f"Calculating device: {device}")

rank = 0
print(f"Calculating device: {device}")

dl_toolbox = args[0].dl_toolbox if type(args) == list else args.dl_toolbox
if dl_toolbox == "torch":
Expand Down
3 changes: 2 additions & 1 deletion xuance/environment/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def _thunk(env_seed: int = None):
raise AttributeError(f"The environment named {config.env_name} cannot be created.")

if config.distributed_training:
rank = int(os.environ['RANK'])
# rank = int(os.environ['RANK']) # for torch.nn.parallel.DistributedDataParallel
rank = 1
config.env_seed += rank * config.parallels

if config.vectorize in REGISTRY_VEC_ENV.keys():
Expand Down
1 change: 1 addition & 0 deletions xuance/tensorflow/agents/base/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(self,
self.config = config
self.use_rnn = config.use_rnn if hasattr(config, "use_rnn") else False
self.use_actions_mask = config.use_actions_mask if hasattr(config, "use_actions_mask") else False
self.distributed_training = config.distributed_training

self.gamma = config.gamma
self.start_training = config.start_training if hasattr(config, "start_training") else 1
Expand Down
2 changes: 1 addition & 1 deletion xuance/tensorflow/agents/qlearning_family/dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _build_policy(self) -> Module:
policy = REGISTRY_Policy["Basic_Q_network"](
action_space=self.action_space, representation=representation, hidden_size=self.config.q_hidden_size,
normalize=normalize_fn, initialize=initializer, activation=activation,
use_distributed_training=self.config.use_distributed_training, gamma=self.config)
use_distributed_training=self.distributed_training)
else:
raise AttributeError(f"{self.config.agent} does not support the policy named {self.config.policy}.")

Expand Down
1 change: 1 addition & 0 deletions xuance/tensorflow/learners/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(self,
self.os_name = platform.platform()
self.value_normalizer = None
self.config = config
self.distributed_training = config.distributed_training

self.episode_length = config.episode_length
self.use_rnn = config.use_rnn if hasattr(config, 'use_rnn') else False
Expand Down
66 changes: 41 additions & 25 deletions xuance/tensorflow/learners/qlearning_family/dqn_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,41 +15,57 @@ def __init__(self,
policy: Module):
super(DQN_Learner, self).__init__(config, policy)
if ("macOS" in self.os_name) and ("arm" in self.os_name): # For macOS with Apple's M-series chips.
self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate)
if self.distributed_training:
with self.policy.mirrored_strategy.scope():
self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate)
else:
self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate)
else:
self.optimizer = tk.optimizers.Adam(config.learning_rate)
if self.distributed_training:
with self.policy.mirrored_strategy.scope():
self.optimizer = tk.optimizers.Adam(config.learning_rate)
else:
self.optimizer = tk.optimizers.Adam(config.learning_rate)
self.gamma = config.gamma
self.sync_frequency = config.sync_frequency
self.n_actions = self.policy.action_dim

@tf.function
def learn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch):
with self.policy.mirrored_strategy.scope():
with tf.GradientTape() as tape:
_, _, evalQ = self.policy(obs_batch)
_, _, targetQ = self.policy.target(next_batch)
targetQ = tf.math.reduce_max(targetQ, axis=-1)
targetQ = rew_batch + self.gamma * (1 - ter_batch) * targetQ
targetQ = tf.stop_gradient(targetQ)
def forward_fn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch):
with tf.GradientTape() as tape:
_, _, evalQ = self.policy(obs_batch)
_, _, targetQ = self.policy.target(next_batch)
targetQ = tf.math.reduce_max(targetQ, axis=-1)
targetQ = rew_batch + self.gamma * (1 - ter_batch) * targetQ
targetQ = tf.stop_gradient(targetQ)

predictQ = tf.math.reduce_sum(evalQ * tf.one_hot(act_batch, evalQ.shape[1]), axis=-1)
predictQ = tf.math.reduce_sum(evalQ * tf.one_hot(act_batch, evalQ.shape[1]), axis=-1)

loss = tk.losses.mean_squared_error(targetQ, predictQ)
gradients = tape.gradient(loss, self.policy.trainable_variables)
if self.use_grad_clip:
self.optimizer.apply_gradients([
(tf.clip_by_norm(grad, self.grad_clip_norm), var)
for (grad, var) in zip(gradients, self.policy.trainable_variables)
if grad is not None
])
else:
self.optimizer.apply_gradients([
(grad, var)
for (grad, var) in zip(gradients, self.policy.trainable_variables)
if grad is not None
])
loss = tk.losses.mean_squared_error(targetQ, predictQ)
gradients = tape.gradient(loss, self.policy.trainable_variables)
if self.use_grad_clip:
self.optimizer.apply_gradients([
(tf.clip_by_norm(grad, self.grad_clip_norm), var)
for (grad, var) in zip(gradients, self.policy.trainable_variables)
if grad is not None
])
else:
self.optimizer.apply_gradients([
(grad, var)
for (grad, var) in zip(gradients, self.policy.trainable_variables)
if grad is not None
])
return predictQ, loss

@tf.function
def learn(self, *inputs):
if self.distributed_training:
predictQ, loss = self.policy.mirrored_strategy.run(self.forward_fn, args=inputs)
return predictQ, self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None)
else:
predictQ, loss = self.forward_fn(inputs)
return predictQ, loss

def update(self, **samples):
self.iterations += 1
obs_batch = samples['obs']
Expand Down

0 comments on commit b9b6ce1

Please sign in to comment.