-
Notifications
You must be signed in to change notification settings - Fork 102
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a64c263
commit 5c4acf4
Showing
8 changed files
with
429 additions
and
277 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
195 changes: 121 additions & 74 deletions
195
xuance/mindspore/learners/multi_agent_rl/isac_learner.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,101 +1,148 @@ | ||
""" | ||
Independent Soft Actor-critic (ISAC) | ||
Implementation: Pytorch | ||
Creator: Kun Jiang ([email protected]) | ||
Implementation: MindSpore | ||
""" | ||
from xuance.mindspore import ms, Module, Tensor, optim | ||
from mindspore.nn import MSELoss | ||
from xuance.mindspore import ms, Module, Tensor, optim, ops | ||
from xuance.mindspore.learners import LearnerMAS | ||
from xuance.mindspore.utils import clip_grads | ||
from xuance.common import List | ||
from argparse import Namespace | ||
|
||
|
||
class ISAC_Learner(LearnerMAS): | ||
class ActorNetWithLossCell(Module): | ||
def __init__(self, backbone, n_agents, alpha): | ||
super(ISAC_Learner.ActorNetWithLossCell, self).__init__() | ||
self._backbone = backbone | ||
self.n_agents = n_agents | ||
self.alpha = alpha | ||
|
||
def construct(self, bs, o, ids, agt_mask): | ||
_, actions_dist_mu = self._backbone(o, ids) | ||
actions_eval = self._backbone.actor_net.sample(actions_dist_mu) | ||
log_pi_a = self._backbone.actor_net.log_prob(actions_eval, actions_dist_mu) | ||
log_pi_a = ms.ops.expand_dims(log_pi_a, axis=-1) | ||
loss_a = -(self._backbone.critic_for_train(o, actions_eval, ids) - self.alpha * log_pi_a * agt_mask).sum() / agt_mask.sum() | ||
return loss_a | ||
|
||
class CriticNetWithLossCell(Module): | ||
def __init__(self, backbone): | ||
super(ISAC_Learner.CriticNetWithLossCell, self).__init__() | ||
self._backbone = backbone | ||
|
||
def construct(self, o, acts, ids, agt_mask, tar_q): | ||
q_eval = self._backbone.critic_for_train(o, acts, ids) | ||
td_error = (q_eval - tar_q) * agt_mask | ||
loss_c = (td_error ** 2).sum() / agt_mask.sum() | ||
return loss_c | ||
|
||
def __init__(self, | ||
config: Namespace, | ||
model_keys: List[str], | ||
agent_keys: List[str], | ||
policy: Module): | ||
self.gamma = gamma | ||
self.tau = config.tau | ||
self.alpha = config.alpha | ||
self.sync_frequency = sync_frequency | ||
self.mse_loss = nn.MSELoss() | ||
super(ISAC_Learner, self).__init__(config, model_keys, agent_keys, policy) | ||
self.optimizer = { | ||
'actor': optimizer[0], | ||
'critic': optimizer[1] | ||
} | ||
key: { | ||
'actor': optim.Adam(params=self.policy.parameters_actor[key], lr=self.config.learning_rate_actor, | ||
eps=1e-5), | ||
'critic': optim.Adam(params=self.policy.parameters_critic[key], lr=self.config.learning_rate_critic, | ||
eps=1e-5)} | ||
for key in self.model_keys} | ||
self.scheduler = { | ||
'actor': scheduler[0], | ||
'critic': scheduler[1] | ||
} | ||
# define mindspore trainers | ||
self.actor_loss_net = self.ActorNetWithLossCell(policy, self.n_agents, self.alpha) | ||
self.actor_train = nn.TrainOneStepCell(self.actor_loss_net, self.optimizer['actor']) | ||
self.actor_train.set_train() | ||
self.critic_loss_net = self.CriticNetWithLossCell(policy) | ||
self.critic_train = nn.TrainOneStepCell(self.critic_loss_net, self.optimizer['critic']) | ||
self.critic_train.set_train() | ||
key: {'actor': optim.lr_scheduler.LinearLR(self.optimizer[key]['actor'], start_factor=1.0, | ||
end_factor=0.5, total_iters=self.config.running_steps), | ||
'critic': optim.lr_scheduler.LinearLR(self.optimizer[key]['critic'], start_factor=1.0, | ||
end_factor=0.5, total_iters=self.config.running_steps)} | ||
for key in self.model_keys} | ||
self.gamma = config.gamma | ||
self.tau = config.tau | ||
self.alpha = {key: config.alpha for key in self.model_keys} | ||
self.mse_loss = MSELoss() | ||
self._ones = ops.Ones() | ||
self.use_automatic_entropy_tuning = config.use_automatic_entropy_tuning | ||
if self.use_automatic_entropy_tuning: | ||
self.target_entropy = {key: -policy.action_space[key].shape[-1] for key in self.model_keys} | ||
self.log_alpha = {key: ms.Parameter(self._ones(1, ms.float32)) for key in self.model_keys} | ||
self.alpha = {key: ops.exp(self.log_alpha[key]) for key in self.model_keys} | ||
self.alpha_optimizer = {key: optim.Adam(params=[self.log_alpha[key]], lr=config.learning_rate_actor) | ||
for key in self.model_keys} | ||
# Get gradient function | ||
self.grad_fn_alpha = {key: ms.value_and_grad(self.forward_fn_alpha, None, | ||
self.alpha_optimizer[key].parameters, has_aux=True) | ||
for key in self.model_keys} | ||
# Get gradient function | ||
self.grad_fn_actor = {key: ms.value_and_grad(self.forward_fn_actor, None, | ||
self.optimizer[key]['actor'].parameters, has_aux=True) | ||
for key in self.model_keys} | ||
self.grad_fn_critic = {key: ms.value_and_grad(self.forward_fn_critic, None, | ||
self.optimizer[key]['critic'].parameters, has_aux=True) | ||
for key in self.model_keys} | ||
|
||
def forward_fn_alpha(self, log_pi_eval_i, key): | ||
alpha_loss = -(self.log_alpha[key] * ops.stop_gradient((log_pi_eval_i + self.target_entropy[key]))).mean() | ||
return alpha_loss, self.log_alpha[key] | ||
|
||
def forward_fn_actor(self, obs, ids, mask_values, agent_key): | ||
_, actions_eval, log_pi_eval = self.policy(observation=obs, agent_ids=ids) | ||
_, _, policy_q_1, policy_q_2 = self.policy.Qpolicy(observation=obs, actions=actions_eval, agent_ids=ids, | ||
agent_key=agent_key) | ||
log_pi_eval_i = log_pi_eval[agent_key].reshape(-1) | ||
policy_q = ops.minimum(policy_q_1[agent_key], policy_q_2[agent_key]).reshape(-1) | ||
loss_a = ((self.alpha[agent_key] * log_pi_eval_i - policy_q) * mask_values).sum() / mask_values.sum() | ||
return loss_a, log_pi_eval[agent_key], policy_q | ||
|
||
def forward_fn_critic(self, obs, actions, ids, mask_values, backup, agent_key): | ||
_, _, action_q_1, action_q_2 = self.policy.Qaction(observation=obs, actions=actions, agent_ids=ids) | ||
action_q_1_i, action_q_2_i = action_q_1[agent_key].reshape(-1), action_q_2[agent_key].reshape(-1) | ||
td_error_1, td_error_2 = action_q_1_i - ops.stop_gradient(backup), action_q_2_i - ops.stop_gradient(backup) | ||
td_error_1 *= mask_values | ||
td_error_2 *= mask_values | ||
loss_c = ((td_error_1 ** 2).sum() + (td_error_2 ** 2).sum()) / mask_values.sum() | ||
return loss_c, action_q_1_i, action_q_2_i | ||
|
||
def update(self, sample): | ||
self.iterations += 1 | ||
obs = Tensor(sample['obs']) | ||
actions = Tensor(sample['actions']) | ||
obs_next = Tensor(sample['obs_next']) | ||
rewards = Tensor(sample['rewards']) | ||
terminals = Tensor(sample['terminals']).view(-1, self.n_agents, 1) | ||
agent_mask = Tensor(sample['agent_mask']).view(-1, self.n_agents, 1) | ||
batch_size = obs.shape[0] | ||
IDs = ops.broadcast_to(self.expand_dims(self.eye(self.n_agents, self.n_agents, ms.float32), 0), | ||
(batch_size, -1, -1)) | ||
info = {} | ||
|
||
actions_next_dist_mu = self.policy.target_actor(obs_next, IDs) | ||
actions_next = self.policy.target_actor_net.sample(actions_next_dist_mu) | ||
log_pi_a_next = self.policy.target_actor_net.log_prob(actions_next, actions_next_dist_mu) | ||
q_next = self.policy.target_critic(obs_next, actions_next, IDs) | ||
log_pi_a_next = ms.ops.expand_dims(log_pi_a_next, axis=-1) | ||
q_target = rewards + (1-terminals) * self.args.gamma * (q_next - self.alpha * log_pi_a_next) | ||
# Prepare training data. | ||
sample_Tensor = self.build_training_data(sample, | ||
use_parameter_sharing=self.use_parameter_sharing, | ||
use_actions_mask=False) | ||
batch_size = sample_Tensor['batch_size'] | ||
obs = sample_Tensor['obs'] | ||
actions = sample_Tensor['actions'] | ||
obs_next = sample_Tensor['obs_next'] | ||
rewards = sample_Tensor['rewards'] | ||
terminals = sample_Tensor['terminals'] | ||
agent_mask = sample_Tensor['agent_mask'] | ||
IDs = sample_Tensor['agent_ids'] | ||
if self.use_parameter_sharing: | ||
key = self.model_keys[0] | ||
bs = batch_size * self.n_agents | ||
rewards[key] = rewards[key].reshape(batch_size * self.n_agents) | ||
terminals[key] = terminals[key].reshape(batch_size * self.n_agents) | ||
else: | ||
bs = batch_size | ||
|
||
# calculate the loss function | ||
loss_a = self.actor_train(batch_size, obs, IDs, agent_mask) | ||
loss_c = self.critic_train(obs, actions, IDs, agent_mask, q_target) | ||
# feedforward | ||
|
||
self.policy.soft_update(self.tau) | ||
_, actions_next, log_pi_next = self.policy(observation=obs_next, agent_ids=IDs) | ||
_, _, next_q = self.policy.Qtarget(next_observation=obs_next, next_actions=actions_next, agent_ids=IDs) | ||
|
||
for key in self.model_keys: | ||
mask_values = agent_mask[key] | ||
# update critic | ||
log_pi_next_eval = log_pi_next[key].reshape(bs) | ||
next_q_i = next_q[key].reshape(bs) | ||
target_value = next_q_i - self.alpha[key] * log_pi_next_eval | ||
backup = rewards[key] + (1 - terminals[key]) * self.gamma * target_value | ||
(loss_c, _, _), grads_critic = self.grad_fn_critic[key](obs, actions, IDs, mask_values, backup, key) | ||
if self.use_grad_clip: | ||
grads_critic = clip_grads(grads_critic, Tensor(-self.grad_clip_norm), Tensor(self.grad_clip_norm)) | ||
self.optimizer[key]['critic'](grads_critic) | ||
|
||
learning_rate_actor = self.scheduler['actor'](self.iterations).asnumpy() | ||
learning_rate_critic = self.scheduler['critic'](self.iterations).asnumpy() | ||
# update actor | ||
(loss_a, log_pi_eval_i, policy_q), grads_actor = self.grad_fn_actor[key](obs, IDs, mask_values, key) | ||
if self.use_grad_clip: | ||
grads_actor = clip_grads(grads_actor, Tensor(-self.grad_clip_norm), Tensor(self.grad_clip_norm)) | ||
self.optimizer[key]['actor'](grads_actor) | ||
|
||
info = { | ||
"learning_rate_actor": learning_rate_actor, | ||
"learning_rate_critic": learning_rate_critic, | ||
"loss_actor": loss_a.asnumpy(), | ||
"loss_critic": loss_c.asnumpy() | ||
} | ||
# automatic entropy tuning | ||
if self.use_automatic_entropy_tuning: | ||
(alpha_loss, _), grads_alpha = self.grad_fn_alpha[key](log_pi_eval_i, key) | ||
self.alpha_optimizer[key](grads_alpha) | ||
self.alpha[key] = ops.exp(self.log_alpha[key]) | ||
else: | ||
alpha_loss = 0 | ||
|
||
learning_rate_actor = self.scheduler[key]['actor'].get_last_lr()[0] | ||
learning_rate_critic = self.scheduler[key]['critic'].get_last_lr()[0] | ||
|
||
info.update({ | ||
f"{key}/learning_rate_actor": learning_rate_actor.asnumpy(), | ||
f"{key}/learning_rate_critic": learning_rate_critic.asnumpy(), | ||
f"{key}/loss_actor": loss_a.asnumpy(), | ||
f"{key}/loss_critic": loss_c.asnumpy(), | ||
f"{key}/predictQ": policy_q.mean().asnumpy(), | ||
f"{key}/alpha_loss": alpha_loss.asnumpy(), | ||
f"{key}/alpha": self.alpha[key].asnumpy(), | ||
}) | ||
|
||
self.policy.soft_update(self.tau) | ||
return info |
Oops, something went wrong.