Skip to content

Commit

Permalink
distributed training for tf, off policy
Browse files Browse the repository at this point in the history
  • Loading branch information
wenzhangliu committed Oct 2, 2024
1 parent 9a3f9e5 commit 7a3419d
Show file tree
Hide file tree
Showing 13 changed files with 251 additions and 59 deletions.
3 changes: 2 additions & 1 deletion xuance/tensorflow/agents/qlearning_family/c51_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def _build_policy(self) -> Module:
action_space=self.action_space,
atom_num=self.config.atom_num, v_min=self.config.v_min, v_max=self.config.v_max,
representation=representation, hidden_size=self.config.q_hidden_size,
normalize=normalize_fn, initialize=initializer, activation=activation)
normalize=normalize_fn, initialize=initializer, activation=activation,
use_distributed_training=self.distributed_training)
else:
raise AttributeError(f"C51 currently does not support the policy named {self.config.policy}.")

Expand Down
3 changes: 2 additions & 1 deletion xuance/tensorflow/agents/qlearning_family/drqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def _build_policy(self) -> Module:
action_space=self.action_space, representation=representation,
rnn=self.config.rnn, recurrent_hidden_size=self.config.recurrent_hidden_size,
recurrent_layer_N=self.config.recurrent_layer_N, dropout=self.config.dropout,
normalize=normalize_fn, initialize=initializer, activation=activation)
normalize=normalize_fn, initialize=initializer, activation=activation,
use_distributed_training=self.distributed_training)
else:
raise AttributeError(
f"{self.config.agent} currently does not support the policy named {self.config.policy}.")
Expand Down
3 changes: 2 additions & 1 deletion xuance/tensorflow/agents/qlearning_family/dueldqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def _build_policy(self) -> Module:
if self.config.policy == "Duel_Q_network":
policy = REGISTRY_Policy["Duel_Q_network"](
action_space=self.action_space, representation=representation, hidden_size=self.config.q_hidden_size,
normalize=normalize_fn, initialize=initializer, activation=activation)
normalize=normalize_fn, initialize=initializer, activation=activation,
use_distributed_training=self.distributed_training)
else:
raise AttributeError(f"{self.config.agent} currently does not support the policy named {self.config.policy}.")

Expand Down
3 changes: 2 additions & 1 deletion xuance/tensorflow/agents/qlearning_family/noisydqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def _build_policy(self) -> Module:
if self.config.policy == "Noisy_Q_network":
policy = REGISTRY_Policy["Noisy_Q_network"](
action_space=self.action_space, representation=representation, hidden_size=self.config.q_hidden_size,
normalize=normalize_fn, initialize=initializer, activation=activation)
normalize=normalize_fn, initialize=initializer, activation=activation,
use_distributed_training=self.distributed_training)
else:
raise AttributeError(f"{self.config.agent} currently does not support the policy named {self.config.policy}.")

Expand Down
3 changes: 2 additions & 1 deletion xuance/tensorflow/agents/qlearning_family/qrdqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def _build_policy(self) -> Module:
policy = REGISTRY_Policy["QR_Q_network"](
action_space=self.action_space, quantile_num=self.config.quantile_num,
representation=representation, hidden_size=self.config.q_hidden_size,
normalize=normalize_fn, initialize=initializer, activation=activation)
normalize=normalize_fn, initialize=initializer, activation=activation,
use_distributed_training=self.distributed_training)
else:
raise AttributeError(f"{self.config.agent} currently does not support the policy named {self.config.policy}.")

Expand Down
23 changes: 20 additions & 3 deletions xuance/tensorflow/learners/qlearning_family/c51_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,22 @@ def __init__(self,
policy: Module):
super(C51_Learner, self).__init__(config, policy)
if ("macOS" in self.os_name) and ("arm" in self.os_name): # For macOS with Apple's M-series chips.
self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate)
if self.distributed_training:
with self.policy.mirrored_strategy.scope():
self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate)
else:
self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate)
else:
self.optimizer = tk.optimizers.Adam(config.learning_rate)
if self.distributed_training:
with self.policy.mirrored_strategy.scope():
self.optimizer = tk.optimizers.Adam(config.learning_rate)
else:
self.optimizer = tk.optimizers.Adam(config.learning_rate)
self.gamma = config.gamma
self.sync_frequency = config.sync_frequency

@tf.function
def learn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch):
def forward_fn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch):
with tf.GradientTape() as tape:
_, _, evalZ = self.policy(obs_batch)
_, targetA, targetZ = self.policy.target(next_batch)
Expand Down Expand Up @@ -59,6 +67,15 @@ def learn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch):

return loss

@tf.function
def learn(self, *inputs):
if self.distributed_training:
loss = self.policy.mirrored_strategy.run(self.forward_fn, args=inputs)
return self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None)
else:
loss = self.forward_fn(*inputs)
return loss

def update(self, **samples):
self.iterations += 1
obs_batch = samples['obs']
Expand Down
24 changes: 21 additions & 3 deletions xuance/tensorflow/learners/qlearning_family/ddqn_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,22 @@ def __init__(self,
policy: Module):
super(DDQN_Learner, self).__init__(config, policy)
if ("macOS" in self.os_name) and ("arm" in self.os_name): # For macOS with Apple's M-series chips.
self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate)
if self.distributed_training:
with self.policy.mirrored_strategy.scope():
self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate)
else:
self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate)
else:
self.optimizer = tk.optimizers.Adam(config.learning_rate)
if self.distributed_training:
with self.policy.mirrored_strategy.scope():
self.optimizer = tk.optimizers.Adam(config.learning_rate)
else:
self.optimizer = tk.optimizers.Adam(config.learning_rate)
self.gamma = config.gamma
self.sync_frequency = config.sync_frequency

@tf.function
def learn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch):
def forward_fn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch):
with tf.GradientTape() as tape:
_, _, evalQ = self.policy(obs_batch)
_, targetA, targetQ = self.policy.target(next_batch)
Expand All @@ -49,6 +57,16 @@ def learn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch):
])
return predictQ, loss

@tf.function
def learn(self, *inputs):
if self.distributed_training:
predictQ, loss = self.policy.mirrored_strategy.run(self.forward_fn, args=inputs)
return (self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, predictQ, axis=None),
self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None))
else:
predictQ, loss = self.forward_fn(*inputs)
return predictQ, loss

def update(self, **samples):
self.iterations += 1
obs_batch = samples['obs']
Expand Down
3 changes: 2 additions & 1 deletion xuance/tensorflow/learners/qlearning_family/dqn_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def forward_fn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch):
def learn(self, *inputs):
if self.distributed_training:
predictQ, loss = self.policy.mirrored_strategy.run(self.forward_fn, args=inputs)
return predictQ, self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None)
return (self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, predictQ, axis=None),
self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None))
else:
predictQ, loss = self.forward_fn(*inputs)
return predictQ, loss
Expand Down
23 changes: 20 additions & 3 deletions xuance/tensorflow/learners/qlearning_family/drqn_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,23 @@ def __init__(self,
policy: Module):
super(DRQN_Learner, self).__init__(config, policy)
if ("macOS" in self.os_name) and ("arm" in self.os_name): # For macOS with Apple's M-series chips.
self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate)
if self.distributed_training:
with self.policy.mirrored_strategy.scope():
self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate)
else:
self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate)
else:
self.optimizer = tk.optimizers.Adam(config.learning_rate)
if self.distributed_training:
with self.policy.mirrored_strategy.scope():
self.optimizer = tk.optimizers.Adam(config.learning_rate)
else:
self.optimizer = tk.optimizers.Adam(config.learning_rate)
self.gamma = config.gamma
self.sync_frequency = config.sync_frequency
self.n_actions = self.policy.action_dim

@tf.function
def learn(self, batch_size, obs_batch, act_batch, rew_batch, ter_batch):
def forward_fn(self, batch_size, obs_batch, act_batch, rew_batch, ter_batch):
with tf.GradientTape() as tape:
rnn_hidden = self.policy.init_hidden(batch_size)
_, _, evalQ, _ = self.policy(obs_batch[:, 0:-1], *rnn_hidden)
Expand Down Expand Up @@ -56,6 +64,15 @@ def learn(self, batch_size, obs_batch, act_batch, rew_batch, ter_batch):

return predictQ, loss

@tf.function
def learn(self, *inputs):
if self.distributed_training:
predictQ, loss = self.policy.mirrored_strategy.run(self.forward_fn, args=inputs)
return predictQ, self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None)
else:
predictQ, loss = self.forward_fn(*inputs)
return predictQ, loss

def update(self, **samples):
self.iterations += 1
obs_batch = samples['obs']
Expand Down
24 changes: 21 additions & 3 deletions xuance/tensorflow/learners/qlearning_family/dueldqn_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,22 @@ def __init__(self,
policy: Module):
super(DuelDQN_Learner, self).__init__(config, policy)
if ("macOS" in self.os_name) and ("arm" in self.os_name): # For macOS with Apple's M-series chips.
self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate)
if self.distributed_training:
with self.policy.mirrored_strategy.scope():
self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate)
else:
self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate)
else:
self.optimizer = tk.optimizers.Adam(config.learning_rate)
if self.distributed_training:
with self.policy.mirrored_strategy.scope():
self.optimizer = tk.optimizers.Adam(config.learning_rate)
else:
self.optimizer = tk.optimizers.Adam(config.learning_rate)
self.gamma = config.gamma
self.sync_frequency = config.sync_frequency

@tf.function
def learn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch):
def forward_fn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch):
with tf.GradientTape() as tape:
_, _, evalQ = self.policy(obs_batch)
_, _, targetQ = self.policy.target(next_batch)
Expand All @@ -47,6 +55,16 @@ def learn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch):
])
return predictQ, loss

@tf.function
def learn(self, *inputs):
if self.distributed_training:
predictQ, loss = self.policy.mirrored_strategy.run(self.forward_fn, args=inputs)
return (self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, predictQ, axis=None),
self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None))
else:
predictQ, loss = self.forward_fn(*inputs)
return predictQ, loss

def update(self, **samples):
self.iterations += 1
obs_batch = samples['obs']
Expand Down
25 changes: 22 additions & 3 deletions xuance/tensorflow/learners/qlearning_family/perdqn_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,22 @@ def __init__(self,
policy: Module):
super(PerDQN_Learner, self).__init__(config, policy)
if ("macOS" in self.os_name) and ("arm" in self.os_name): # For macOS with Apple's M-series chips.
self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate)
if self.distributed_training:
with self.policy.mirrored_strategy.scope():
self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate)
else:
self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate)
else:
self.optimizer = tk.optimizers.Adam(config.learning_rate)
if self.distributed_training:
with self.policy.mirrored_strategy.scope():
self.optimizer = tk.optimizers.Adam(config.learning_rate)
else:
self.optimizer = tk.optimizers.Adam(config.learning_rate)
self.gamma = config.gamma
self.sync_frequency = config.sync_frequency

@tf.function
def learn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch):
def forward_fn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch):
with tf.GradientTape() as tape:
_, _, evalQ = self.policy(obs_batch)
_, _, targetQ = self.policy.target(next_batch)
Expand All @@ -48,6 +56,17 @@ def learn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch):
])
return td_error, predictQ, loss

@tf.function
def learn(self, *inputs):
if self.distributed_training:
td_error, predictQ, loss = self.policy.mirrored_strategy.run(self.forward_fn, args=inputs)
return (self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, td_error, axis=None),
self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, predictQ, axis=None),
self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None))
else:
td_error, predictQ, loss = self.forward_fn(*inputs)
return td_error, predictQ, loss

def update(self, **samples):
self.iterations += 1
obs_batch = samples['obs']
Expand Down
23 changes: 20 additions & 3 deletions xuance/tensorflow/learners/qlearning_family/qrdqn_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,22 @@ def __init__(self,
policy: Module):
super(QRDQN_Learner, self).__init__(config, policy)
if ("macOS" in self.os_name) and ("arm" in self.os_name): # For macOS with Apple's M-series chips.
self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate)
if self.distributed_training:
with self.policy.mirrored_strategy.scope():
self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate)
else:
self.optimizer = tk.optimizers.legacy.Adam(config.learning_rate)
else:
self.optimizer = tk.optimizers.Adam(config.learning_rate)
if self.distributed_training:
with self.policy.mirrored_strategy.scope():
self.optimizer = tk.optimizers.Adam(config.learning_rate)
else:
self.optimizer = tk.optimizers.Adam(config.learning_rate)
self.gamma = config.gamma
self.sync_frequency = config.sync_frequency

@tf.function
def learn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch):
def forward_fn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch):
with tf.GradientTape() as tape:
_, _, evalZ = self.policy(obs_batch)
_, targetA, targetZ = self.policy.target(next_batch)
Expand Down Expand Up @@ -50,6 +58,15 @@ def learn(self, obs_batch, act_batch, next_batch, rew_batch, ter_batch):
])
return current_quantile, loss

def learn(self, *inputs):
if self.distributed_training:
predictQ, loss = self.policy.mirrored_strategy.run(self.forward_fn, args=inputs)
return (self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, predictQ, axis=None),
self.policy.mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None))
else:
predictQ, loss = self.forward_fn(*inputs)
return predictQ, loss

def update(self, **samples):
self.iterations += 1
obs_batch = samples['obs']
Expand Down
Loading

0 comments on commit 7a3419d

Please sign in to comment.