Skip to content

Commit

Permalink
pdqn.rst APIs(#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ykizi committed Nov 30, 2023
1 parent 535b5b6 commit a33e1f9
Showing 1 changed file with 100 additions and 7 deletions.
107 changes: 100 additions & 7 deletions docs/source/documents/api/learners/drl/pdqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,41 @@ PDQN_Learner

**PyTorch:**

.. py:class::
xuance.torch.learners.qlearning_family.pdqn_learner.PDQN_Learner(policy, optimizer, scheduler, summary_writer, device, model_dir, gamma, tau)

:param policy: xxxxxx.
:type policy: xxxxxx
:param optimizer: xxxxxx.
:type optimizer: xxxxxx
:param scheduler: xxxxxx.
:type scheduler: xxxxxx
:param summary_writer: xxxxxx.
:type summary_writer: xxxxxx
:param device: xxxxxx.
:type device: xxxxxx
:param model_dir: xxxxxx.
:type model_dir: xxxxxx
:param gamma: xxxxxx.
:type gamma: xxxxxx
:param tau: xxxxxx.
:type tau: xxxxxx

.. py:function::
xuance.torch.learners.qlearning_family.pdqn_learner.PDQN_Learner.update(obs_batch, act_batch, rew_batch, next_batch, terminal_batch)

:param obs_batch: xxxxxx.
:type obs_batch: xxxxxx
:param act_batch: xxxxxx.
:type act_batch: xxxxxx
:param rew_batch: xxxxxx.
:type rew_batch: xxxxxx
:param next_batch: xxxxxx.
:type next_batch: xxxxxx
:param terminal_batch: xxxxxx.
:type terminal_batch: xxxxxx
:return: None.
:rtype: xxxxxx

.. raw:: html

Expand All @@ -28,18 +63,76 @@ Source Code
-----------------

.. tabs::

.. group-tab:: PyTorch

.. code-block:: python3
.. group-tab:: PyTorch

.. code-block:: python
from xuance.torch.learners import *
class PDQN_Learner(Learner):
def __init__(self,
policy: nn.Module,
optimizers: Sequence[torch.optim.Optimizer],
schedulers: Sequence[torch.optim.lr_scheduler._LRScheduler],
summary_writer: Optional[SummaryWriter] = None,
device: Optional[Union[int, str, torch.device]] = None,
model_dir: str = "./",
gamma: float = 0.99,
tau: float = 0.01):
self.tau = tau
self.gamma = gamma
super(PDQN_Learner, self).__init__(policy, optimizers, schedulers, summary_writer, device, model_dir)
.. group-tab:: TensorFlow
def update(self, obs_batch, act_batch, rew_batch, next_batch, terminal_batch):
self.iterations += 1
obs_batch = torch.as_tensor(obs_batch, device=self.device)
hyact_batch = torch.as_tensor(act_batch, device=self.device)
disact_batch = hyact_batch[:, 0].long()
conact_batch = hyact_batch[:, 1:]
rew_batch = torch.as_tensor(rew_batch, device=self.device)
next_batch = torch.as_tensor(next_batch, device=self.device)
ter_batch = torch.as_tensor(terminal_batch, device=self.device)
.. code-block:: python3
# optimize Q-network
with torch.no_grad():
target_conact = self.policy.Atarget(next_batch)
target_q = self.policy.Qtarget(next_batch, target_conact)
target_q = torch.max(target_q, 1, keepdim=True)[0].squeeze()
.. group-tab:: MindSpore
target_q = rew_batch + (1 - ter_batch) * self.gamma * target_q
.. code-block:: python3
eval_qs = self.policy.Qeval(obs_batch, conact_batch)
eval_q = eval_qs.gather(1, disact_batch.view(-1, 1)).squeeze()
q_loss = F.mse_loss(eval_q, target_q)
self.optimizer[1].zero_grad()
q_loss.backward()
self.optimizer[1].step()
# optimize actor network
policy_q = self.policy.Qpolicy(obs_batch)
p_loss = - policy_q.mean()
self.optimizer[0].zero_grad()
p_loss.backward()
self.optimizer[0].step()
if self.scheduler is not None:
self.scheduler[0].step()
self.scheduler[1].step()
self.policy.soft_update(self.tau)
self.writer.add_scalar("Q_loss", q_loss.item(), self.iterations)
self.writer.add_scalar("P_loss", q_loss.item(), self.iterations)
self.writer.add_scalar('Qvalue', eval_q.mean().item(), self.iterations)
.. group-tab:: TensorFlow

.. code-block:: python
.. group-tab:: MindSpore

.. code-block:: python

0 comments on commit a33e1f9

Please sign in to comment.