From a64c263f555b17c8756c4609270e32755739c211 Mon Sep 17 00:00:00 2001 From: wenzhangliu Date: Wed, 11 Sep 2024 10:16:32 +0800 Subject: [PATCH] mindspore: maddpg --- .../learners/multi_agent_rl/iddpg_learner.py | 8 +- .../learners/multi_agent_rl/maddpg_learner.py | 181 +++++++++++------- .../mindspore/policies/deterministic_marl.py | 24 +-- 3 files changed, 130 insertions(+), 83 deletions(-) diff --git a/xuance/mindspore/learners/multi_agent_rl/iddpg_learner.py b/xuance/mindspore/learners/multi_agent_rl/iddpg_learner.py index 1a5bc34b..5038e7c8 100644 --- a/xuance/mindspore/learners/multi_agent_rl/iddpg_learner.py +++ b/xuance/mindspore/learners/multi_agent_rl/iddpg_learner.py @@ -8,7 +8,6 @@ from xuance.mindspore.utils import clip_grads from xuance.common import List from argparse import Namespace -from operator import itemgetter class IDDPG_Learner(LearnerMAS): @@ -34,6 +33,7 @@ def __init__(self, self.gamma = config.gamma self.tau = config.tau self.mse_loss = MSELoss() + # Get gradient function self.grad_fn_actor = {key: ms.value_and_grad(self.forward_fn_actor, None, self.optimizer[key]['actor'].parameters, has_aux=True) for key in self.model_keys} @@ -109,9 +109,9 @@ def update(self, sample): info.update({ f"{key}/learning_rate_actor": learning_rate_actor.asnumpy(), f"{key}/learning_rate_critic": learning_rate_critic.asnumpy(), - f"{key}/loss_actor": loss_a.numpy(), - f"{key}/loss_critic": loss_c.numpy(), - f"{key}/predictQ": q_eval_a.mean().numpy() + f"{key}/loss_actor": loss_a.asnumpy(), + f"{key}/loss_critic": loss_c.asnumpy(), + f"{key}/predictQ": q_eval_a.mean().asnumpy() }) self.policy.soft_update(self.tau) diff --git a/xuance/mindspore/learners/multi_agent_rl/maddpg_learner.py b/xuance/mindspore/learners/multi_agent_rl/maddpg_learner.py index 08a6342e..f39a59fc 100644 --- a/xuance/mindspore/learners/multi_agent_rl/maddpg_learner.py +++ b/xuance/mindspore/learners/multi_agent_rl/maddpg_learner.py @@ -5,92 +5,139 @@ Implementation: MindSpore Trick: Parameter sharing for all agents, with agents' one-hot IDs as actor-critic's inputs. """ -from xuance.mindspore import ms, Module, Tensor, optim +from mindspore.nn import MSELoss +from xuance.mindspore import ms, Module, Tensor, optim, ops from xuance.mindspore.learners import LearnerMAS +from xuance.mindspore.utils import clip_grads from xuance.common import List from argparse import Namespace +from operator import itemgetter class MADDPG_Learner(LearnerMAS): - class ActorNetWithLossCell(Module): - def __init__(self, backbone, n_agents): - super(MADDPG_Learner.ActorNetWithLossCell, self).__init__() - self._backbone = backbone - self._mean = ms.ops.ReduceMean(keep_dims=True) - self.n_agents = n_agents - - def construct(self, bs, o, ids, agt_mask): - _, actions_eval = self._backbone(o, ids) - loss_a = -(self._backbone.critic(o, actions_eval, ids) * agt_mask).sum() / agt_mask.sum() - return loss_a - - class CriticNetWithLossCell(Module): - def __init__(self, backbone): - super(MADDPG_Learner.CriticNetWithLossCell, self).__init__() - self._backbone = backbone - self._loss = nn.MSELoss() - - def construct(self, o, a_n, ids, agt_mask, tar_q): - q_eval = self._backbone.critic(o, a_n, ids) - td_error = (q_eval - tar_q) * agt_mask - loss_c = (td_error ** 2).sum() / agt_mask.sum() - return loss_c - def __init__(self, config: Namespace, model_keys: List[str], agent_keys: List[str], policy: Module): - self.gamma = gamma - self.tau = config.tau - self.sync_frequency = sync_frequency - self.mse_loss = nn.MSELoss() super(MADDPG_Learner, self).__init__(config, model_keys, agent_keys, policy) self.optimizer = { - 'actor': optimizer[0], - 'critic': optimizer[1] - } + key: { + 'actor': optim.Adam(params=self.policy.parameters_actor[key], lr=self.config.learning_rate_actor, + eps=1e-5), + 'critic': optim.Adam(params=self.policy.parameters_critic[key], lr=self.config.learning_rate_critic, + eps=1e-5)} + for key in self.model_keys} self.scheduler = { - 'actor': scheduler[0], - 'critic': scheduler[1] - } - # define mindspore trainers - self.actor_loss_net = self.ActorNetWithLossCell(policy, self.n_agents) - self.actor_train = nn.TrainOneStepCell(self.actor_loss_net, self.optimizer['actor']) - self.actor_train.set_train() - self.critic_loss_net = self.CriticNetWithLossCell(policy) - self.critic_train = nn.TrainOneStepCell(self.critic_loss_net, self.optimizer['critic']) - self.critic_train.set_train() + key: {'actor': optim.lr_scheduler.LinearLR(self.optimizer[key]['actor'], start_factor=1.0, + end_factor=0.5, total_iters=self.config.running_steps), + 'critic': optim.lr_scheduler.LinearLR(self.optimizer[key]['critic'], start_factor=1.0, + end_factor=0.5, total_iters=self.config.running_steps)} + for key in self.model_keys} + self.gamma = config.gamma + self.tau = config.tau + self.mse_loss = MSELoss() + # Get gradient function + self.grad_fn_actor = {key: ms.value_and_grad(self.forward_fn_actor, None, + self.optimizer[key]['actor'].parameters, has_aux=True) + for key in self.model_keys} + self.grad_fn_critic = {key: ms.value_and_grad(self.forward_fn_critic, None, + self.optimizer[key]['critic'].parameters, has_aux=True) + for key in self.model_keys} + self.policy.set_train() + + def forward_fn_actor(self, batch_size, obs, obs_joint, actions, ids, mask_values, agent_key): + _, actions_eval = self.policy(observation=obs, agent_ids=ids) + if self.use_parameter_sharing: + act_eval = actions_eval[agent_key].reshape(batch_size, self.n_agents, -1).reshape(batch_size, -1) + else: + a_joint = {k: actions_eval[k] if k == agent_key else actions[k] for k in self.agent_keys} + act_eval = ops.cat(itemgetter(*self.agent_keys)(a_joint), axis=-1).reshape(batch_size, -1) + _, q_policy = self.policy.Qpolicy(joint_observation=obs_joint, joint_actions=act_eval, + agent_ids=ids, agent_key=agent_key) + q_policy_i = q_policy[agent_key].reshape(-1) + loss_a = -(q_policy_i * mask_values).sum() / mask_values.sum() + return loss_a, q_policy_i + + def forward_fn_critic(self, obs_joint, actions_joint, ids, mask_values, q_target, agent_key): + _, q_eval = self.policy.Qpolicy(joint_observation=obs_joint, joint_actions=actions_joint, agent_ids=ids, + agent_key=agent_key) + q_eval_a = q_eval[agent_key].reshape(-1) + td_error = (q_eval_a - ops.stop_gradient(q_target)) * mask_values + loss_c = (td_error ** 2).sum() / mask_values.sum() + return loss_c, q_eval_a def update(self, sample): self.iterations += 1 - obs = Tensor(sample['obs']) - actions = Tensor(sample['actions']) - obs_next = Tensor(sample['obs_next']) - rewards = Tensor(sample['rewards']) - terminals = Tensor(sample['terminals']).view(-1, self.n_agents, 1) - agent_mask = Tensor(sample['agent_mask']).view(-1, self.n_agents, 1) - batch_size = obs.shape[0] - IDs = ops.broadcast_to(self.expand_dims(self.eye(self.n_agents, self.n_agents, ms.float32), 0), - (batch_size, -1, -1)) - # calculate the loss and train - actions_next = self.policy.target_actor(obs_next, IDs) - q_next = self.policy.target_critic(obs_next, actions_next, IDs) - q_target = rewards + (1 - terminals) * self.args.gamma * q_next + info = {} - # calculate the loss and train - loss_a = self.actor_train(batch_size, obs, IDs, agent_mask) - loss_c = self.critic_train(obs, actions, IDs, agent_mask, q_target) - self.policy.soft_update(self.tau) + # prepare training data + sample_Tensor = self.build_training_data(sample, + use_parameter_sharing=self.use_parameter_sharing, + use_actions_mask=False) + batch_size = sample_Tensor['batch_size'] + obs = sample_Tensor['obs'] + actions = sample_Tensor['actions'] + obs_next = sample_Tensor['obs_next'] + rewards = sample_Tensor['rewards'] + terminals = sample_Tensor['terminals'] + agent_mask = sample_Tensor['agent_mask'] + IDs = sample_Tensor['agent_ids'] + if self.use_parameter_sharing: + key = self.model_keys[0] + bs = batch_size * self.n_agents + obs_joint = obs[key].reshape(batch_size, -1) + next_obs_joint = obs_next[key].reshape(batch_size, -1) + actions_joint = actions[key].reshape(batch_size, -1) + rewards[key] = rewards[key].reshape(batch_size * self.n_agents) + terminals[key] = terminals[key].reshape(batch_size * self.n_agents) + else: + bs = batch_size + obs_joint = ops.cat(itemgetter(*self.agent_keys)(obs), axis=-1).reshape(batch_size, -1) + next_obs_joint = ops.cat(itemgetter(*self.agent_keys)(obs_next), axis=-1).reshape(batch_size, -1) + actions_joint = ops.cat(itemgetter(*self.agent_keys)(actions), axis=-1).reshape(batch_size, -1) + + # get actions + _, actions_next = self.policy.Atarget(next_observation=obs_next, agent_ids=IDs) + # get values + if self.use_parameter_sharing: + key = self.model_keys[0] + actions_next_joint = actions_next[key].reshape(batch_size, self.n_agents, -1).reshape(batch_size, -1) + else: + actions_next_joint = ops.cat(itemgetter(*self.model_keys)(actions_next), -1).reshape(batch_size, -1) + _, q_next = self.policy.Qtarget(joint_observation=next_obs_joint, joint_actions=actions_next_joint, + agent_ids=IDs) - learning_rate_actor = self.scheduler['actor'](self.iterations).asnumpy() - learning_rate_critic = self.scheduler['critic'](self.iterations).asnumpy() + for key in self.model_keys: + mask_values = agent_mask[key] + # updata critic + q_next_i = q_next[key].reshape(bs) + q_target = rewards[key] + (1 - terminals[key]) * self.gamma * q_next_i + (loss_c, q_eval_a), grads_critic = self.grad_fn_critic[key](obs_joint, actions_joint, IDs, mask_values, + q_target, key) + if self.use_grad_clip: + grads_critic = clip_grads(grads_critic, Tensor(-self.grad_clip_norm), Tensor(self.grad_clip_norm)) + self.optimizer[key]['critic'](grads_critic) - info = { - "learning_rate_actor": learning_rate_actor, - "learning_rate_critic": learning_rate_critic, - "loss_actor": loss_a.asnumpy(), - "loss_critic": loss_c.asnumpy() - } + # update actor + (loss_a, _), grads_actor = self.grad_fn_actor[key](batch_size, obs, obs_joint, actions, IDs, mask_values, + key) + if self.use_grad_clip: + grads_actor = clip_grads(grads_actor, Tensor(-self.grad_clip_norm), Tensor(self.grad_clip_norm)) + self.optimizer[key]['actor'](grads_actor) + self.scheduler[key]['actor'].step() + self.scheduler[key]['critic'].step() + learning_rate_actor = self.scheduler[key]['actor'].get_last_lr()[0] + learning_rate_critic = self.scheduler[key]['critic'].get_last_lr()[0] + + info.update({ + f"{key}/learning_rate_actor": learning_rate_actor.asnumpy(), + f"{key}/learning_rate_critic": learning_rate_critic.asnumpy(), + f"{key}/loss_actor": loss_a.asnumpy(), + f"{key}/loss_critic": loss_c.asnumpy(), + f"{key}/predictQ": q_eval_a.mean().asnumpy() + }) + + self.policy.soft_update(self.tau) return info diff --git a/xuance/mindspore/policies/deterministic_marl.py b/xuance/mindspore/policies/deterministic_marl.py index 3d457cb2..1177bcfb 100644 --- a/xuance/mindspore/policies/deterministic_marl.py +++ b/xuance/mindspore/policies/deterministic_marl.py @@ -872,10 +872,10 @@ def Qpolicy(self, joint_observation: Tensor, joint_actions: Tensor, for key in agent_list: if self.use_parameter_sharing: if self.use_rnn: - joint_rep_out = outputs[key]['state'].unsqueeze(1).expand(-1, self.n_agents, -1, -1) + joint_rep_out = outputs[key]['state'].unsqueeze(1).broadcast_to((-1, self.n_agents, -1, -1)) joint_rep_out = joint_rep_out.reshape(bs, seq_len, -1) else: - joint_rep_out = outputs[key]['state'].unsqueeze(1).expand(-1, self.n_agents, -1) + joint_rep_out = outputs[key]['state'].unsqueeze(1).broadcast_to((-1, self.n_agents, -1)) joint_rep_out = joint_rep_out.reshape(bs, -1) critic_in = ops.cat([joint_rep_out, agent_ids], axis=-1) else: @@ -921,10 +921,10 @@ def Qtarget(self, joint_observation: Tensor, joint_actions: Tensor, for key in agent_list: if self.use_parameter_sharing: if self.use_rnn: - joint_rep_out = outputs[key]['state'].unsqueeze(1).expand(-1, self.n_agents, -1, -1) + joint_rep_out = outputs[key]['state'].unsqueeze(1).broadcast_to((-1, self.n_agents, -1, -1)) joint_rep_out = joint_rep_out.reshape(bs, seq_len, -1) else: - joint_rep_out = outputs[key]['state'].unsqueeze(1).expand(-1, self.n_agents, -1) + joint_rep_out = outputs[key]['state'].unsqueeze(1).broadcast_to((-1, self.n_agents, -1)) joint_rep_out = joint_rep_out.reshape(bs, -1) critic_in = ops.cat([joint_rep_out, agent_ids], axis=-1) else: @@ -1031,13 +1031,13 @@ def Qpolicy(self, joint_observation: Tensor, joint_actions: Tensor, for key in agent_list: if self.use_parameter_sharing: if self.use_rnn: - joint_rep_out_A = outputs_A[key]['state'].unsqueeze(1).expand(-1, self.n_agents, -1, -1) - joint_rep_out_B = outputs_B[key]['state'].unsqueeze(1).expand(-1, self.n_agents, -1, -1) + joint_rep_out_A = outputs_A[key]['state'].unsqueeze(1).broadcast_to((-1, self.n_agents, -1, -1)) + joint_rep_out_B = outputs_B[key]['state'].unsqueeze(1).broadcast_to((-1, self.n_agents, -1, -1)) joint_rep_out_A = joint_rep_out_A.reshape(bs, seq_len, -1) joint_rep_out_B = joint_rep_out_B.reshape(bs, seq_len, -1) else: - joint_rep_out_A = outputs_A[key]['state'].unsqueeze(1).expand(-1, self.n_agents, -1) - joint_rep_out_B = outputs_B[key]['state'].unsqueeze(1).expand(-1, self.n_agents, -1) + joint_rep_out_A = outputs_A[key]['state'].unsqueeze(1).broadcast_to((-1, self.n_agents, -1)) + joint_rep_out_B = outputs_B[key]['state'].unsqueeze(1).broadcast_to((-1, self.n_agents, -1)) joint_rep_out_A = joint_rep_out_A.reshape(bs, -1) joint_rep_out_B = joint_rep_out_B.reshape(bs, -1) critic_in_A = ops.cat([joint_rep_out_A, agent_ids], axis=-1) @@ -1091,13 +1091,13 @@ def Qtarget(self, joint_observation: Tensor, joint_actions: Tensor, for key in agent_list: if self.use_parameter_sharing: if self.use_rnn: - joint_rep_out_A = outputs_A[key]['state'].unsqueeze(1).expand(-1, self.n_agents, -1, -1) - joint_rep_out_B = outputs_B[key]['state'].unsqueeze(1).expand(-1, self.n_agents, -1, -1) + joint_rep_out_A = outputs_A[key]['state'].unsqueeze(1).broadcast_to((-1, self.n_agents, -1, -1)) + joint_rep_out_B = outputs_B[key]['state'].unsqueeze(1).broadcast_to((-1, self.n_agents, -1, -1)) joint_rep_out_A = joint_rep_out_A.reshape(bs, seq_len, -1) joint_rep_out_B = joint_rep_out_B.reshape(bs, seq_len, -1) else: - joint_rep_out_A = outputs_A[key]['state'].unsqueeze(1).expand(-1, self.n_agents, -1) - joint_rep_out_B = outputs_B[key]['state'].unsqueeze(1).expand(-1, self.n_agents, -1) + joint_rep_out_A = outputs_A[key]['state'].unsqueeze(1).broadcast_to((-1, self.n_agents, -1)) + joint_rep_out_B = outputs_B[key]['state'].unsqueeze(1).broadcast_to((-1, self.n_agents, -1)) joint_rep_out_A = joint_rep_out_A.reshape(bs, -1) joint_rep_out_B = joint_rep_out_B.reshape(bs, -1) critic_in_A = ops.cat([joint_rep_out_A, agent_ids], axis=-1)