Skip to content

Commit

Permalink
mindspore: maddpg
Browse files Browse the repository at this point in the history
  • Loading branch information
wenzhangliu committed Sep 11, 2024
1 parent db3cf69 commit a64c263
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 83 deletions.
8 changes: 4 additions & 4 deletions xuance/mindspore/learners/multi_agent_rl/iddpg_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from xuance.mindspore.utils import clip_grads
from xuance.common import List
from argparse import Namespace
from operator import itemgetter


class IDDPG_Learner(LearnerMAS):
Expand All @@ -34,6 +33,7 @@ def __init__(self,
self.gamma = config.gamma
self.tau = config.tau
self.mse_loss = MSELoss()
# Get gradient function
self.grad_fn_actor = {key: ms.value_and_grad(self.forward_fn_actor, None,
self.optimizer[key]['actor'].parameters, has_aux=True)
for key in self.model_keys}
Expand Down Expand Up @@ -109,9 +109,9 @@ def update(self, sample):
info.update({
f"{key}/learning_rate_actor": learning_rate_actor.asnumpy(),
f"{key}/learning_rate_critic": learning_rate_critic.asnumpy(),
f"{key}/loss_actor": loss_a.numpy(),
f"{key}/loss_critic": loss_c.numpy(),
f"{key}/predictQ": q_eval_a.mean().numpy()
f"{key}/loss_actor": loss_a.asnumpy(),
f"{key}/loss_critic": loss_c.asnumpy(),
f"{key}/predictQ": q_eval_a.mean().asnumpy()
})

self.policy.soft_update(self.tau)
Expand Down
181 changes: 114 additions & 67 deletions xuance/mindspore/learners/multi_agent_rl/maddpg_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,92 +5,139 @@
Implementation: MindSpore
Trick: Parameter sharing for all agents, with agents' one-hot IDs as actor-critic's inputs.
"""
from xuance.mindspore import ms, Module, Tensor, optim
from mindspore.nn import MSELoss
from xuance.mindspore import ms, Module, Tensor, optim, ops
from xuance.mindspore.learners import LearnerMAS
from xuance.mindspore.utils import clip_grads
from xuance.common import List
from argparse import Namespace
from operator import itemgetter


class MADDPG_Learner(LearnerMAS):
class ActorNetWithLossCell(Module):
def __init__(self, backbone, n_agents):
super(MADDPG_Learner.ActorNetWithLossCell, self).__init__()
self._backbone = backbone
self._mean = ms.ops.ReduceMean(keep_dims=True)
self.n_agents = n_agents

def construct(self, bs, o, ids, agt_mask):
_, actions_eval = self._backbone(o, ids)
loss_a = -(self._backbone.critic(o, actions_eval, ids) * agt_mask).sum() / agt_mask.sum()
return loss_a

class CriticNetWithLossCell(Module):
def __init__(self, backbone):
super(MADDPG_Learner.CriticNetWithLossCell, self).__init__()
self._backbone = backbone
self._loss = nn.MSELoss()

def construct(self, o, a_n, ids, agt_mask, tar_q):
q_eval = self._backbone.critic(o, a_n, ids)
td_error = (q_eval - tar_q) * agt_mask
loss_c = (td_error ** 2).sum() / agt_mask.sum()
return loss_c

def __init__(self,
config: Namespace,
model_keys: List[str],
agent_keys: List[str],
policy: Module):
self.gamma = gamma
self.tau = config.tau
self.sync_frequency = sync_frequency
self.mse_loss = nn.MSELoss()
super(MADDPG_Learner, self).__init__(config, model_keys, agent_keys, policy)
self.optimizer = {
'actor': optimizer[0],
'critic': optimizer[1]
}
key: {
'actor': optim.Adam(params=self.policy.parameters_actor[key], lr=self.config.learning_rate_actor,
eps=1e-5),
'critic': optim.Adam(params=self.policy.parameters_critic[key], lr=self.config.learning_rate_critic,
eps=1e-5)}
for key in self.model_keys}
self.scheduler = {
'actor': scheduler[0],
'critic': scheduler[1]
}
# define mindspore trainers
self.actor_loss_net = self.ActorNetWithLossCell(policy, self.n_agents)
self.actor_train = nn.TrainOneStepCell(self.actor_loss_net, self.optimizer['actor'])
self.actor_train.set_train()
self.critic_loss_net = self.CriticNetWithLossCell(policy)
self.critic_train = nn.TrainOneStepCell(self.critic_loss_net, self.optimizer['critic'])
self.critic_train.set_train()
key: {'actor': optim.lr_scheduler.LinearLR(self.optimizer[key]['actor'], start_factor=1.0,
end_factor=0.5, total_iters=self.config.running_steps),
'critic': optim.lr_scheduler.LinearLR(self.optimizer[key]['critic'], start_factor=1.0,
end_factor=0.5, total_iters=self.config.running_steps)}
for key in self.model_keys}
self.gamma = config.gamma
self.tau = config.tau
self.mse_loss = MSELoss()
# Get gradient function
self.grad_fn_actor = {key: ms.value_and_grad(self.forward_fn_actor, None,
self.optimizer[key]['actor'].parameters, has_aux=True)
for key in self.model_keys}
self.grad_fn_critic = {key: ms.value_and_grad(self.forward_fn_critic, None,
self.optimizer[key]['critic'].parameters, has_aux=True)
for key in self.model_keys}
self.policy.set_train()

def forward_fn_actor(self, batch_size, obs, obs_joint, actions, ids, mask_values, agent_key):
_, actions_eval = self.policy(observation=obs, agent_ids=ids)
if self.use_parameter_sharing:
act_eval = actions_eval[agent_key].reshape(batch_size, self.n_agents, -1).reshape(batch_size, -1)
else:
a_joint = {k: actions_eval[k] if k == agent_key else actions[k] for k in self.agent_keys}
act_eval = ops.cat(itemgetter(*self.agent_keys)(a_joint), axis=-1).reshape(batch_size, -1)
_, q_policy = self.policy.Qpolicy(joint_observation=obs_joint, joint_actions=act_eval,
agent_ids=ids, agent_key=agent_key)
q_policy_i = q_policy[agent_key].reshape(-1)
loss_a = -(q_policy_i * mask_values).sum() / mask_values.sum()
return loss_a, q_policy_i

def forward_fn_critic(self, obs_joint, actions_joint, ids, mask_values, q_target, agent_key):
_, q_eval = self.policy.Qpolicy(joint_observation=obs_joint, joint_actions=actions_joint, agent_ids=ids,
agent_key=agent_key)
q_eval_a = q_eval[agent_key].reshape(-1)
td_error = (q_eval_a - ops.stop_gradient(q_target)) * mask_values
loss_c = (td_error ** 2).sum() / mask_values.sum()
return loss_c, q_eval_a

def update(self, sample):
self.iterations += 1
obs = Tensor(sample['obs'])
actions = Tensor(sample['actions'])
obs_next = Tensor(sample['obs_next'])
rewards = Tensor(sample['rewards'])
terminals = Tensor(sample['terminals']).view(-1, self.n_agents, 1)
agent_mask = Tensor(sample['agent_mask']).view(-1, self.n_agents, 1)
batch_size = obs.shape[0]
IDs = ops.broadcast_to(self.expand_dims(self.eye(self.n_agents, self.n_agents, ms.float32), 0),
(batch_size, -1, -1))
# calculate the loss and train
actions_next = self.policy.target_actor(obs_next, IDs)
q_next = self.policy.target_critic(obs_next, actions_next, IDs)
q_target = rewards + (1 - terminals) * self.args.gamma * q_next
info = {}

# calculate the loss and train
loss_a = self.actor_train(batch_size, obs, IDs, agent_mask)
loss_c = self.critic_train(obs, actions, IDs, agent_mask, q_target)
self.policy.soft_update(self.tau)
# prepare training data
sample_Tensor = self.build_training_data(sample,
use_parameter_sharing=self.use_parameter_sharing,
use_actions_mask=False)
batch_size = sample_Tensor['batch_size']
obs = sample_Tensor['obs']
actions = sample_Tensor['actions']
obs_next = sample_Tensor['obs_next']
rewards = sample_Tensor['rewards']
terminals = sample_Tensor['terminals']
agent_mask = sample_Tensor['agent_mask']
IDs = sample_Tensor['agent_ids']
if self.use_parameter_sharing:
key = self.model_keys[0]
bs = batch_size * self.n_agents
obs_joint = obs[key].reshape(batch_size, -1)
next_obs_joint = obs_next[key].reshape(batch_size, -1)
actions_joint = actions[key].reshape(batch_size, -1)
rewards[key] = rewards[key].reshape(batch_size * self.n_agents)
terminals[key] = terminals[key].reshape(batch_size * self.n_agents)
else:
bs = batch_size
obs_joint = ops.cat(itemgetter(*self.agent_keys)(obs), axis=-1).reshape(batch_size, -1)
next_obs_joint = ops.cat(itemgetter(*self.agent_keys)(obs_next), axis=-1).reshape(batch_size, -1)
actions_joint = ops.cat(itemgetter(*self.agent_keys)(actions), axis=-1).reshape(batch_size, -1)

# get actions
_, actions_next = self.policy.Atarget(next_observation=obs_next, agent_ids=IDs)
# get values
if self.use_parameter_sharing:
key = self.model_keys[0]
actions_next_joint = actions_next[key].reshape(batch_size, self.n_agents, -1).reshape(batch_size, -1)
else:
actions_next_joint = ops.cat(itemgetter(*self.model_keys)(actions_next), -1).reshape(batch_size, -1)
_, q_next = self.policy.Qtarget(joint_observation=next_obs_joint, joint_actions=actions_next_joint,
agent_ids=IDs)

learning_rate_actor = self.scheduler['actor'](self.iterations).asnumpy()
learning_rate_critic = self.scheduler['critic'](self.iterations).asnumpy()
for key in self.model_keys:
mask_values = agent_mask[key]
# updata critic
q_next_i = q_next[key].reshape(bs)
q_target = rewards[key] + (1 - terminals[key]) * self.gamma * q_next_i
(loss_c, q_eval_a), grads_critic = self.grad_fn_critic[key](obs_joint, actions_joint, IDs, mask_values,
q_target, key)
if self.use_grad_clip:
grads_critic = clip_grads(grads_critic, Tensor(-self.grad_clip_norm), Tensor(self.grad_clip_norm))
self.optimizer[key]['critic'](grads_critic)

info = {
"learning_rate_actor": learning_rate_actor,
"learning_rate_critic": learning_rate_critic,
"loss_actor": loss_a.asnumpy(),
"loss_critic": loss_c.asnumpy()
}
# update actor
(loss_a, _), grads_actor = self.grad_fn_actor[key](batch_size, obs, obs_joint, actions, IDs, mask_values,
key)
if self.use_grad_clip:
grads_actor = clip_grads(grads_actor, Tensor(-self.grad_clip_norm), Tensor(self.grad_clip_norm))
self.optimizer[key]['actor'](grads_actor)

self.scheduler[key]['actor'].step()
self.scheduler[key]['critic'].step()
learning_rate_actor = self.scheduler[key]['actor'].get_last_lr()[0]
learning_rate_critic = self.scheduler[key]['critic'].get_last_lr()[0]

info.update({
f"{key}/learning_rate_actor": learning_rate_actor.asnumpy(),
f"{key}/learning_rate_critic": learning_rate_critic.asnumpy(),
f"{key}/loss_actor": loss_a.asnumpy(),
f"{key}/loss_critic": loss_c.asnumpy(),
f"{key}/predictQ": q_eval_a.mean().asnumpy()
})

self.policy.soft_update(self.tau)
return info
24 changes: 12 additions & 12 deletions xuance/mindspore/policies/deterministic_marl.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,10 +872,10 @@ def Qpolicy(self, joint_observation: Tensor, joint_actions: Tensor,
for key in agent_list:
if self.use_parameter_sharing:
if self.use_rnn:
joint_rep_out = outputs[key]['state'].unsqueeze(1).expand(-1, self.n_agents, -1, -1)
joint_rep_out = outputs[key]['state'].unsqueeze(1).broadcast_to((-1, self.n_agents, -1, -1))
joint_rep_out = joint_rep_out.reshape(bs, seq_len, -1)
else:
joint_rep_out = outputs[key]['state'].unsqueeze(1).expand(-1, self.n_agents, -1)
joint_rep_out = outputs[key]['state'].unsqueeze(1).broadcast_to((-1, self.n_agents, -1))
joint_rep_out = joint_rep_out.reshape(bs, -1)
critic_in = ops.cat([joint_rep_out, agent_ids], axis=-1)
else:
Expand Down Expand Up @@ -921,10 +921,10 @@ def Qtarget(self, joint_observation: Tensor, joint_actions: Tensor,
for key in agent_list:
if self.use_parameter_sharing:
if self.use_rnn:
joint_rep_out = outputs[key]['state'].unsqueeze(1).expand(-1, self.n_agents, -1, -1)
joint_rep_out = outputs[key]['state'].unsqueeze(1).broadcast_to((-1, self.n_agents, -1, -1))
joint_rep_out = joint_rep_out.reshape(bs, seq_len, -1)
else:
joint_rep_out = outputs[key]['state'].unsqueeze(1).expand(-1, self.n_agents, -1)
joint_rep_out = outputs[key]['state'].unsqueeze(1).broadcast_to((-1, self.n_agents, -1))
joint_rep_out = joint_rep_out.reshape(bs, -1)
critic_in = ops.cat([joint_rep_out, agent_ids], axis=-1)
else:
Expand Down Expand Up @@ -1031,13 +1031,13 @@ def Qpolicy(self, joint_observation: Tensor, joint_actions: Tensor,
for key in agent_list:
if self.use_parameter_sharing:
if self.use_rnn:
joint_rep_out_A = outputs_A[key]['state'].unsqueeze(1).expand(-1, self.n_agents, -1, -1)
joint_rep_out_B = outputs_B[key]['state'].unsqueeze(1).expand(-1, self.n_agents, -1, -1)
joint_rep_out_A = outputs_A[key]['state'].unsqueeze(1).broadcast_to((-1, self.n_agents, -1, -1))
joint_rep_out_B = outputs_B[key]['state'].unsqueeze(1).broadcast_to((-1, self.n_agents, -1, -1))
joint_rep_out_A = joint_rep_out_A.reshape(bs, seq_len, -1)
joint_rep_out_B = joint_rep_out_B.reshape(bs, seq_len, -1)
else:
joint_rep_out_A = outputs_A[key]['state'].unsqueeze(1).expand(-1, self.n_agents, -1)
joint_rep_out_B = outputs_B[key]['state'].unsqueeze(1).expand(-1, self.n_agents, -1)
joint_rep_out_A = outputs_A[key]['state'].unsqueeze(1).broadcast_to((-1, self.n_agents, -1))
joint_rep_out_B = outputs_B[key]['state'].unsqueeze(1).broadcast_to((-1, self.n_agents, -1))
joint_rep_out_A = joint_rep_out_A.reshape(bs, -1)
joint_rep_out_B = joint_rep_out_B.reshape(bs, -1)
critic_in_A = ops.cat([joint_rep_out_A, agent_ids], axis=-1)
Expand Down Expand Up @@ -1091,13 +1091,13 @@ def Qtarget(self, joint_observation: Tensor, joint_actions: Tensor,
for key in agent_list:
if self.use_parameter_sharing:
if self.use_rnn:
joint_rep_out_A = outputs_A[key]['state'].unsqueeze(1).expand(-1, self.n_agents, -1, -1)
joint_rep_out_B = outputs_B[key]['state'].unsqueeze(1).expand(-1, self.n_agents, -1, -1)
joint_rep_out_A = outputs_A[key]['state'].unsqueeze(1).broadcast_to((-1, self.n_agents, -1, -1))
joint_rep_out_B = outputs_B[key]['state'].unsqueeze(1).broadcast_to((-1, self.n_agents, -1, -1))
joint_rep_out_A = joint_rep_out_A.reshape(bs, seq_len, -1)
joint_rep_out_B = joint_rep_out_B.reshape(bs, seq_len, -1)
else:
joint_rep_out_A = outputs_A[key]['state'].unsqueeze(1).expand(-1, self.n_agents, -1)
joint_rep_out_B = outputs_B[key]['state'].unsqueeze(1).expand(-1, self.n_agents, -1)
joint_rep_out_A = outputs_A[key]['state'].unsqueeze(1).broadcast_to((-1, self.n_agents, -1))
joint_rep_out_B = outputs_B[key]['state'].unsqueeze(1).broadcast_to((-1, self.n_agents, -1))
joint_rep_out_A = joint_rep_out_A.reshape(bs, -1)
joint_rep_out_B = joint_rep_out_B.reshape(bs, -1)
critic_in_A = ops.cat([joint_rep_out_A, agent_ids], axis=-1)
Expand Down

0 comments on commit a64c263

Please sign in to comment.