From a33e1f9654b75566bf94d5bf100505a9c4be54b5 Mon Sep 17 00:00:00 2001 From: YKizi Date: Thu, 30 Nov 2023 11:46:05 +0800 Subject: [PATCH] pdqn.rst APIs(#01) --- .../documents/api/learners/drl/pdqn.rst | 107 ++++++++++++++++-- 1 file changed, 100 insertions(+), 7 deletions(-) diff --git a/docs/source/documents/api/learners/drl/pdqn.rst b/docs/source/documents/api/learners/drl/pdqn.rst index 0cc73a35..5b3fb210 100644 --- a/docs/source/documents/api/learners/drl/pdqn.rst +++ b/docs/source/documents/api/learners/drl/pdqn.rst @@ -7,6 +7,41 @@ PDQN_Learner **PyTorch:** +.. py:class:: + xuance.torch.learners.qlearning_family.pdqn_learner.PDQN_Learner(policy, optimizer, scheduler, summary_writer, device, model_dir, gamma, tau) + + :param policy: xxxxxx. + :type policy: xxxxxx + :param optimizer: xxxxxx. + :type optimizer: xxxxxx + :param scheduler: xxxxxx. + :type scheduler: xxxxxx + :param summary_writer: xxxxxx. + :type summary_writer: xxxxxx + :param device: xxxxxx. + :type device: xxxxxx + :param model_dir: xxxxxx. + :type model_dir: xxxxxx + :param gamma: xxxxxx. + :type gamma: xxxxxx + :param tau: xxxxxx. + :type tau: xxxxxx + +.. py:function:: + xuance.torch.learners.qlearning_family.pdqn_learner.PDQN_Learner.update(obs_batch, act_batch, rew_batch, next_batch, terminal_batch) + + :param obs_batch: xxxxxx. + :type obs_batch: xxxxxx + :param act_batch: xxxxxx. + :type act_batch: xxxxxx + :param rew_batch: xxxxxx. + :type rew_batch: xxxxxx + :param next_batch: xxxxxx. + :type next_batch: xxxxxx + :param terminal_batch: xxxxxx. + :type terminal_batch: xxxxxx + :return: None. + :rtype: xxxxxx .. raw:: html @@ -28,18 +63,76 @@ Source Code ----------------- .. tabs:: - - .. group-tab:: PyTorch - .. code-block:: python3 + .. group-tab:: PyTorch + .. code-block:: python + from xuance.torch.learners import * + class PDQN_Learner(Learner): + def __init__(self, + policy: nn.Module, + optimizers: Sequence[torch.optim.Optimizer], + schedulers: Sequence[torch.optim.lr_scheduler._LRScheduler], + summary_writer: Optional[SummaryWriter] = None, + device: Optional[Union[int, str, torch.device]] = None, + model_dir: str = "./", + gamma: float = 0.99, + tau: float = 0.01): + self.tau = tau + self.gamma = gamma + super(PDQN_Learner, self).__init__(policy, optimizers, schedulers, summary_writer, device, model_dir) - .. group-tab:: TensorFlow + def update(self, obs_batch, act_batch, rew_batch, next_batch, terminal_batch): + self.iterations += 1 + obs_batch = torch.as_tensor(obs_batch, device=self.device) + hyact_batch = torch.as_tensor(act_batch, device=self.device) + disact_batch = hyact_batch[:, 0].long() + conact_batch = hyact_batch[:, 1:] + rew_batch = torch.as_tensor(rew_batch, device=self.device) + next_batch = torch.as_tensor(next_batch, device=self.device) + ter_batch = torch.as_tensor(terminal_batch, device=self.device) - .. code-block:: python3 + # optimize Q-network + with torch.no_grad(): + target_conact = self.policy.Atarget(next_batch) + target_q = self.policy.Qtarget(next_batch, target_conact) + target_q = torch.max(target_q, 1, keepdim=True)[0].squeeze() - .. group-tab:: MindSpore + target_q = rew_batch + (1 - ter_batch) * self.gamma * target_q - .. code-block:: python3 \ No newline at end of file + eval_qs = self.policy.Qeval(obs_batch, conact_batch) + eval_q = eval_qs.gather(1, disact_batch.view(-1, 1)).squeeze() + q_loss = F.mse_loss(eval_q, target_q) + + self.optimizer[1].zero_grad() + q_loss.backward() + self.optimizer[1].step() + + # optimize actor network + policy_q = self.policy.Qpolicy(obs_batch) + p_loss = - policy_q.mean() + self.optimizer[0].zero_grad() + p_loss.backward() + self.optimizer[0].step() + + if self.scheduler is not None: + self.scheduler[0].step() + self.scheduler[1].step() + + self.policy.soft_update(self.tau) + + self.writer.add_scalar("Q_loss", q_loss.item(), self.iterations) + self.writer.add_scalar("P_loss", q_loss.item(), self.iterations) + self.writer.add_scalar('Qvalue', eval_q.mean().item(), self.iterations) + + + .. group-tab:: TensorFlow + + .. code-block:: python + + + .. group-tab:: MindSpore + + .. code-block:: python \ No newline at end of file