diff --git a/CHANGELOG.md b/CHANGELOG.md index dd50651c1..fb786988d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,12 @@ - Introduced a first iteration of a naming convention for vars in `Collector`s. #1063 - Generally improved readability of Collector code and associated tests (still quite some way to go). #1063 - Improved typing for `exploration_noise` and within Collector. #1063 +- Better variable names related to model outputs (logits, dist input etc.). #1032 +- Improved typing for actors and critics, using Tianshou classes like `Actor`, `ActorProb`, etc., +instead of just `nn.Module`. #1032 +- Added interfaces for most `Actor` and `Critic` classes to enforce the presence of `forward` methods. #1032 +- Simplified `PGPolicy` forward by unifying the `dist_fn` interface (see associated breaking change). #1032 +- Use `.mode` of distribution instead of relying on knowledge of the distribution type. #1032 ### Breaking Changes @@ -21,6 +27,8 @@ expicitly or pass `reset_before_collect=True` . #1063 - VectorEnvs now return an array of info-dicts on reset instead of a list. #1063 - Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063 +- Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both +continuous and discrete cases. #1032 ### Tests - Fixed env seeding it test_sac_with_il.py so that the test doesn't fail randomly. #1081 diff --git a/docs/02_notebooks/L4_Policy.ipynb b/docs/02_notebooks/L4_Policy.ipynb index 37a1f933e..00f7f27b9 100644 --- a/docs/02_notebooks/L4_Policy.ipynb +++ b/docs/02_notebooks/L4_Policy.ipynb @@ -69,7 +69,7 @@ "from tianshou.policy import BasePolicy\n", "from tianshou.policy.modelfree.pg import (\n", " PGTrainingStats,\n", - " TDistributionFunction,\n", + " TDistFnDiscrOrCont,\n", " TPGTrainingStats,\n", ")\n", "from tianshou.utils import RunningMeanStd\n", @@ -339,7 +339,7 @@ " *,\n", " actor: torch.nn.Module,\n", " optim: torch.optim.Optimizer,\n", - " dist_fn: TDistributionFunction,\n", + " dist_fn: TDistFnDiscrOrCont,\n", " action_space: gym.Space,\n", " discount_factor: float = 0.99,\n", " observation_space: gym.Space | None = None,\n", diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 63ee791eb..be730ff0a 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -257,3 +257,8 @@ macOS joblib master Panchenko +BA +BH +BO +BD + diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 705acaa00..3ee3709bd 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -167,8 +167,9 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) - def dist(*logits: torch.Tensor) -> Distribution: - return Independent(Normal(*logits), 1) + def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: + loc, scale = loc_scale + return Independent(Normal(loc, scale), 1) # expert replay buffer dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task)) diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index 95b645dc3..6caac9898 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -137,8 +137,9 @@ def test_a2c(args: argparse.Namespace = get_args()) -> None: lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) - def dist(*logits: torch.Tensor) -> Distribution: - return Independent(Normal(*logits), 1) + def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: + loc, scale = loc_scale + return Independent(Normal(loc, scale), 1) policy: A2CPolicy = A2CPolicy( actor=actor, diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 454565a46..e8ee97cae 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -134,8 +134,9 @@ def test_npg(args: argparse.Namespace = get_args()) -> None: lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) - def dist(*logits: torch.Tensor) -> Distribution: - return Independent(Normal(*logits), 1) + def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: + loc, scale = loc_scale + return Independent(Normal(loc, scale), 1) policy: NPGPolicy = NPGPolicy( actor=actor, diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index c0d868cf2..218b95d07 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -137,8 +137,9 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) - def dist(*logits: torch.Tensor) -> Distribution: - return Independent(Normal(*logits), 1) + def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: + loc, scale = loc_scale + return Independent(Normal(loc, scale), 1) policy: PPOPolicy = PPOPolicy( actor=actor, diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index f4a86934a..06e2bc173 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -119,8 +119,9 @@ def test_reinforce(args: argparse.Namespace = get_args()) -> None: lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) - def dist(*logits: torch.Tensor) -> Distribution: - return Independent(Normal(*logits), 1) + def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: + loc, scale = loc_scale + return Independent(Normal(loc, scale), 1) policy: PGPolicy = PGPolicy( actor=actor, diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index b001fd04c..c17ba6c14 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -137,8 +137,9 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None: lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) - def dist(*logits: torch.Tensor) -> Distribution: - return Independent(Normal(*logits), 1) + def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: + loc, scale = loc_scale + return Independent(Normal(loc, scale), 1) policy: TRPOPolicy = TRPOPolicy( actor=actor, diff --git a/test/base/test_policy.py b/test/base/test_policy.py index 9fe6f8c3a..0c51f847c 100644 --- a/test/base/test_policy.py +++ b/test/base/test_policy.py @@ -2,7 +2,7 @@ import numpy as np import pytest import torch -from torch.distributions import Categorical, Independent, Normal +from torch.distributions import Categorical, Distribution, Independent, Normal from tianshou.policy import PPOPolicy from tianshou.utils.net.common import ActorCritic, Net @@ -25,7 +25,11 @@ def policy(request): Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.shape), action_shape=action_space.shape, ) - dist_fn = lambda *logits: Independent(Normal(*logits), 1) + + def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: + loc, scale = loc_scale + return Independent(Normal(loc, scale), 1) + elif action_type == "discrete": action_space = gym.spaces.Discrete(3) actor = Actor( diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index bcfe6b07b..8e0a50d2c 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -103,8 +103,9 @@ def test_npg(args: argparse.Namespace = get_args()) -> None: # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward - def dist(*logits: torch.Tensor) -> Distribution: - return Independent(Normal(*logits), 1) + def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: + loc, scale = loc_scale + return Independent(Normal(loc, scale), 1) policy: NPGPolicy[NPGTrainingStats] = NPGPolicy( actor=actor, diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index d092bc67c..38ddbe8f0 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -100,8 +100,9 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward - def dist(*logits: torch.Tensor) -> Distribution: - return Independent(Normal(*logits), 1) + def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: + loc, scale = loc_scale + return Independent(Normal(loc, scale), 1) policy: PPOPolicy[PPOTrainingStats] = PPOPolicy( actor=actor, diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 807061231..9de81283c 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -102,8 +102,9 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None: # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward - def dist(*logits: torch.Tensor) -> Distribution: - return Independent(Normal(*logits), 1) + def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: + loc, scale = loc_scale + return Independent(Normal(loc, scale), 1) policy: BasePolicy = TRPOPolicy( actor=actor, diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index 37eb3352f..68fab728f 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -133,8 +133,9 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward - def dist(*logits: torch.Tensor) -> Distribution: - return Independent(Normal(*logits), 1) + def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: + loc, scale = loc_scale + return Independent(Normal(loc, scale), 1) policy: BasePolicy = GAILPolicy( actor=actor, diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 0897d73ad..14b5aacca 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -181,8 +181,9 @@ def get_agents( torch.nn.init.zeros_(m.bias) optim = torch.optim.Adam(set(actor.parameters()).union(critic.parameters()), lr=args.lr) - def dist(*logits: torch.Tensor) -> Distribution: - return Independent(Normal(*logits), 1) + def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: + loc, scale = loc_scale + return Independent(Normal(loc, scale), 1) agent: PPOPolicy = PPOPolicy( actor, diff --git a/tianshou/highlevel/params/dist_fn.py b/tianshou/highlevel/params/dist_fn.py index 9e9c26655..c8d2aca9e 100644 --- a/tianshou/highlevel/params/dist_fn.py +++ b/tianshou/highlevel/params/dist_fn.py @@ -1,40 +1,47 @@ from abc import ABC, abstractmethod +from collections.abc import Callable +from typing import Any import torch from tianshou.highlevel.env import Environments, EnvType -from tianshou.policy.modelfree.pg import TDistributionFunction +from tianshou.policy.modelfree.pg import TDistFnDiscrete, TDistFnDiscrOrCont from tianshou.utils.string import ToStringMixin class DistributionFunctionFactory(ToStringMixin, ABC): + # True return type defined in subclasses @abstractmethod - def create_dist_fn(self, envs: Environments) -> TDistributionFunction: + def create_dist_fn( + self, + envs: Environments, + ) -> Callable[[Any], torch.distributions.Distribution]: pass class DistributionFunctionFactoryCategorical(DistributionFunctionFactory): - def create_dist_fn(self, envs: Environments) -> TDistributionFunction: + def create_dist_fn(self, envs: Environments) -> TDistFnDiscrete: envs.get_type().assert_discrete(self) return self._dist_fn @staticmethod - def _dist_fn(p: torch.Tensor) -> torch.distributions.Distribution: + def _dist_fn(p: torch.Tensor) -> torch.distributions.Categorical: return torch.distributions.Categorical(logits=p) class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory): - def create_dist_fn(self, envs: Environments) -> TDistributionFunction: + def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont: envs.get_type().assert_continuous(self) return self._dist_fn @staticmethod - def _dist_fn(*p: torch.Tensor) -> torch.distributions.Distribution: - return torch.distributions.Independent(torch.distributions.Normal(*p), 1) + def _dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.distributions.Distribution: + loc, scale = loc_scale + return torch.distributions.Independent(torch.distributions.Normal(loc, scale), 1) class DistributionFunctionFactoryDefault(DistributionFunctionFactory): - def create_dist_fn(self, envs: Environments) -> TDistributionFunction: + def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont: match envs.get_type(): case EnvType.DISCRETE: return DistributionFunctionFactoryCategorical().create_dist_fn(envs) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 373a413e8..24674bc8c 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -19,7 +19,7 @@ from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory from tianshou.highlevel.params.noise import NoiseFactory -from tianshou.policy.modelfree.pg import TDistributionFunction +from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont from tianshou.utils import MultipleLRSchedulers from tianshou.utils.string import ToStringMixin @@ -322,7 +322,7 @@ class PGParams(Params, ParamsMixinActionScaling, ParamsMixinLearningRateWithSche whether to use deterministic action (the dist's mode) instead of stochastic one during evaluation. Does not affect training. """ - dist_fn: TDistributionFunction | DistributionFunctionFactory | Literal["default"] = "default" + dist_fn: TDistFnDiscrOrCont | DistributionFunctionFactory | Literal["default"] = "default" """ This can either be a function which maps the model output to a torch distribution or a factory for the creation of such a function. diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 7df7ebd2c..77602a02b 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -213,10 +213,11 @@ def __init__( super().__init__() self.observation_space = observation_space self.action_space = action_space + self._action_type: Literal["discrete", "continuous"] if isinstance(action_space, Discrete | MultiDiscrete | MultiBinary): - self.action_type = "discrete" + self._action_type = "discrete" elif isinstance(action_space, Box): - self.action_type = "continuous" + self._action_type = "continuous" else: raise ValueError(f"Unsupported action space: {action_space}.") self.agent_id = 0 @@ -226,6 +227,10 @@ def __init__( self.lr_scheduler = lr_scheduler self._compile() + @property + def action_type(self) -> Literal["discrete", "continuous"]: + return self._action_type + def set_agent_id(self, agent_id: int) -> None: """Set self.agent_id = agent_id, for MARL.""" self.agent_id = agent_id diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index 1daa9ae71..6e21016d9 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -16,6 +16,12 @@ from tianshou.policy import BasePolicy from tianshou.policy.base import TLearningRateScheduler, TrainingStats +# Dimension Naming Convention +# B - Batch Size +# A - Action +# D - Dist input (usually 2, loc and scale) +# H - Dimension of hidden, can be None + @dataclass(kw_only=True) class ImitationTrainingStats(TrainingStats): @@ -72,9 +78,20 @@ def forward( state: dict | BatchProtocol | np.ndarray | None = None, **kwargs: Any, ) -> ModelOutputBatchProtocol: - logits, hidden = self.actor(batch.obs, state=state, info=batch.info) - act = logits.max(dim=1)[1] if self.action_type == "discrete" else logits - result = Batch(logits=logits, act=act, state=hidden) + # TODO - ALGO-REFACTORING: marked for refactoring when Algorithm abstraction is introduced + if self.action_type == "discrete": + # If it's discrete, the "actor" is usually a critic that maps obs to action_values + # which then could be turned into logits or a Categorigal + action_values_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) + act_B = action_values_BA.argmax(dim=1) + result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH) + elif self.action_type == "continuous": + # If it's continuous, the actor would usually deliver something like loc, scale determining a + # Gaussian dist + dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) + result = Batch(logits=dist_input_BD, act=dist_input_BD, state=hidden_BH) + else: + raise RuntimeError(f"Unknown {self.action_type=}, this shouldn't have happened!") return cast(ModelOutputBatchProtocol, result) def learn( diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index 8412e0a60..b5258c141 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -34,8 +34,7 @@ class DiscreteBCQTrainingStats(DQNTrainingStats): class DiscreteBCQPolicy(DQNPolicy[TDiscreteBCQTrainingStats]): """Implementation of discrete BCQ algorithm. arXiv:1910.01708. - :param model: a model following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> q_value) + :param model: a model following the rules (s_B -> action_values_BA) :param imitator: a model following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> imitation_logits) :param optim: a torch.optim for optimizing the model. diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/policy/imitation/discrete_cql.py index dc23cb75a..b63f83e11 100644 --- a/tianshou/policy/imitation/discrete_cql.py +++ b/tianshou/policy/imitation/discrete_cql.py @@ -25,8 +25,7 @@ class DiscreteCQLTrainingStats(QRDQNTrainingStats): class DiscreteCQLPolicy(QRDQNPolicy[TDiscreteCQLTrainingStats]): """Implementation of discrete Conservative Q-Learning algorithm. arXiv:2006.04779. - :param model: a model following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param model: a model following the rules (s_B -> action_values_BA) :param optim: a torch.optim for optimizing the model. :param action_space: Env's action space. :param min_q_weight: the weight for the cql loss. diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py index 9a3c2db9f..9c54129da 100644 --- a/tianshou/policy/imitation/discrete_crr.py +++ b/tianshou/policy/imitation/discrete_crr.py @@ -11,6 +11,7 @@ from tianshou.data.types import RolloutBatchProtocol from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.pg import PGPolicy, PGTrainingStats +from tianshou.utils.net.discrete import Actor, Critic @dataclass @@ -26,8 +27,9 @@ class DiscreteCRRTrainingStats(PGTrainingStats): class DiscreteCRRPolicy(PGPolicy[TDiscreteCRRTrainingStats]): r"""Implementation of discrete Critic Regularized Regression. arXiv:2006.15134. - :param actor: the actor network following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param actor: the actor network following the rules: + If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`). + If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). :param critic: the action-value critic (i.e., Q function) network. (s -> Q(s, \*)) :param optim: a torch.optim for optimizing the model. @@ -55,8 +57,8 @@ class DiscreteCRRPolicy(PGPolicy[TDiscreteCRRTrainingStats]): def __init__( self, *, - actor: torch.nn.Module, - critic: torch.nn.Module, + actor: torch.nn.Module | Actor, + critic: torch.nn.Module | Critic, optim: torch.optim.Optimizer, action_space: gym.spaces.Discrete, discount_factor: float = 0.99, diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py index c98f7afb8..524f04001 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/policy/imitation/gail.py @@ -15,8 +15,11 @@ from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol from tianshou.policy import PPOPolicy from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.pg import TDistributionFunction +from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont from tianshou.policy.modelfree.ppo import PPOTrainingStats +from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.discrete import Actor as DiscreteActor +from tianshou.utils.net.discrete import Critic as DiscreteCritic @dataclass(kw_only=True) @@ -32,7 +35,9 @@ class GailTrainingStats(PPOTrainingStats): class GAILPolicy(PPOPolicy[TGailTrainingStats]): r"""Implementation of Generative Adversarial Imitation Learning. arXiv:1606.03476. - :param actor: the actor network following the rules in BasePolicy. (s -> logits) + :param actor: the actor network following the rules: + If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`). + If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). :param critic: the critic network. (s -> V(s)) :param optim: the optimizer for actor and critic network. :param dist_fn: distribution class for computing the action. @@ -75,10 +80,10 @@ class GAILPolicy(PPOPolicy[TGailTrainingStats]): def __init__( self, *, - actor: torch.nn.Module, - critic: torch.nn.Module, + actor: torch.nn.Module | ActorProb | DiscreteActor, + critic: torch.nn.Module | Critic | DiscreteCritic, optim: torch.optim.Optimizer, - dist_fn: TDistributionFunction, + dist_fn: TDistFnDiscrOrCont, action_space: gym.Space, expert_buffer: ReplayBuffer, disc_net: torch.nn.Module, diff --git a/tianshou/policy/imitation/td3_bc.py b/tianshou/policy/imitation/td3_bc.py index 7ef700b0c..f4b2bfe91 100644 --- a/tianshou/policy/imitation/td3_bc.py +++ b/tianshou/policy/imitation/td3_bc.py @@ -25,7 +25,7 @@ class TD3BCPolicy(TD3Policy[TTD3BCTrainingStats]): """Implementation of TD3+BC. arXiv:2106.06860. :param actor: the actor network following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :class:`~tianshou.policy.BasePolicy`. (s -> actions) :param actor_optim: the optimizer for actor network. :param critic: the first critic network. (s, a -> Q(s, a)) :param critic_optim: the optimizer for the first critic network. diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 2aad187dd..d41ccb463 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -11,8 +11,11 @@ from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol from tianshou.policy import PGPolicy from tianshou.policy.base import TLearningRateScheduler, TrainingStats -from tianshou.policy.modelfree.pg import TDistributionFunction +from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont from tianshou.utils.net.common import ActorCritic +from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.discrete import Actor as DiscreteActor +from tianshou.utils.net.discrete import Critic as DiscreteCritic @dataclass(kw_only=True) @@ -30,7 +33,9 @@ class A2CTrainingStats(TrainingStats): class A2CPolicy(PGPolicy[TA2CTrainingStats], Generic[TA2CTrainingStats]): # type: ignore[type-var] """Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783. - :param actor: the actor network following the rules in BasePolicy. (s -> logits) + :param actor: the actor network following the rules: + If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`). + If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). :param critic: the critic network. (s -> V(s)) :param optim: the optimizer for actor and critic network. :param dist_fn: distribution class for computing the action. @@ -59,10 +64,10 @@ class A2CPolicy(PGPolicy[TA2CTrainingStats], Generic[TA2CTrainingStats]): # typ def __init__( self, *, - actor: torch.nn.Module, - critic: torch.nn.Module, + actor: torch.nn.Module | ActorProb | DiscreteActor, + critic: torch.nn.Module | Critic | DiscreteCritic, optim: torch.optim.Optimizer, - dist_fn: TDistributionFunction, + dist_fn: TDistFnDiscrOrCont, action_space: gym.Space, vf_coef: float = 0.5, ent_coef: float = 0.01, diff --git a/tianshou/policy/modelfree/bdq.py b/tianshou/policy/modelfree/bdq.py index ba3747772..d7196a92b 100644 --- a/tianshou/policy/modelfree/bdq.py +++ b/tianshou/policy/modelfree/bdq.py @@ -31,7 +31,7 @@ class BDQNTrainingStats(DQNTrainingStats): class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]): """Implementation of the Branching dual Q network arXiv:1711.08946. - :param model: BranchingNet mapping (obs, state, info) -> logits. + :param model: BranchingNet mapping (obs, state, info) -> action_values_BA. :param optim: a torch.optim for optimizing the model. :param discount_factor: in [0, 1]. :param estimation_step: the number of steps to look ahead. @@ -156,10 +156,10 @@ def forward( model = getattr(self, model) obs = batch.obs # TODO: this is very contrived, see also iqn.py - obs_next = obs.obs if hasattr(obs, "obs") else obs - logits, hidden = model(obs_next, state=state, info=batch.info) - act = to_numpy(logits.max(dim=-1)[1]) - result = Batch(logits=logits, act=act, state=hidden) + obs_next_BO = obs.obs if hasattr(obs, "obs") else obs + action_values_BA, hidden_BH = model(obs_next_BO, state=state, info=batch.info) + act_B = to_numpy(action_values_BA.argmax(dim=-1)) + result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH) return cast(ModelOutputBatchProtocol, result) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBDQNTrainingStats: diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index bd4491469..5bfdba0c1 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -23,8 +23,7 @@ class C51TrainingStats(DQNTrainingStats): class C51Policy(DQNPolicy[TC51TrainingStats], Generic[TC51TrainingStats]): """Implementation of Categorical Deep Q-Network. arXiv:1707.06887. - :param model: a model following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param model: a model following the rules (s_B -> action_values_BA) :param optim: a torch.optim for optimizing the model. :param discount_factor: in [0, 1]. :param num_atoms: the number of atoms in the support set of the diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index b54860e6b..f21744f72 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -19,6 +19,7 @@ from tianshou.exploration import BaseNoise, GaussianNoise from tianshou.policy import BasePolicy from tianshou.policy.base import TLearningRateScheduler, TrainingStats +from tianshou.utils.net.continuous import Actor, Critic @dataclass(kw_only=True) @@ -33,8 +34,7 @@ class DDPGTrainingStats(TrainingStats): class DDPGPolicy(BasePolicy[TDDPGTrainingStats], Generic[TDDPGTrainingStats]): """Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971. - :param actor: The actor network following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> model_output) + :param actor: The actor network following the rules (s -> actions) :param actor_optim: The optimizer for actor network. :param critic: The critic network. (s, a -> Q(s, a)) :param critic_optim: The optimizer for critic network. @@ -60,9 +60,9 @@ class DDPGPolicy(BasePolicy[TDDPGTrainingStats], Generic[TDDPGTrainingStats]): def __init__( self, *, - actor: torch.nn.Module, + actor: torch.nn.Module | Actor, actor_optim: torch.optim.Optimizer, - critic: torch.nn.Module, + critic: torch.nn.Module | Critic, critic_optim: torch.optim.Optimizer, action_space: gym.Space, tau: float = 0.005, diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index d1054f9f6..e9f9b3b4a 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -12,6 +12,7 @@ from tianshou.policy import SACPolicy from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.sac import SACTrainingStats +from tianshou.utils.net.discrete import Actor, Critic @dataclass @@ -25,8 +26,7 @@ class DiscreteSACTrainingStats(SACTrainingStats): class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]): """Implementation of SAC for Discrete Action Settings. arXiv:1910.07207. - :param actor: the actor network following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param actor: the actor network following the rules (s_B -> dist_input_BD) :param actor_optim: the optimizer for actor network. :param critic: the first critic network. (s, a -> Q(s, a)) :param critic_optim: the optimizer for the first critic network. @@ -54,12 +54,12 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]): def __init__( self, *, - actor: torch.nn.Module, + actor: torch.nn.Module | Actor, actor_optim: torch.optim.Optimizer, - critic: torch.nn.Module, + critic: torch.nn.Module | Critic, critic_optim: torch.optim.Optimizer, action_space: gym.spaces.Discrete, - critic2: torch.nn.Module | None = None, + critic2: torch.nn.Module | Critic | None = None, critic2_optim: torch.optim.Optimizer | None = None, tau: float = 0.005, gamma: float = 0.99, @@ -105,13 +105,13 @@ def forward( # type: ignore state: dict | Batch | np.ndarray | None = None, **kwargs: Any, ) -> Batch: - logits, hidden = self.actor(batch.obs, state=state, info=batch.info) - dist = Categorical(logits=logits) + logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) + dist = Categorical(logits=logits_BA) if self.deterministic_eval and not self.training: - act = dist.mode + act_B = dist.mode else: - act = dist.sample() - return Batch(logits=logits, act=act, state=hidden, dist=dist) + act_B = dist.sample() + return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: obs_next_batch = Batch( diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index ad5f7dd0d..e0ada0733 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -17,6 +17,7 @@ ) from tianshou.policy import BasePolicy from tianshou.policy.base import TLearningRateScheduler, TrainingStats +from tianshou.utils.net.common import Net @dataclass(kw_only=True) @@ -35,8 +36,7 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]): Implementation of Dueling DQN. arXiv:1511.06581 (the dueling DQN is implemented in the network side, not here). - :param model: a model following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param model: a model following the rules (s -> action_values_BA) :param optim: a torch.optim for optimizing the model. :param discount_factor: in [0, 1]. :param estimation_step: the number of steps to look ahead. @@ -60,7 +60,7 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]): def __init__( self, *, - model: torch.nn.Module, + model: torch.nn.Module | Net, optim: torch.optim.Optimizer, # TODO: type violates Liskov substitution principle action_space: gym.spaces.Discrete, @@ -201,12 +201,12 @@ def forward( obs = batch.obs # TODO: this is convoluted! See also other places where this is done. obs_next = obs.obs if hasattr(obs, "obs") else obs - logits, hidden = model(obs_next, state=state, info=batch.info) - q = self.compute_q_value(logits, getattr(obs, "mask", None)) + action_values_BA, hidden_BH = model(obs_next, state=state, info=batch.info) + q = self.compute_q_value(action_values_BA, getattr(obs, "mask", None)) if self.max_action_num is None: self.max_action_num = q.shape[1] - act = to_numpy(q.max(dim=1)[1]) - result = Batch(logits=logits, act=act, state=hidden) + act_B = to_numpy(q.argmax(dim=1)) + result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH) return cast(ModelOutputBatchProtocol, result) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDQNTrainingStats: diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index 9f1b083ee..9c87f9cac 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -27,8 +27,7 @@ class FQFTrainingStats(QRDQNTrainingStats): class FQFPolicy(QRDQNPolicy[TFQFTrainingStats]): """Implementation of Fully-parameterized Quantile Function. arXiv:1911.02140. - :param model: a model following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param model: a model following the rules (s_B -> action_values_BA) :param optim: a torch.optim for optimizing the model. :param fraction_model: a FractionProposalNetwork for proposing fractions/quantiles given state. diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index f242c146d..75d76a2dd 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -29,8 +29,7 @@ class IQNTrainingStats(QRDQNTrainingStats): class IQNPolicy(QRDQNPolicy[TIQNTrainingStats]): """Implementation of Implicit Quantile Network. arXiv:1806.06923. - :param model: a model following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param model: a model following the rules (s_B -> action_values_BA) :param optim: a torch.optim for optimizing the model. :param discount_factor: in [0, 1]. :param sample_size: the number of samples for policy evaluation. diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index f2939450c..9e04d3feb 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -12,7 +12,10 @@ from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol from tianshou.policy import A2CPolicy from tianshou.policy.base import TLearningRateScheduler, TrainingStats -from tianshou.policy.modelfree.pg import TDistributionFunction +from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont +from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.discrete import Actor as DiscreteActor +from tianshou.utils.net.discrete import Critic as DiscreteCritic @dataclass(kw_only=True) @@ -31,7 +34,9 @@ class NPGPolicy(A2CPolicy[TNPGTrainingStats], Generic[TNPGTrainingStats]): # ty https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf - :param actor: the actor network following the rules in BasePolicy. (s -> logits) + :param actor: the actor network following the rules: + If `self.action_type == "discrete"`: (`s` ->`action_values_BA`). + If `self.action_type == "continuous"`: (`s` -> `dist_input_BD`). :param critic: the critic network. (s -> V(s)) :param optim: the optimizer for actor and critic network. :param dist_fn: distribution class for computing the action. @@ -55,10 +60,10 @@ class NPGPolicy(A2CPolicy[TNPGTrainingStats], Generic[TNPGTrainingStats]): # ty def __init__( self, *, - actor: torch.nn.Module, - critic: torch.nn.Module, + actor: torch.nn.Module | ActorProb | DiscreteActor, + critic: torch.nn.Module | Critic | DiscreteCritic, optim: torch.optim.Optimizer, - dist_fn: TDistributionFunction, + dist_fn: TDistFnDiscrOrCont, action_space: gym.Space, optim_critic_iters: int = 5, actor_step_size: float = 0.5, diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index eb6cb5952..9a148feb7 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -1,7 +1,7 @@ import warnings from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast +from typing import Any, Generic, Literal, TypeVar, cast import gymnasium as gym import numpy as np @@ -24,9 +24,22 @@ from tianshou.policy import BasePolicy from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.utils import RunningMeanStd +from tianshou.utils.net.continuous import ActorProb +from tianshou.utils.net.discrete import Actor -# TODO: Is there a better way to define this type? mypy doesn't like Callable[[torch.Tensor, ...], torch.distributions.Distribution] -TDistributionFunction: TypeAlias = Callable[..., torch.distributions.Distribution] +# Dimension Naming Convention +# B - Batch Size +# A - Action +# D - Dist input (usually 2, loc and scale) +# H - Dimension of hidden, can be None + +TDistFnContinuous = Callable[ + [tuple[torch.Tensor, torch.Tensor]], + torch.distributions.Distribution, +] +TDistFnDiscrete = Callable[[torch.Tensor], torch.distributions.Categorical] + +TDistFnDiscrOrCont = TDistFnContinuous | TDistFnDiscrete @dataclass(kw_only=True) @@ -40,8 +53,9 @@ class PGTrainingStats(TrainingStats): class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]): """Implementation of REINFORCE algorithm. - :param actor: mapping (s->model_output), should follow the rules in - :class:`~tianshou.policy.BasePolicy`. + :param actor: the actor network following the rules: + If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`). + If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). :param optim: optimizer for actor network. :param dist_fn: distribution class for computing the action. Maps model_output -> distribution. Typically a Gaussian distribution @@ -71,9 +85,9 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]): def __init__( self, *, - actor: torch.nn.Module, + actor: torch.nn.Module | ActorProb | Actor, optim: torch.optim.Optimizer, - dist_fn: TDistributionFunction, + dist_fn: TDistFnDiscrOrCont, action_space: gym.Space, discount_factor: float = 0.99, # TODO: rename to return_normalization? @@ -175,20 +189,20 @@ def forward( Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for more detailed explanation. """ - # TODO: rename? It's not really logits and there are particular - # assumptions about the order of the output and on distribution type - logits, hidden = self.actor(batch.obs, state=state, info=batch.info) - if isinstance(logits, tuple): - dist = self.dist_fn(*logits) - else: - dist = self.dist_fn(logits) + # TODO - ALGO: marked for algorithm refactoring + action_dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) + # in the case that self.action_type == "discrete", the dist should always be Categorical, and D=A + # therefore action_dist_input_BD is equivalent to logits_BA + # If discrete, dist_fn will typically map loc, scale to a distribution (usually a Gaussian) + # the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked + dist = self.dist_fn(action_dist_input_BD) - # in this case, the dist is unused! if self.deterministic_eval and not self.training: - act = dist.mode + act_B = dist.mode else: - act = dist.sample() - result = Batch(logits=logits, act=act, state=hidden, dist=dist) + act_B = dist.sample() + # act is of dimension BA in continuous case and of dimension B in discrete + result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist) return cast(DistBatchProtocol, result) # TODO: why does mypy complain? diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index fde9e7c79..196cd72e4 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -10,8 +10,11 @@ from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol from tianshou.policy import A2CPolicy from tianshou.policy.base import TLearningRateScheduler, TrainingStats -from tianshou.policy.modelfree.pg import TDistributionFunction +from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont from tianshou.utils.net.common import ActorCritic +from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.discrete import Actor as DiscreteActor +from tianshou.utils.net.discrete import Critic as DiscreteCritic @dataclass(kw_only=True) @@ -29,7 +32,9 @@ class PPOTrainingStats(TrainingStats): class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # type: ignore[type-var] r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347. - :param actor: the actor network following the rules in BasePolicy. (s -> logits) + :param actor: the actor network following the rules: + If `self.action_type == "discrete"`: (`s` ->`action_values_BA`). + If `self.action_type == "continuous"`: (`s` -> `dist_input_BD`). :param critic: the critic network. (s -> V(s)) :param optim: the optimizer for actor and critic network. :param dist_fn: distribution class for computing the action. @@ -67,10 +72,10 @@ class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # ty def __init__( self, *, - actor: torch.nn.Module, - critic: torch.nn.Module, + actor: torch.nn.Module | ActorProb | DiscreteActor, + critic: torch.nn.Module | Critic | DiscreteCritic, optim: torch.optim.Optimizer, - dist_fn: TDistributionFunction, + dist_fn: TDistFnDiscrOrCont, action_space: gym.Space, eps_clip: float = 0.2, dual_clip: float | None = None, diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index b2f5d1e8c..71c36de0c 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -25,8 +25,7 @@ class QRDQNTrainingStats(DQNTrainingStats): class QRDQNPolicy(DQNPolicy[TQRDQNTrainingStats], Generic[TQRDQNTrainingStats]): """Implementation of Quantile Regression Deep Q-Network. arXiv:1710.10044. - :param model: a model following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param model: a model following the rules (s -> action_values_BA) :param optim: a torch.optim for optimizing the model. :param action_space: Env's action space. :param discount_factor: in [0, 1]. diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index 100a361f4..f9793f4db 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -12,6 +12,7 @@ from tianshou.policy import DDPGPolicy from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.ddpg import DDPGTrainingStats +from tianshou.utils.net.continuous import ActorProb @dataclass @@ -61,7 +62,7 @@ class REDQPolicy(DDPGPolicy[TREDQTrainingStats]): def __init__( self, *, - actor: torch.nn.Module, + actor: torch.nn.Module | ActorProb, actor_optim: torch.optim.Optimizer, critic: torch.nn.Module, critic_optim: torch.optim.Optimizer, @@ -150,23 +151,28 @@ def forward( # type: ignore state: dict | Batch | np.ndarray | None = None, **kwargs: Any, ) -> Batch: - loc_scale, h = self.actor(batch.obs, state=state, info=batch.info) - loc, scale = loc_scale - dist = Independent(Normal(loc, scale), 1) + (loc_B, scale_B), h_BH = self.actor(batch.obs, state=state, info=batch.info) + dist = Independent(Normal(loc_B, scale_B), 1) if self.deterministic_eval and not self.training: - act = dist.mode + act_B = dist.mode else: - act = dist.rsample() - log_prob = dist.log_prob(act).unsqueeze(-1) + act_B = dist.rsample() + log_prob = dist.log_prob(act_B).unsqueeze(-1) # apply correction for Tanh squashing when computing logprob from Gaussian # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. # in appendix C to get some understanding of this equation. - squashed_action = torch.tanh(act) + squashed_action = torch.tanh(act_B) log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum( -1, keepdim=True, ) - return Batch(logits=loc_scale, act=squashed_action, state=h, dist=dist, log_prob=log_prob) + return Batch( + logits=(loc_B, scale_B), + act=squashed_action, + state=h_BH, + dist=dist, + log_prob=log_prob, + ) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: obs_next_batch = Batch( diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index a4336247f..3b3975473 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -17,6 +17,7 @@ from tianshou.policy import DDPGPolicy from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.utils.conversion import to_optional_float +from tianshou.utils.net.continuous import ActorProb from tianshou.utils.optim import clone_optimizer @@ -36,8 +37,7 @@ class SACTrainingStats(TrainingStats): class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # type: ignore[type-var] """Implementation of Soft Actor-Critic. arXiv:1812.05905. - :param actor: the actor network following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param actor: the actor network following the rules (s -> dist_input_BD) :param actor_optim: the optimizer for actor network. :param critic: the first critic network. (s, a -> Q(s, a)) :param critic_optim: the optimizer for the first critic network. @@ -76,7 +76,7 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t def __init__( self, *, - actor: torch.nn.Module, + actor: torch.nn.Module | ActorProb, actor_optim: torch.optim.Optimizer, critic: torch.nn.Module, critic_optim: torch.optim.Optimizer, @@ -173,26 +173,25 @@ def forward( # type: ignore state: dict | Batch | np.ndarray | None = None, **kwargs: Any, ) -> DistLogProbBatchProtocol: - logits, hidden = self.actor(batch.obs, state=state, info=batch.info) - assert isinstance(logits, tuple) - dist = Independent(Normal(*logits), 1) + (loc_B, scale_B), hidden_BH = self.actor(batch.obs, state=state, info=batch.info) + dist = Independent(Normal(loc=loc_B, scale=scale_B), 1) if self.deterministic_eval and not self.training: - act = dist.mode + act_B = dist.mode else: - act = dist.rsample() - log_prob = dist.log_prob(act).unsqueeze(-1) + act_B = dist.rsample() + log_prob = dist.log_prob(act_B).unsqueeze(-1) # apply correction for Tanh squashing when computing logprob from Gaussian # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. # in appendix C to get some understanding of this equation. - squashed_action = torch.tanh(act) + squashed_action = torch.tanh(act_B) log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum( -1, keepdim=True, ) result = Batch( - logits=logits, + logits=(loc_B, scale_B), act=squashed_action, - state=hidden, + state=hidden_BH, dist=dist, log_prob=log_prob, ) diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index dbf7b6589..8c2ae8c98 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -29,7 +29,7 @@ class TD3Policy(DDPGPolicy[TTD3TrainingStats], Generic[TTD3TrainingStats]): # t """Implementation of TD3, arXiv:1802.09477. :param actor: the actor network following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :class:`~tianshou.policy.BasePolicy`. (s -> actions) :param actor_optim: the optimizer for actor network. :param critic: the first critic network. (s, a -> Q(s, a)) :param critic_optim: the optimizer for the first critic network. diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index babc23bfa..e7aa5cfd5 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -11,7 +11,10 @@ from tianshou.policy import NPGPolicy from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.npg import NPGTrainingStats -from tianshou.policy.modelfree.pg import TDistributionFunction +from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont +from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.discrete import Actor as DiscreteActor +from tianshou.utils.net.discrete import Critic as DiscreteCritic @dataclass(kw_only=True) @@ -25,7 +28,9 @@ class TRPOTrainingStats(NPGTrainingStats): class TRPOPolicy(NPGPolicy[TTRPOTrainingStats]): """Implementation of Trust Region Policy Optimization. arXiv:1502.05477. - :param actor: the actor network following the rules in BasePolicy. (s -> logits) + :param actor: the actor network following the rules: + If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`). + If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). :param critic: the critic network. (s -> V(s)) :param optim: the optimizer for actor and critic network. :param dist_fn: distribution class for computing the action. @@ -53,10 +58,10 @@ class TRPOPolicy(NPGPolicy[TTRPOTrainingStats]): def __init__( self, *, - actor: torch.nn.Module, - critic: torch.nn.Module, + actor: torch.nn.Module | ActorProb | DiscreteActor, + critic: torch.nn.Module | Critic | DiscreteCritic, optim: torch.optim.Optimizer, - dist_fn: TDistributionFunction, + dist_fn: TDistFnDiscrOrCont, action_space: gym.Space, max_kl: float = 0.01, backtrack_coeff: float = 0.8, diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 14ec54a07..dabe24e75 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -610,6 +610,17 @@ def get_preprocess_net(self) -> nn.Module: def get_output_dim(self) -> int: pass + @abstractmethod + def forward( + self, + obs: np.ndarray | torch.Tensor, + state: Any = None, + info: dict[str, Any] | None = None, + ) -> tuple[Any, Any]: + # TODO: ALGO-REFACTORING. Marked to be addressed as part of Algorithm abstraction. + # Return type needs to be more specific + pass + def getattr_with_matching_alt_value(obj: Any, attr_name: str, alt_value: T | None) -> T: """Gets the given attribute from the given object or takes the alternative value if it is not present. diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index f257f8ab4..6cd4a0f63 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -1,4 +1,5 @@ import warnings +from abc import ABC, abstractmethod from collections.abc import Sequence from typing import Any @@ -9,6 +10,7 @@ from tianshou.utils.net.common import ( MLP, BaseActor, + Net, TActionShape, TLinearLayer, get_output_dim, @@ -19,33 +21,27 @@ class Actor(BaseActor): - """Simple actor network. + """Simple actor network that directly outputs actions for continuous action space. + Used primarily in DDPG and its variants. For probabilistic policies, see :class:`~ActorProb`. It will create an actor operated in continuous action space with structure of preprocess_net ---> action_shape. - :param preprocess_net: a self-defined preprocess_net which output a - flattened hidden state. + :param preprocess_net: a self-defined preprocess_net, see usage. + Typically, an instance of :class:`~tianshou.utils.net.common.Net`. :param action_shape: a sequence of int for the shape of action. :param hidden_sizes: a sequence of int for constructing the MLP after - preprocess_net. Default to empty sequence (where the MLP now contains - only a single linear layer). - :param max_action: the scale for the final action logits. Default to - 1. - :param preprocess_net_output_dim: the output dimension of preprocess_net. + :param max_action: the scale for the final action. + :param preprocess_net_output_dim: the output dimension of + `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`. For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. - - .. seealso:: - - Please refer to :class:`~tianshou.utils.net.common.Net` as an instance - of how preprocess_net is suggested to be defined. """ def __init__( self, - preprocess_net: nn.Module, + preprocess_net: nn.Module | Net, action_shape: TActionShape, hidden_sizes: Sequence[int] = (), max_action: float = 1.0, @@ -77,42 +73,50 @@ def forward( state: Any = None, info: dict[str, Any] | None = None, ) -> tuple[torch.Tensor, Any]: - """Mapping: obs -> logits -> action.""" - if info is None: - info = {} - logits, hidden = self.preprocess(obs, state) - logits = self.max_action * torch.tanh(self.last(logits)) - return logits, hidden + """Mapping: s_B -> action_values_BA, hidden_state_BH | None. + + Returns a tensor representing the actions directly, i.e, of shape + `(n_actions, )`, and a hidden state (which may be None). + The hidden state is only not None if a recurrent net is used as part of the + learning algorithm (support for RNNs is currently experimental). + """ + action_BA, hidden_BH = self.preprocess(obs, state) + action_BA = self.max_action * torch.tanh(self.last(action_BA)) + return action_BA, hidden_BH + + +class CriticBase(nn.Module, ABC): + @abstractmethod + def forward( + self, + obs: np.ndarray | torch.Tensor, + act: np.ndarray | torch.Tensor | None = None, + info: dict[str, Any] | None = None, + ) -> torch.Tensor: + """Mapping: (s_B, a_B) -> Q(s, a)_B.""" -class Critic(nn.Module): +class Critic(CriticBase): """Simple critic network. It will create an actor operated in continuous action space with structure of preprocess_net ---> 1(q value). - :param preprocess_net: a self-defined preprocess_net which output a - flattened hidden state. + :param preprocess_net: a self-defined preprocess_net, see usage. + Typically, an instance of :class:`~tianshou.utils.net.common.Net`. :param hidden_sizes: a sequence of int for constructing the MLP after - preprocess_net. Default to empty sequence (where the MLP now contains - only a single linear layer). - :param preprocess_net_output_dim: the output dimension of preprocess_net. - :param linear_layer: use this module as linear layer. Default to nn.Linear. + :param preprocess_net_output_dim: the output dimension of + `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`. + :param linear_layer: use this module as linear layer. :param flatten_input: whether to flatten input data for the last layer. - Default to True. For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. - - .. seealso:: - - Please refer to :class:`~tianshou.utils.net.common.Net` as an instance - of how preprocess_net is suggested to be defined. """ def __init__( self, - preprocess_net: nn.Module, + preprocess_net: nn.Module | Net, hidden_sizes: Sequence[int] = (), device: str | int | torch.device = "cpu", preprocess_net_output_dim: int | None = None, @@ -139,9 +143,7 @@ def forward( act: np.ndarray | torch.Tensor | None = None, info: dict[str, Any] | None = None, ) -> torch.Tensor: - """Mapping: (s, a) -> logits -> Q(s, a).""" - if info is None: - info = {} + """Mapping: (s_B, a_B) -> Q(s, a)_B.""" obs = torch.as_tensor( obs, device=self.device, @@ -154,41 +156,35 @@ def forward( dtype=torch.float32, ).flatten(1) obs = torch.cat([obs, act], dim=1) - logits, hidden = self.preprocess(obs) - return self.last(logits) + values_B, hidden_BH = self.preprocess(obs) + return self.last(values_B) class ActorProb(BaseActor): - """Simple actor network (output with a Gauss distribution). + """Simple actor network that outputs `mu` and `sigma` to be used as input for a `dist_fn` (typically, a Gaussian). - :param preprocess_net: a self-defined preprocess_net which output a - flattened hidden state. + Used primarily in SAC, PPO and variants thereof. For deterministic policies, see :class:`~Actor`. + + :param preprocess_net: a self-defined preprocess_net, see usage. + Typically, an instance of :class:`~tianshou.utils.net.common.Net`. :param action_shape: a sequence of int for the shape of action. :param hidden_sizes: a sequence of int for constructing the MLP after - preprocess_net. Default to empty sequence (where the MLP now contains - only a single linear layer). - :param max_action: the scale for the final action logits. Default to - 1. + preprocess_net. + :param max_action: the scale for the final action logits. :param unbounded: whether to apply tanh activation on final logits. - Default to False. :param conditioned_sigma: True when sigma is calculated from the - input, False when sigma is an independent parameter. Default to False. + input, False when sigma is an independent parameter. :param preprocess_net_output_dim: the output dimension of - preprocess_net. + `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`. For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. - - .. seealso:: - - Please refer to :class:`~tianshou.utils.net.common.Net` as an instance - of how preprocess_net is suggested to be defined. """ # TODO: force kwargs, adjust downstream code def __init__( self, - preprocess_net: nn.Module, + preprocess_net: nn.Module | Net, action_shape: TActionShape, hidden_sizes: Sequence[int] = (), max_action: float = 1.0, @@ -402,8 +398,7 @@ class Perturbation(nn.Module): flattened hidden state. :param max_action: the maximum value of each dimension of action. :param device: which device to create this model on. - Default to cpu. - :param phi: max perturbation parameter for BCQ. Default to 0.05. + :param phi: max perturbation parameter for BCQ. For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. @@ -449,7 +444,6 @@ class VAE(nn.Module): :param latent_dim: the size of latent layer. :param max_action: the maximum value of each dimension of action. :param device: which device to create this model on. - Default to "cpu". For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 8a54a07ac..ab9069801 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -7,17 +7,14 @@ from torch import nn from tianshou.data import Batch, to_torch -from tianshou.utils.net.common import MLP, BaseActor, TActionShape, get_output_dim +from tianshou.utils.net.common import MLP, BaseActor, Net, TActionShape, get_output_dim class Actor(BaseActor): - """Simple actor network. + """Simple actor network for discrete action spaces. - Will create an actor operated in discrete action space with structure of - preprocess_net ---> action_shape. - - :param preprocess_net: a self-defined preprocess_net which output a - flattened hidden state. + :param preprocess_net: a self-defined preprocess_net. Typically, an instance of + :class:`~tianshou.utils.net.common.Net`. :param action_shape: a sequence of int for the shape of action. :param hidden_sizes: a sequence of int for constructing the MLP after preprocess_net. Default to empty sequence (where the MLP now contains @@ -25,20 +22,15 @@ class Actor(BaseActor): :param softmax_output: whether to apply a softmax layer over the last layer's output. :param preprocess_net_output_dim: the output dimension of - preprocess_net. + `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`. For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. - - .. seealso:: - - Please refer to :class:`~tianshou.utils.net.common.Net` as an instance - of how preprocess_net is suggested to be defined. """ def __init__( self, - preprocess_net: nn.Module, + preprocess_net: nn.Module | Net, action_shape: TActionShape, hidden_sizes: Sequence[int] = (), softmax_output: bool = True, @@ -71,43 +63,44 @@ def forward( obs: np.ndarray | torch.Tensor, state: Any = None, info: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, Any]: - r"""Mapping: s -> Q(s, \*).""" - if info is None: - info = {} - logits, hidden = self.preprocess(obs, state) - logits = self.last(logits) + ) -> tuple[torch.Tensor, torch.Tensor | None]: + r"""Mapping: s_B -> action_values_BA, hidden_state_BH | None. + + Returns a tensor representing the values of each action, i.e, of shape + `(n_actions, )`, and + a hidden state (which may be None). If `self.softmax_output` is True, they are the + probabilities for taking each action. Otherwise, they will be action values. + The hidden state is only + not None if a recurrent net is used as part of the learning algorithm. + """ + x, hidden_BH = self.preprocess(obs, state) + x = self.last(x) if self.softmax_output: - logits = F.softmax(logits, dim=-1) - return logits, hidden + x = F.softmax(x, dim=-1) + # If we computed softmax, output is probabilities, otherwise it's the non-normalized action values + output_BA = x + return output_BA, hidden_BH class Critic(nn.Module): - """Simple critic network. + """Simple critic network for discrete action spaces. - It will create an actor operated in discrete action space with structure of preprocess_net ---> 1(q value). - - :param preprocess_net: a self-defined preprocess_net which output a - flattened hidden state. + :param preprocess_net: a self-defined preprocess_net. Typically, an instance of + :class:`~tianshou.utils.net.common.Net`. :param hidden_sizes: a sequence of int for constructing the MLP after preprocess_net. Default to empty sequence (where the MLP now contains only a single linear layer). :param last_size: the output dimension of Critic network. Default to 1. :param preprocess_net_output_dim: the output dimension of - preprocess_net. + `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`. For advanced usage (how to customize the network), please refer to - :ref:`build_the_network`. - - .. seealso:: - - Please refer to :class:`~tianshou.utils.net.common.Net` as an instance - of how preprocess_net is suggested to be defined. + :ref:`build_the_network`.. """ def __init__( self, - preprocess_net: nn.Module, + preprocess_net: nn.Module | Net, hidden_sizes: Sequence[int] = (), last_size: int = 1, preprocess_net_output_dim: int | None = None, @@ -120,8 +113,10 @@ def __init__( input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device) + # TODO: make a proper interface! def forward(self, obs: np.ndarray | torch.Tensor, **kwargs: Any) -> torch.Tensor: - """Mapping: s -> V(s).""" + """Mapping: s_B -> V(s)_B.""" + # TODO: don't use this mechanism for passing state logits, _ = self.preprocess(obs, state=kwargs.get("state", None)) return self.last(logits)