diff --git a/docs/source/documents/api/learners/marl/isac.rst b/docs/source/documents/api/learners/marl/isac.rst
index 62e2839a..74e6af24 100644
--- a/docs/source/documents/api/learners/marl/isac.rst
+++ b/docs/source/documents/api/learners/marl/isac.rst
@@ -1,12 +1,43 @@
ISAC_Learner
=====================================
+xxxxxx.
+
.. raw:: html
**PyTorch:**
+.. py:class::
+ xuance.torch.learners.multi_agent_rl.isac_learner.ISAC_Learner(config, policy, optimizer, scheduler, device, model_dir, gamma, sync_frequency)
+
+ :param config: xxxxxx.
+ :type config: xxxxxx
+ :param policy: xxxxxx.
+ :type policy: xxxxxx
+ :param optimizer: xxxxxx.
+ :type optimizer: xxxxxx
+ :param scheduler: xxxxxx.
+ :type scheduler: xxxxxx
+ :param device: xxxxxx.
+ :type device: xxxxxx
+ :param model_dir: xxxxxx.
+ :type model_dir: xxxxxx
+ :param gamma: xxxxxx.
+ :type gamma: xxxxxx
+ :param sync_frequency: xxxxxx.
+ :type sync_frequency: xxxxxx
+
+.. py:function::
+ xuance.torch.learners.multi_agent_rl.isac_learner.ISAC_Learner.update(sample)
+
+ xxxxxx.
+
+ :param sample: xxxxxx.
+ :type sample: xxxxxx
+ :return: xxxxxx.
+ :rtype: xxxxxx
.. raw:: html
@@ -14,7 +45,6 @@ ISAC_Learner
**TensorFlow:**
-
.. raw:: html
@@ -29,20 +59,112 @@ Source Code
-----------------
.. tabs::
-
- .. group-tab:: PyTorch
-
- .. code-block:: python3
+
+ .. group-tab:: PyTorch
+
+ .. code-block:: python
+
+ """
+ Independent Soft Actor-critic (ISAC)
+ Implementation: Pytorch
+ """
+ from xuance.torch.learners import *
+
+
+ class ISAC_Learner(LearnerMAS):
+ def __init__(self,
+ config: Namespace,
+ policy: nn.Module,
+ optimizer: Sequence[torch.optim.Optimizer],
+ scheduler: Sequence[torch.optim.lr_scheduler._LRScheduler] = None,
+ device: Optional[Union[int, str, torch.device]] = None,
+ model_dir: str = "./",
+ gamma: float = 0.99,
+ sync_frequency: int = 100
+ ):
+ self.gamma = gamma
+ self.tau = config.tau
+ self.alpha = config.alpha
+ self.sync_frequency = sync_frequency
+ self.mse_loss = nn.MSELoss()
+ super(ISAC_Learner, self).__init__(config, policy, optimizer, scheduler, device, model_dir)
+ self.optimizer = {
+ 'actor': optimizer[0],
+ 'critic': optimizer[1]
+ }
+ self.scheduler = {
+ 'actor': scheduler[0],
+ 'critic': scheduler[1]
+ }
+
+ def update(self, sample):
+ self.iterations += 1
+ obs = torch.Tensor(sample['obs']).to(self.device)
+ actions = torch.Tensor(sample['actions']).to(self.device)
+ obs_next = torch.Tensor(sample['obs_next']).to(self.device)
+ rewards = torch.Tensor(sample['rewards']).to(self.device)
+ terminals = torch.Tensor(sample['terminals']).float().reshape(-1, self.n_agents, 1).to(self.device)
+ agent_mask = torch.Tensor(sample['agent_mask']).float().reshape(-1, self.n_agents, 1).to(self.device)
+ IDs = torch.eye(self.n_agents).unsqueeze(0).expand(self.args.batch_size, -1, -1).to(self.device)
+
+ q_eval = self.policy.critic(obs, actions, IDs)
+ actions_next_dist = self.policy.target_actor(obs_next, IDs)
+ actions_next = actions_next_dist.rsample()
+ log_pi_a_next = actions_next_dist.log_prob(actions_next)
+ q_next = self.policy.target_critic(obs_next, actions_next, IDs)
+ q_target = rewards + (1-terminals) * self.args.gamma * (q_next - self.alpha * log_pi_a_next.unsqueeze(dim=-1))
+
+ # calculate the loss function
+ _, actions_dist = self.policy(obs, IDs)
+ actions_eval = actions_dist.rsample()
+ log_pi_a = actions_dist.log_prob(actions_eval)
+ loss_a = -(self.policy.critic(obs, actions_eval, IDs) - self.alpha * log_pi_a.unsqueeze(dim=-1) * agent_mask).sum() / agent_mask.sum()
+ # loss_a = (- self.policy.critic(obs, actions_eval, IDs)) * agent_mask.sum() / agent_mask.sum()
+ self.optimizer['actor'].zero_grad()
+ loss_a.backward()
+ torch.nn.utils.clip_grad_norm_(self.policy.parameters_actor, self.args.grad_clip_norm)
+ self.optimizer['actor'].step()
+ if self.scheduler['actor'] is not None:
+ self.scheduler['actor'].step()
+
+ td_error = (q_eval - q_target.detach()) * agent_mask
+ loss_c = (td_error ** 2).sum() / agent_mask.sum()
+ self.optimizer['critic'].zero_grad()
+ loss_c.backward()
+ torch.nn.utils.clip_grad_norm_(self.policy.parameters_critic, self.args.grad_clip_norm)
+ self.optimizer['critic'].step()
+ if self.scheduler['critic'] is not None:
+ self.scheduler['critic'].step()
+
+ self.policy.soft_update(self.tau)
+
+ lr_a = self.optimizer['actor'].state_dict()['param_groups'][0]['lr']
+ lr_c = self.optimizer['critic'].state_dict()['param_groups'][0]['lr']
+
+ info = {
+ "learning_rate_actor": lr_a,
+ "learning_rate_critic": lr_c,
+ "loss_actor": loss_a.item(),
+ "loss_critic": loss_c.item(),
+ "predictQ": q_eval.mean().item()
+ }
+
+ return info
+
+
+
+
- .. group-tab:: TensorFlow
-
- .. code-block:: python3
- .. group-tab:: MindSpore
+ .. group-tab:: TensorFlow
- .. code-block:: python3
-
\ No newline at end of file
+ .. code-block:: python
+
+
+ .. group-tab:: MindSpore
+
+ .. code-block:: python
diff --git a/docs/source/documents/api/learners/marl/maddpg.rst b/docs/source/documents/api/learners/marl/maddpg.rst
index 7362aafb..0e11976a 100644
--- a/docs/source/documents/api/learners/marl/maddpg.rst
+++ b/docs/source/documents/api/learners/marl/maddpg.rst
@@ -1,12 +1,43 @@
MADDPG_Learner
=====================================
+xxxxxx.
+
.. raw:: html
**PyTorch:**
+.. py:class::
+ xuance.torch.learners.multi_agent_rl.maddpg_learner.MADDPG_Learner(config, policy, optimizer, scheduler, device, model_dir, gamma, sync_frequency)
+
+ :param config: xxxxxx.
+ :type config: xxxxxx
+ :param policy: xxxxxx.
+ :type policy: xxxxxx
+ :param optimizer: xxxxxx.
+ :type optimizer: xxxxxx
+ :param scheduler: xxxxxx.
+ :type scheduler: xxxxxx
+ :param device: xxxxxx.
+ :type device: xxxxxx
+ :param model_dir: xxxxxx.
+ :type model_dir: xxxxxx
+ :param gamma: xxxxxx.
+ :type gamma: xxxxxx
+ :param sync_frequency: xxxxxx.
+ :type sync_frequency: xxxxxx
+
+.. py:function::
+ xuance.torch.learners.multi_agent_rl.isac_learner.ISAC_Learner.update(sample)
+
+ xxxxxx.
+
+ :param sample: xxxxxx.
+ :type sample: xxxxxx
+ :return: xxxxxx.
+ :rtype: xxxxxx
.. raw:: html
@@ -14,7 +45,6 @@ MADDPG_Learner
**TensorFlow:**
-
.. raw:: html
@@ -30,21 +60,114 @@ Source Code
.. tabs::
- .. group-tab:: PyTorch
-
- .. code-block:: python3
+ .. group-tab:: PyTorch
+
+ .. code-block:: python
+
+ """
+ Multi-Agent Deep Deterministic Policy Gradient
+ Paper link:
+ https://proceedings.neurips.cc/paper/2017/file/68a9750337a418a86fe06c1991a1d64c-Paper.pdf
+ Implementation: Pytorch
+ Trick: Parameter sharing for all agents, with agents' one-hot IDs as actor-critic's inputs.
+ """
+ from xuance.torch.learners import *
+
+
+ class MADDPG_Learner(LearnerMAS):
+ def __init__(self,
+ config: Namespace,
+ policy: nn.Module,
+ optimizer: Sequence[torch.optim.Optimizer],
+ scheduler: Sequence[torch.optim.lr_scheduler._LRScheduler] = None,
+ device: Optional[Union[int, str, torch.device]] = None,
+ model_dir: str = "./",
+ gamma: float = 0.99,
+ sync_frequency: int = 100
+ ):
+ self.gamma = gamma
+ self.tau = config.tau
+ self.sync_frequency = sync_frequency
+ self.mse_loss = nn.MSELoss()
+ super(MADDPG_Learner, self).__init__(config, policy, optimizer, scheduler, device, model_dir)
+ self.optimizer = {
+ 'actor': optimizer[0],
+ 'critic': optimizer[1]
+ }
+ self.scheduler = {
+ 'actor': scheduler[0],
+ 'critic': scheduler[1]
+ }
+
+ def update(self, sample):
+ self.iterations += 1
+ obs = torch.Tensor(sample['obs']).to(self.device)
+ actions = torch.Tensor(sample['actions']).to(self.device)
+ obs_next = torch.Tensor(sample['obs_next']).to(self.device)
+ rewards = torch.Tensor(sample['rewards']).to(self.device)
+ terminals = torch.Tensor(sample['terminals']).float().reshape(-1, self.n_agents, 1).to(self.device)
+ agent_mask = torch.Tensor(sample['agent_mask']).float().reshape(-1, self.n_agents, 1).to(self.device)
+ IDs = torch.eye(self.n_agents).unsqueeze(0).expand(self.args.batch_size, -1, -1).to(self.device)
+
+ # train actor
+ _, actions_eval = self.policy(obs, IDs)
+ loss_a = -(self.policy.critic(obs, actions_eval, IDs) * agent_mask).sum() / agent_mask.sum()
+ self.optimizer['actor'].zero_grad()
+ loss_a.backward()
+ if self.args.use_grad_clip:
+ torch.nn.utils.clip_grad_norm_(self.policy.parameters_actor, self.args.grad_clip_norm)
+ self.optimizer['actor'].step()
+ if self.scheduler['actor'] is not None:
+ self.scheduler['actor'].step()
+
+ # train critic
+ actions_next = self.policy.target_actor(obs_next, IDs)
+ q_eval = self.policy.critic(obs, actions, IDs)
+ q_next = self.policy.target_critic(obs_next, actions_next, IDs)
+ q_target = rewards + (1 - terminals) * self.args.gamma * q_next
+ td_error = (q_eval - q_target.detach()) * agent_mask
+ loss_c = (td_error ** 2).sum() / agent_mask.sum()
+ self.optimizer['critic'].zero_grad()
+ loss_c.backward()
+ if self.args.use_grad_clip:
+ torch.nn.utils.clip_grad_norm_(self.policy.parameters_critic, self.args.grad_clip_norm)
+ self.optimizer['critic'].step()
+ if self.scheduler['critic'] is not None:
+ self.scheduler['critic'].step()
+
+ self.policy.soft_update(self.tau)
+
+ lr_a = self.optimizer['actor'].state_dict()['param_groups'][0]['lr']
+ lr_c = self.optimizer['critic'].state_dict()['param_groups'][0]['lr']
+
+ info = {
+ "learning_rate_actor": lr_a,
+ "learning_rate_critic": lr_c,
+ "loss_actor": loss_a.item(),
+ "loss_critic": loss_c.item(),
+ "predictQ": q_eval.mean().item()
+ }
+
+ return info
+
+
+
+
+
+
+
+
- .. group-tab:: TensorFlow
-
- .. code-block:: python3
+ .. group-tab:: TensorFlow
+ .. code-block:: python
- .. group-tab:: MindSpore
+ .. group-tab:: MindSpore
+
+ .. code-block:: python
- .. code-block:: python3
-
\ No newline at end of file