diff --git a/docs/source/documents/api/learners/drl/a2c.rst b/docs/source/documents/api/learners/drl/a2c.rst index b5f48dde..4d811e8d 100644 --- a/docs/source/documents/api/learners/drl/a2c.rst +++ b/docs/source/documents/api/learners/drl/a2c.rst @@ -7,6 +7,39 @@ A2C_Learner **PyTorch:** +.. py:class:: + xuance.torch.learners.policy_gradient.a2c_learner.A2C_Learner(policy, optimizer, scheduler, device, model_dir, vf_coef, ent_coef, clip_grad) + + :param policy: xxxxxx. + :type policy: xxxxxx + :param optimizer: xxxxxx. + :type optimizer: xxxxxx + :param scheduler: xxxxxx. + :type scheduler: xxxxxx + :param device: xxxxxx. + :type device: xxxxxx + :param model_dir: xxxxxx. + :type model_dir: xxxxxx + :param vf_coef: xxxxxx. + :type vf_coef: xxxxxx + :param ent_coef: xxxxxx. + :type ent_coef: xxxxxx + :param clip_grad: xxxxxx. + :type clip_grad: xxxxxx + +.. py:function:: + xuance.torch.learners.policy_gradient.a2c_learner.A2C_Learner.update(obs_batch, act_batch, ret_batch, adv_batch) + + :param obs_batch: xxxxxx. + :type obs_batch: xxxxxx + :param act_batch: xxxxxx. + :type act_batch: xxxxxx + :param ret_batch: xxxxxx. + :type ret_batch: xxxxxx + :param adv_batch: xxxxxx. + :type adv_batch: xxxxxx + :return: xxxxxx. + :rtype: xxxxxx .. raw:: html @@ -28,18 +61,70 @@ Source Code ----------------- .. tabs:: - - .. group-tab:: PyTorch - .. code-block:: python3 + .. group-tab:: PyTorch + .. code-block:: python + from xuance.torch.learners import * - .. group-tab:: TensorFlow + class A2C_Learner(Learner): + def __init__(self, + policy: nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + device: Optional[Union[int, str, torch.device]] = None, + model_dir: str = "./", + vf_coef: float = 0.25, + ent_coef: float = 0.005, + clip_grad: Optional[float] = None): + super(A2C_Learner, self).__init__(policy, optimizer, scheduler, device, model_dir) + self.vf_coef = vf_coef + self.ent_coef = ent_coef + self.clip_grad = clip_grad - .. code-block:: python3 + def update(self, obs_batch, act_batch, ret_batch, adv_batch): + self.iterations += 1 + act_batch = torch.as_tensor(act_batch, device=self.device) + ret_batch = torch.as_tensor(ret_batch, device=self.device) + adv_batch = torch.as_tensor(adv_batch, device=self.device) + outputs, a_dist, v_pred = self.policy(obs_batch) + log_prob = a_dist.log_prob(act_batch) - .. group-tab:: MindSpore + a_loss = -(adv_batch * log_prob).mean() + c_loss = F.mse_loss(v_pred, ret_batch) + e_loss = a_dist.entropy().mean() - .. code-block:: python3 \ No newline at end of file + loss = a_loss - self.ent_coef * e_loss + self.vf_coef * c_loss + self.optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.clip_grad) + self.optimizer.step() + if self.scheduler is not None: + self.scheduler.step() + + # Logger + lr = self.optimizer.state_dict()['param_groups'][0]['lr'] + + info = { + "actor-loss": a_loss.item(), + "critic-loss": c_loss.item(), + "entropy": e_loss.item(), + "learning_rate": lr, + "predict_value": v_pred.mean().item() + } + + return info + + + + + .. group-tab:: TensorFlow + + .. code-block:: python + + + .. group-tab:: MindSpore + + .. code-block:: python \ No newline at end of file diff --git a/docs/source/documents/api/learners/drl/c51.rst b/docs/source/documents/api/learners/drl/c51.rst index f1797c41..00020430 100644 --- a/docs/source/documents/api/learners/drl/c51.rst +++ b/docs/source/documents/api/learners/drl/c51.rst @@ -7,6 +7,39 @@ C51_Learner **PyTorch:** +.. py:class:: + xuance.torch.learners.qlearning_family.c51_learner.C51_Learner(policy, optimizer, scheduler, device, model_dir, gamma, sync_frequency) + + :param policy: xxxxxx. + :type policy: xxxxxx + :param optimizer: xxxxxx. + :type optimizer: xxxxxx + :param scheduler: xxxxxx. + :type scheduler: xxxxxx + :param device: xxxxxx. + :type device: xxxxxx + :param model_dir: xxxxxx. + :type model_dir: xxxxxx + :param gamma: xxxxxx. + :type gamma: xxxxxx + :param sync_frequency: xxxxxx. + :type sync_frequency: xxxxxx + +.. py:function:: + xuance.torch.learners.qlearning_family.c51_learner.C51_Learner.update(obs_batch, act_batch, rew_batch, next_batch, terminal_batch) + + :param obs_batch: xxxxxx. + :type obs_batch: xxxxxx + :param act_batch: xxxxxx. + :type act_batch: xxxxxx + :param rew_batch: xxxxxx. + :type rew_batch: xxxxxx + :param next_batch: xxxxxx. + :type next_batch: xxxxxx + :param terminal_batch: xxxxxx. + :type terminal_batch: xxxxxx + :return: xxxxxx. + :rtype: xxxxxx .. raw:: html @@ -28,18 +61,72 @@ Source Code ----------------- .. tabs:: - - .. group-tab:: PyTorch - .. code-block:: python3 + .. group-tab:: PyTorch + .. code-block:: python + from xuance.torch.learners import * - .. group-tab:: TensorFlow + class C51_Learner(Learner): + def __init__(self, + policy: nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + device: Optional[Union[int, str, torch.device]] = None, + model_dir: str = "./", + gamma: float = 0.99, + sync_frequency: int = 100): + self.gamma = gamma + self.sync_frequency = sync_frequency + super(C51_Learner, self).__init__(policy, optimizer, scheduler, device, model_dir) - .. code-block:: python3 + def update(self, obs_batch, act_batch, rew_batch, next_batch, terminal_batch): + self.iterations += 1 + act_batch = torch.as_tensor(act_batch, device=self.device).long() + rew_batch = torch.as_tensor(rew_batch, device=self.device) + ter_batch = torch.as_tensor(terminal_batch, device=self.device) + _, _, evalZ = self.policy(obs_batch) + _, targetA, targetZ = self.policy.target(next_batch) - .. group-tab:: MindSpore + current_dist = (evalZ * F.one_hot(act_batch, evalZ.shape[1]).unsqueeze(-1)).sum(1) + target_dist = (targetZ * F.one_hot(targetA.detach(), evalZ.shape[1]).unsqueeze(-1)).sum(1).detach() - .. code-block:: python3 \ No newline at end of file + current_supports = self.policy.supports + next_supports = rew_batch.unsqueeze(1) + self.gamma * self.policy.supports * (1 - ter_batch.unsqueeze(1)) + next_supports = next_supports.clamp(self.policy.vmin, self.policy.vmax) + + projection = 1 - (next_supports.unsqueeze(-1) - current_supports.unsqueeze(0)).abs() / self.policy.deltaz + target_dist = torch.bmm(target_dist.unsqueeze(1), projection.clamp(0, 1)).squeeze(1) + loss = -(target_dist * torch.log(current_dist + 1e-8)).sum(1).mean() + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + if self.scheduler is not None: + self.scheduler.step() + # hard update for target network + if self.iterations % self.sync_frequency == 0: + self.policy.copy_target() + lr = self.optimizer.state_dict()['param_groups'][0]['lr'] + + info = { + "Qloss": loss.item(), + "learning_rate": lr + } + + return info + + + + + + + .. group-tab:: TensorFlow + + .. code-block:: python + + + .. group-tab:: MindSpore + + .. code-block:: python \ No newline at end of file diff --git a/docs/source/documents/api/learners/drl/ddpg.rst b/docs/source/documents/api/learners/drl/ddpg.rst index e5a55a97..9dc6e734 100644 --- a/docs/source/documents/api/learners/drl/ddpg.rst +++ b/docs/source/documents/api/learners/drl/ddpg.rst @@ -7,6 +7,39 @@ DDPG_Learner **PyTorch:** +.. py:class:: + xuance.torch.learners.policy_gradient.ddpg_learner.DDPG_Learner(policy, optimizer, scheduler, device, model_dir, gamma, tau) + + :param policy: xxxxxx. + :type policy: xxxxxx + :param optimizer: xxxxxx. + :type optimizer: xxxxxx + :param scheduler: xxxxxx. + :type scheduler: xxxxxx + :param device: xxxxxx. + :type device: xxxxxx + :param model_dir: xxxxxx. + :type model_dir: xxxxxx + :param gamma: xxxxxx. + :type gamma: xxxxxx + :param tau: xxxxxx. + :type tau: xxxxxx + +.. py:function:: + xuance.torch.learners.policy_gradient.ddpg_learner.DDPG_Learner.update(obs_batch, act_batch, rew_batch, next_batch, terminal_batch) + + :param obs_batch: xxxxxx. + :type obs_batch: xxxxxx + :param act_batch: xxxxxx. + :type act_batch: xxxxxx + :param rew_batch: xxxxxx. + :type rew_batch: xxxxxx + :param next_batch: xxxxxx. + :type next_batch: xxxxxx + :param terminal_batch: xxxxxx. + :type terminal_batch: xxxxxx + :return: xxxxxx. + :rtype: xxxxxx .. raw:: html @@ -28,18 +61,77 @@ Source Code ----------------- .. tabs:: - - .. group-tab:: PyTorch - .. code-block:: python3 + .. group-tab:: PyTorch + .. code-block:: python + from xuance.torch.learners import * - .. group-tab:: TensorFlow + class DDPG_Learner(Learner): + def __init__(self, + policy: nn.Module, + optimizers: Sequence[torch.optim.Optimizer], + schedulers: Sequence[torch.optim.lr_scheduler._LRScheduler], + device: Optional[Union[int, str, torch.device]] = None, + model_dir: str = "./", + gamma: float = 0.99, + tau: float = 0.01): + self.tau = tau + self.gamma = gamma + super(DDPG_Learner, self).__init__(policy, optimizers, schedulers, device, model_dir) - .. code-block:: python3 + def update(self, obs_batch, act_batch, rew_batch, next_batch, terminal_batch): + self.iterations += 1 + act_batch = torch.as_tensor(act_batch, device=self.device) + rew_batch = torch.as_tensor(rew_batch, device=self.device) + ter_batch = torch.as_tensor(terminal_batch, device=self.device) + # critic update + action_q = self.policy.Qaction(obs_batch, act_batch) + # with torch.no_grad(): + target_q = self.policy.Qtarget(next_batch) + backup = rew_batch + (1 - ter_batch) * self.gamma * target_q + q_loss = F.mse_loss(action_q, backup.detach()) + self.optimizer[1].zero_grad() + q_loss.backward() + self.optimizer[1].step() - .. group-tab:: MindSpore + # actor update + policy_q = self.policy.Qpolicy(obs_batch) + p_loss = -policy_q.mean() + self.optimizer[0].zero_grad() + p_loss.backward() + self.optimizer[0].step() - .. code-block:: python3 \ No newline at end of file + if self.scheduler is not None: + self.scheduler[0].step() + self.scheduler[1].step() + + self.policy.soft_update(self.tau) + + actor_lr = self.optimizer[0].state_dict()['param_groups'][0]['lr'] + critic_lr = self.optimizer[1].state_dict()['param_groups'][0]['lr'] + + info = { + "Qloss": q_loss.item(), + "Ploss": p_loss.item(), + "Qvalue": action_q.mean().item(), + "actor_lr": actor_lr, + "critic_lr": critic_lr + } + + return info + + + + + + .. group-tab:: TensorFlow + + .. code-block:: python + + + .. group-tab:: MindSpore + + .. code-block:: python \ No newline at end of file diff --git a/docs/source/documents/api/learners/drl/ddqn.rst b/docs/source/documents/api/learners/drl/ddqn.rst index 61da154d..787c8fa0 100644 --- a/docs/source/documents/api/learners/drl/ddqn.rst +++ b/docs/source/documents/api/learners/drl/ddqn.rst @@ -7,6 +7,39 @@ DDQN_Learner **PyTorch:** +.. py:class:: + xuance.torch.learners.qlearning_family.ddqn_learner.DDQN_Learner(policy, optimizer, scheduler, device, model_dir, gamma, sync_frequency) + + :param policy: xxxxxx. + :type policy: xxxxxx + :param optimizer: xxxxxx. + :type optimizer: xxxxxx + :param scheduler: xxxxxx. + :type scheduler: xxxxxx + :param device: xxxxxx. + :type device: xxxxxx + :param model_dir: xxxxxx. + :type model_dir: xxxxxx + :param gamma: xxxxxx. + :type gamma: xxxxxx + :param sync_frequency: xxxxxx. + :type sync_frequency: xxxxxx + +.. py:function:: + xuance.torch.learners.qlearning_family.ddqn_learner.DDQN_Learner.update(obs_batch, act_batch, rew_batch, next_batch, terminal_batch) + + :param obs_batch: xxxxxx. + :type obs_batch: xxxxxx + :param act_batch: xxxxxx. + :type act_batch: xxxxxx + :param rew_batch: xxxxxx. + :type rew_batch: xxxxxx + :param next_batch: xxxxxx. + :type next_batch: xxxxxx + :param terminal_batch: xxxxxx. + :type terminal_batch: xxxxxx + :return: xxxxxx. + :rtype: xxxxxx .. raw:: html @@ -28,18 +61,73 @@ Source Code ----------------- .. tabs:: - - .. group-tab:: PyTorch - .. code-block:: python3 + .. group-tab:: PyTorch + .. code-block:: python + from xuance.torch.learners import * - .. group-tab:: TensorFlow + class DDQN_Learner(Learner): + def __init__(self, + policy: nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + device: Optional[Union[int, str, torch.device]] = None, + model_dir: str = "./", + gamma: float = 0.99, + sync_frequency: int = 100): + self.gamma = gamma + self.sync_frequency = sync_frequency + super(DDQN_Learner, self).__init__(policy, optimizer, scheduler, device, model_dir) - .. code-block:: python3 + def update(self, obs_batch, act_batch, rew_batch, next_batch, terminal_batch): + self.iterations += 1 + act_batch = torch.as_tensor(act_batch, device=self.device) + rew_batch = torch.as_tensor(rew_batch, device=self.device) + ter_batch = torch.as_tensor(terminal_batch, device=self.device) - .. group-tab:: MindSpore + _, _, evalQ = self.policy(obs_batch) + _, targetA, targetQ = self.policy(next_batch) - .. code-block:: python3 \ No newline at end of file + targetA = F.one_hot(targetA, targetQ.shape[-1]) + targetQ = (targetQ * targetA).sum(dim=-1) + targetQ = rew_batch + self.gamma * (1 - ter_batch) * targetQ + predictQ = (evalQ * F.one_hot(act_batch.long(), evalQ.shape[1])).sum(dim=-1) + + loss = F.mse_loss(predictQ, targetQ) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + if self.scheduler is not None: + self.scheduler.step() + + # hard update for target network + if self.iterations % self.sync_frequency == 0: + self.policy.copy_target() + + lr = self.optimizer.state_dict()['param_groups'][0]['lr'] + + info = { + "Qloss": loss.item(), + "learning_rate": lr, + "predictQ": predictQ.mean().item() + } + + return info + + + + + + + + .. group-tab:: TensorFlow + + .. code-block:: python + + + .. group-tab:: MindSpore + + .. code-block:: python \ No newline at end of file diff --git a/docs/source/documents/api/learners/drl/dqn.rst b/docs/source/documents/api/learners/drl/dqn.rst index a4f94ea3..1c05bf78 100644 --- a/docs/source/documents/api/learners/drl/dqn.rst +++ b/docs/source/documents/api/learners/drl/dqn.rst @@ -7,6 +7,39 @@ DQN_Learner **PyTorch:** +.. py:class:: + xuance.torch.learners.qlearning_family.dqn_learner.DQN_Learner(policy, optimizer, scheduler, device, model_dir, gamma, sync_frequency) + + :param policy: xxxxxx. + :type policy: xxxxxx + :param optimizer: xxxxxx. + :type optimizer: xxxxxx + :param scheduler: xxxxxx. + :type scheduler: xxxxxx + :param device: xxxxxx. + :type device: xxxxxx + :param model_dir: xxxxxx. + :type model_dir: xxxxxx + :param gamma: xxxxxx. + :type gamma: xxxxxx + :param sync_frequency: xxxxxx. + :type sync_frequency: xxxxxx + +.. py:function:: + xuance.torch.learners.qlearning_family.dqn_learner.DQN_Learner.update(obs_batch, act_batch, rew_batch, next_batch, terminal_batch) + + :param obs_batch: xxxxxx. + :type obs_batch: xxxxxx + :param act_batch: xxxxxx. + :type act_batch: xxxxxx + :param rew_batch: xxxxxx. + :type rew_batch: xxxxxx + :param next_batch: xxxxxx. + :type next_batch: xxxxxx + :param terminal_batch: xxxxxx. + :type terminal_batch: xxxxxx + :return: xxxxxx. + :rtype: xxxxxx .. raw:: html @@ -28,18 +61,71 @@ Source Code ----------------- .. tabs:: - - .. group-tab:: PyTorch - .. code-block:: python3 + .. group-tab:: PyTorch + .. code-block:: python + from xuance.torch.learners import * - .. group-tab:: TensorFlow + class DQN_Learner(Learner): + def __init__(self, + policy: nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + device: Optional[Union[int, str, torch.device]] = None, + model_dir: str = "./", + gamma: float = 0.99, + sync_frequency: int = 100): + self.gamma = gamma + self.sync_frequency = sync_frequency + super(DQN_Learner, self).__init__(policy, optimizer, scheduler, device, model_dir) - .. code-block:: python3 + def update(self, obs_batch, act_batch, rew_batch, next_batch, terminal_batch): + self.iterations += 1 + act_batch = torch.as_tensor(act_batch, device=self.device) + rew_batch = torch.as_tensor(rew_batch, device=self.device) + ter_batch = torch.as_tensor(terminal_batch, device=self.device) - .. group-tab:: MindSpore + _, _, evalQ = self.policy(obs_batch) + _, _, targetQ = self.policy.target(next_batch) + targetQ = targetQ.max(dim=-1).values + targetQ = rew_batch + self.gamma * (1 - ter_batch) * targetQ + predictQ = (evalQ * F.one_hot(act_batch.long(), evalQ.shape[1])).sum(dim=-1) - .. code-block:: python3 \ No newline at end of file + loss = F.mse_loss(predictQ, targetQ) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + if self.scheduler is not None: + self.scheduler.step() + + # hard update for target network + if self.iterations % self.sync_frequency == 0: + self.policy.copy_target() + lr = self.optimizer.state_dict()['param_groups'][0]['lr'] + + info = { + "Qloss": loss.item(), + "learning_rate": lr, + "predictQ": predictQ.mean().item() + } + + return info + + + + + + + + + .. group-tab:: TensorFlow + + .. code-block:: python + + + .. group-tab:: MindSpore + + .. code-block:: python \ No newline at end of file