Skip to content

Commit

Permalink
Naming and typing improvements in Actor/Critic/Policy forwards (thu-m…
Browse files Browse the repository at this point in the history
…l#1032)

Closes thu-ml#917 

### Internal Improvements
- Better variable names related to model outputs (logits, dist input
etc.). thu-ml#1032
- Improved typing for actors and critics, using Tianshou classes like
`Actor`, `ActorProb`, etc.,
instead of just `nn.Module`. thu-ml#1032
- Added interfaces for most `Actor` and `Critic` classes to enforce the
presence of `forward` methods. thu-ml#1032
- Simplified `PGPolicy` forward by unifying the `dist_fn` interface (see
associated breaking change). thu-ml#1032
- Use `.mode` of distribution instead of relying on knowledge of the
distribution type. thu-ml#1032

### Breaking Changes

- Changed interface of `dist_fn` in `PGPolicy` and all subclasses to
take a single argument in both
continuous and discrete cases. thu-ml#1032

---------

Co-authored-by: Arnau Jimenez <[email protected]>
Co-authored-by: Michael Panchenko <[email protected]>
  • Loading branch information
3 people authored Apr 1, 2024
1 parent 5bf923c commit bf0d632
Show file tree
Hide file tree
Showing 43 changed files with 340 additions and 243 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
- Introduced a first iteration of a naming convention for vars in `Collector`s. #1063
- Generally improved readability of Collector code and associated tests (still quite some way to go). #1063
- Improved typing for `exploration_noise` and within Collector. #1063
- Better variable names related to model outputs (logits, dist input etc.). #1032
- Improved typing for actors and critics, using Tianshou classes like `Actor`, `ActorProb`, etc.,
instead of just `nn.Module`. #1032
- Added interfaces for most `Actor` and `Critic` classes to enforce the presence of `forward` methods. #1032
- Simplified `PGPolicy` forward by unifying the `dist_fn` interface (see associated breaking change). #1032
- Use `.mode` of distribution instead of relying on knowledge of the distribution type. #1032

### Breaking Changes

Expand All @@ -21,6 +27,8 @@
expicitly or pass `reset_before_collect=True` . #1063
- VectorEnvs now return an array of info-dicts on reset instead of a list. #1063
- Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063
- Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both
continuous and discrete cases. #1032

### Tests
- Fixed env seeding it test_sac_with_il.py so that the test doesn't fail randomly. #1081
Expand Down
4 changes: 2 additions & 2 deletions docs/02_notebooks/L4_Policy.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
"from tianshou.policy import BasePolicy\n",
"from tianshou.policy.modelfree.pg import (\n",
" PGTrainingStats,\n",
" TDistributionFunction,\n",
" TDistFnDiscrOrCont,\n",
" TPGTrainingStats,\n",
")\n",
"from tianshou.utils import RunningMeanStd\n",
Expand Down Expand Up @@ -339,7 +339,7 @@
" *,\n",
" actor: torch.nn.Module,\n",
" optim: torch.optim.Optimizer,\n",
" dist_fn: TDistributionFunction,\n",
" dist_fn: TDistFnDiscrOrCont,\n",
" action_space: gym.Space,\n",
" discount_factor: float = 0.99,\n",
" observation_space: gym.Space | None = None,\n",
Expand Down
5 changes: 5 additions & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,8 @@ macOS
joblib
master
Panchenko
BA
BH
BO
BD

5 changes: 3 additions & 2 deletions examples/inverse/irl_gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,9 @@ def test_gail(args: argparse.Namespace = get_args()) -> None:

lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)

def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)

# expert replay buffer
dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task))
Expand Down
5 changes: 3 additions & 2 deletions examples/mujoco/mujoco_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,9 @@ def test_a2c(args: argparse.Namespace = get_args()) -> None:

lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)

def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)

policy: A2CPolicy = A2CPolicy(
actor=actor,
Expand Down
5 changes: 3 additions & 2 deletions examples/mujoco/mujoco_npg.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,9 @@ def test_npg(args: argparse.Namespace = get_args()) -> None:

lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)

def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)

policy: NPGPolicy = NPGPolicy(
actor=actor,
Expand Down
5 changes: 3 additions & 2 deletions examples/mujoco/mujoco_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,9 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:

lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)

def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)

policy: PPOPolicy = PPOPolicy(
actor=actor,
Expand Down
5 changes: 3 additions & 2 deletions examples/mujoco/mujoco_reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,9 @@ def test_reinforce(args: argparse.Namespace = get_args()) -> None:

lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)

def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)

policy: PGPolicy = PGPolicy(
actor=actor,
Expand Down
5 changes: 3 additions & 2 deletions examples/mujoco/mujoco_trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,9 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None:

lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)

def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)

policy: TRPOPolicy = TRPOPolicy(
actor=actor,
Expand Down
8 changes: 6 additions & 2 deletions test/base/test_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import pytest
import torch
from torch.distributions import Categorical, Independent, Normal
from torch.distributions import Categorical, Distribution, Independent, Normal

from tianshou.policy import PPOPolicy
from tianshou.utils.net.common import ActorCritic, Net
Expand All @@ -25,7 +25,11 @@ def policy(request):
Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.shape),
action_shape=action_space.shape,
)
dist_fn = lambda *logits: Independent(Normal(*logits), 1)

def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)

elif action_type == "discrete":
action_space = gym.spaces.Discrete(3)
actor = Actor(
Expand Down
5 changes: 3 additions & 2 deletions test/continuous/test_npg.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,9 @@ def test_npg(args: argparse.Namespace = get_args()) -> None:

# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward
def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)

policy: NPGPolicy[NPGTrainingStats] = NPGPolicy(
actor=actor,
Expand Down
5 changes: 3 additions & 2 deletions test/continuous/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:

# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward
def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)

policy: PPOPolicy[PPOTrainingStats] = PPOPolicy(
actor=actor,
Expand Down
5 changes: 3 additions & 2 deletions test/continuous/test_trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None:

# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward
def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)

policy: BasePolicy = TRPOPolicy(
actor=actor,
Expand Down
5 changes: 3 additions & 2 deletions test/offline/test_gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,9 @@ def test_gail(args: argparse.Namespace = get_args()) -> None:

# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward
def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)

policy: BasePolicy = GAILPolicy(
actor=actor,
Expand Down
5 changes: 3 additions & 2 deletions test/pettingzoo/pistonball_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,9 @@ def get_agents(
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(set(actor.parameters()).union(critic.parameters()), lr=args.lr)

def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)

agent: PPOPolicy = PPOPolicy(
actor,
Expand Down
23 changes: 15 additions & 8 deletions tianshou/highlevel/params/dist_fn.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,47 @@
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any

import torch

from tianshou.highlevel.env import Environments, EnvType
from tianshou.policy.modelfree.pg import TDistributionFunction
from tianshou.policy.modelfree.pg import TDistFnDiscrete, TDistFnDiscrOrCont
from tianshou.utils.string import ToStringMixin


class DistributionFunctionFactory(ToStringMixin, ABC):
# True return type defined in subclasses
@abstractmethod
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
def create_dist_fn(
self,
envs: Environments,
) -> Callable[[Any], torch.distributions.Distribution]:
pass


class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrete:
envs.get_type().assert_discrete(self)
return self._dist_fn

@staticmethod
def _dist_fn(p: torch.Tensor) -> torch.distributions.Distribution:
def _dist_fn(p: torch.Tensor) -> torch.distributions.Categorical:
return torch.distributions.Categorical(logits=p)


class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont:
envs.get_type().assert_continuous(self)
return self._dist_fn

@staticmethod
def _dist_fn(*p: torch.Tensor) -> torch.distributions.Distribution:
return torch.distributions.Independent(torch.distributions.Normal(*p), 1)
def _dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.distributions.Distribution:
loc, scale = loc_scale
return torch.distributions.Independent(torch.distributions.Normal(loc, scale), 1)


class DistributionFunctionFactoryDefault(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont:
match envs.get_type():
case EnvType.DISCRETE:
return DistributionFunctionFactoryCategorical().create_dist_fn(envs)
Expand Down
4 changes: 2 additions & 2 deletions tianshou/highlevel/params/policy_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
from tianshou.highlevel.params.noise import NoiseFactory
from tianshou.policy.modelfree.pg import TDistributionFunction
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.utils import MultipleLRSchedulers
from tianshou.utils.string import ToStringMixin

Expand Down Expand Up @@ -322,7 +322,7 @@ class PGParams(Params, ParamsMixinActionScaling, ParamsMixinLearningRateWithSche
whether to use deterministic action (the dist's mode) instead of stochastic one during evaluation.
Does not affect training.
"""
dist_fn: TDistributionFunction | DistributionFunctionFactory | Literal["default"] = "default"
dist_fn: TDistFnDiscrOrCont | DistributionFunctionFactory | Literal["default"] = "default"
"""
This can either be a function which maps the model output to a torch distribution or a
factory for the creation of such a function.
Expand Down
9 changes: 7 additions & 2 deletions tianshou/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,11 @@ def __init__(
super().__init__()
self.observation_space = observation_space
self.action_space = action_space
self._action_type: Literal["discrete", "continuous"]
if isinstance(action_space, Discrete | MultiDiscrete | MultiBinary):
self.action_type = "discrete"
self._action_type = "discrete"
elif isinstance(action_space, Box):
self.action_type = "continuous"
self._action_type = "continuous"
else:
raise ValueError(f"Unsupported action space: {action_space}.")
self.agent_id = 0
Expand All @@ -226,6 +227,10 @@ def __init__(
self.lr_scheduler = lr_scheduler
self._compile()

@property
def action_type(self) -> Literal["discrete", "continuous"]:
return self._action_type

def set_agent_id(self, agent_id: int) -> None:
"""Set self.agent_id = agent_id, for MARL."""
self.agent_id = agent_id
Expand Down
23 changes: 20 additions & 3 deletions tianshou/policy/imitation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats

# Dimension Naming Convention
# B - Batch Size
# A - Action
# D - Dist input (usually 2, loc and scale)
# H - Dimension of hidden, can be None


@dataclass(kw_only=True)
class ImitationTrainingStats(TrainingStats):
Expand Down Expand Up @@ -72,9 +78,20 @@ def forward(
state: dict | BatchProtocol | np.ndarray | None = None,
**kwargs: Any,
) -> ModelOutputBatchProtocol:
logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
act = logits.max(dim=1)[1] if self.action_type == "discrete" else logits
result = Batch(logits=logits, act=act, state=hidden)
# TODO - ALGO-REFACTORING: marked for refactoring when Algorithm abstraction is introduced
if self.action_type == "discrete":
# If it's discrete, the "actor" is usually a critic that maps obs to action_values
# which then could be turned into logits or a Categorigal
action_values_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
act_B = action_values_BA.argmax(dim=1)
result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH)
elif self.action_type == "continuous":
# If it's continuous, the actor would usually deliver something like loc, scale determining a
# Gaussian dist
dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
result = Batch(logits=dist_input_BD, act=dist_input_BD, state=hidden_BH)
else:
raise RuntimeError(f"Unknown {self.action_type=}, this shouldn't have happened!")
return cast(ModelOutputBatchProtocol, result)

def learn(
Expand Down
3 changes: 1 addition & 2 deletions tianshou/policy/imitation/discrete_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ class DiscreteBCQTrainingStats(DQNTrainingStats):
class DiscreteBCQPolicy(DQNPolicy[TDiscreteBCQTrainingStats]):
"""Implementation of discrete BCQ algorithm. arXiv:1910.01708.
:param model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> q_value)
:param model: a model following the rules (s_B -> action_values_BA)
:param imitator: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> imitation_logits)
:param optim: a torch.optim for optimizing the model.
Expand Down
3 changes: 1 addition & 2 deletions tianshou/policy/imitation/discrete_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ class DiscreteCQLTrainingStats(QRDQNTrainingStats):
class DiscreteCQLPolicy(QRDQNPolicy[TDiscreteCQLTrainingStats]):
"""Implementation of discrete Conservative Q-Learning algorithm. arXiv:2006.04779.
:param model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param model: a model following the rules (s_B -> action_values_BA)
:param optim: a torch.optim for optimizing the model.
:param action_space: Env's action space.
:param min_q_weight: the weight for the cql loss.
Expand Down
10 changes: 6 additions & 4 deletions tianshou/policy/imitation/discrete_crr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tianshou.data.types import RolloutBatchProtocol
from tianshou.policy.base import TLearningRateScheduler
from tianshou.policy.modelfree.pg import PGPolicy, PGTrainingStats
from tianshou.utils.net.discrete import Actor, Critic


@dataclass
Expand All @@ -26,8 +27,9 @@ class DiscreteCRRTrainingStats(PGTrainingStats):
class DiscreteCRRPolicy(PGPolicy[TDiscreteCRRTrainingStats]):
r"""Implementation of discrete Critic Regularized Regression. arXiv:2006.15134.
:param actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param actor: the actor network following the rules:
If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
:param critic: the action-value critic (i.e., Q function)
network. (s -> Q(s, \*))
:param optim: a torch.optim for optimizing the model.
Expand Down Expand Up @@ -55,8 +57,8 @@ class DiscreteCRRPolicy(PGPolicy[TDiscreteCRRTrainingStats]):
def __init__(
self,
*,
actor: torch.nn.Module,
critic: torch.nn.Module,
actor: torch.nn.Module | Actor,
critic: torch.nn.Module | Critic,
optim: torch.optim.Optimizer,
action_space: gym.spaces.Discrete,
discount_factor: float = 0.99,
Expand Down
Loading

0 comments on commit bf0d632

Please sign in to comment.