From bf391853dccdca1d0ae02e68203014dcb26966f6 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 14 Feb 2024 19:06:01 +0100 Subject: [PATCH 1/6] Allow to configure number of test episodes in high-level API --- tianshou/highlevel/agent.py | 4 ++-- tianshou/highlevel/config.py | 17 ++++++++++++++++- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index b72ab5e96..1a1a0bf76 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -184,7 +184,7 @@ def create_trainer( max_epoch=sampling_config.num_epochs, step_per_epoch=sampling_config.step_per_epoch, repeat_per_collect=sampling_config.repeat_per_collect, - episode_per_test=sampling_config.num_test_envs, + episode_per_test=sampling_config.num_test_episodes_per_test_env, batch_size=sampling_config.batch_size, step_per_collect=sampling_config.step_per_collect, save_best_fn=policy_persistence.get_save_best_fn(world), @@ -228,7 +228,7 @@ def create_trainer( max_epoch=sampling_config.num_epochs, step_per_epoch=sampling_config.step_per_epoch, step_per_collect=sampling_config.step_per_collect, - episode_per_test=sampling_config.num_test_envs, + episode_per_test=sampling_config.num_test_episodes_per_test_env, batch_size=sampling_config.batch_size, save_best_fn=policy_persistence.get_save_best_fn(world), logger=world.logger, diff --git a/tianshou/highlevel/config.py b/tianshou/highlevel/config.py index 80e04769d..498247214 100644 --- a/tianshou/highlevel/config.py +++ b/tianshou/highlevel/config.py @@ -1,3 +1,4 @@ +import math import multiprocessing from dataclasses import dataclass @@ -16,7 +17,10 @@ class SamplingConfig(ToStringMixin): * collects environment steps/transitions (collection step), adding them to the (replay) buffer (see :attr:`step_per_collect`) - * performs one or more gradient updates (see :attr:`update_per_step`). + * performs one or more gradient updates (see :attr:`update_per_step`), + + and the test step collects :attr:`num_episodes_per_test` test episodes in order to evaluate + agent performance. The number of training steps in each epoch is indirectly determined by :attr:`step_per_epoch`: As many training steps will be performed as are required in @@ -49,6 +53,12 @@ class SamplingConfig(ToStringMixin): num_test_envs: int = 1 """the number of test environments to use""" + num_test_episodes: int = 1 + """the total number of episodes to collect in each test step (across all test environments). + This should be a multiple of the number of test environments; if it is not, the effective + number of episodes collected will be the nearest multiple (rounded up). + """ + buffer_size: int = 4096 """the total size of the sample/replay buffer, in which environment steps (transitions) are stored""" @@ -119,3 +129,8 @@ class SamplingConfig(ToStringMixin): def __post_init__(self) -> None: if self.num_train_envs == -1: self.num_train_envs = multiprocessing.cpu_count() + + @property + def num_test_episodes_per_test_env(self) -> int: + """:return: the number of episodes to collect per test environment in every test step""" + return math.ceil(self.num_test_episodes / self.num_test_envs) From 76cbd7efc2c54655a0797f7c4fe30f9e4ae2cbe2 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 14 Feb 2024 20:42:06 +0100 Subject: [PATCH 2/6] Make OptimizerFactory more flexible by adding a second method which allows the creation of an optimizer given arbitrary parameters (rather than a module) --- tianshou/highlevel/optim.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/tianshou/highlevel/optim.py b/tianshou/highlevel/optim.py index 0e754b111..db5fd90ff 100644 --- a/tianshou/highlevel/optim.py +++ b/tianshou/highlevel/optim.py @@ -1,11 +1,14 @@ from abc import ABC, abstractmethod -from typing import Any, Protocol +from collections.abc import Iterable +from typing import Any, Protocol, TypeAlias import torch from torch.optim import Adam, RMSprop from tianshou.utils.string import ToStringMixin +TParams: TypeAlias = Iterable[torch.Tensor] | Iterable[dict[str, Any]] + class OptimizerWithLearningRateProtocol(Protocol): def __call__(self, parameters: Any, lr: float, **kwargs: Any) -> torch.optim.Optimizer: @@ -13,8 +16,15 @@ def __call__(self, parameters: Any, lr: float, **kwargs: Any) -> torch.optim.Opt class OptimizerFactory(ABC, ToStringMixin): + def create_optimizer( + self, + module: torch.nn.Module, + lr: float, + ) -> torch.optim.Optimizer: + return self.create_optimizer_for_params(module.parameters(), lr) + @abstractmethod - def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer: + def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer: pass @@ -30,8 +40,8 @@ def __init__(self, optim_class: OptimizerWithLearningRateProtocol, **kwargs: Any self.optim_class = optim_class self.kwargs = kwargs - def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer: - return self.optim_class(module.parameters(), lr=lr, **self.kwargs) + def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer: + return self.optim_class(params, lr=lr, **self.kwargs) class OptimizerFactoryAdam(OptimizerFactory): @@ -45,9 +55,9 @@ def __init__( self.eps = eps self.betas = betas - def create_optimizer(self, module: torch.nn.Module, lr: float) -> Adam: + def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer: return Adam( - module.parameters(), + params, lr=lr, betas=self.betas, eps=self.eps, @@ -70,9 +80,9 @@ def __init__( self.weight_decay = weight_decay self.eps = eps - def create_optimizer(self, module: torch.nn.Module, lr: float) -> RMSprop: + def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer: return RMSprop( - module.parameters(), + params, lr=lr, alpha=self.alpha, eps=self.eps, From eeb2081ca68f7894ba5a8ab31de32e360ab6ded5 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 14 Feb 2024 20:43:38 +0100 Subject: [PATCH 3/6] Fix AutoAlphaFactoryDefault using hard-coded Adam optimizer instead of passed factory --- tianshou/highlevel/params/alpha.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/highlevel/params/alpha.py b/tianshou/highlevel/params/alpha.py index 878ae4b76..4e8490de8 100644 --- a/tianshou/highlevel/params/alpha.py +++ b/tianshou/highlevel/params/alpha.py @@ -20,7 +20,7 @@ def create_auto_alpha( pass -class AutoAlphaFactoryDefault(AutoAlphaFactory): # TODO better name? +class AutoAlphaFactoryDefault(AutoAlphaFactory): def __init__(self, lr: float = 3e-4): self.lr = lr @@ -32,5 +32,5 @@ def create_auto_alpha( ) -> tuple[float, torch.Tensor, torch.optim.Optimizer]: target_entropy = float(-np.prod(envs.get_action_shape())) log_alpha = torch.zeros(1, requires_grad=True, device=device) - alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr) + alpha_optim = optim_factory.create_optimizer_for_params([log_alpha], self.lr) return target_entropy, log_alpha, alpha_optim From f2e0fd165d9797cf3523340d5040d8260b97e8e5 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 15 Feb 2024 11:26:39 +0100 Subject: [PATCH 4/6] Fix gitignore applying to tianshou/env on platfoms with case-insensitive file system --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 98acab57b..e63e24b00 100644 --- a/.gitignore +++ b/.gitignore @@ -111,7 +111,7 @@ celerybeat.pid .env .venv venv/ -ENV/ +/ENV/ env.bak/ venv.bak/ From 08728ad35e9841354a1a6db3c84e747095f1cf6d Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 15 Feb 2024 11:26:54 +0100 Subject: [PATCH 5/6] Resolve platform-specific/installation-specific mypy issues by adding ignores and ignoring unused ignores locally --- tianshou/env/worker/ray.py | 3 +++ tianshou/env/worker/subproc.py | 9 ++++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/tianshou/env/worker/ray.py b/tianshou/env/worker/ray.py index f465eae9c..76b842220 100644 --- a/tianshou/env/worker/ray.py +++ b/tianshou/env/worker/ray.py @@ -12,6 +12,9 @@ import ray +# mypy: disable-error-code="unused-ignore" + + class _SetAttrWrapper(gym.Wrapper): def set_env_attr(self, key: str, value: Any) -> None: setattr(self.env.unwrapped, key, value) diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index 2ca60c2d8..af5ec4e9f 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -12,6 +12,9 @@ from tianshou.env.utils import CloudpickleWrapper, gym_new_venv_step_type from tianshou.env.worker import EnvWorker +# mypy: disable-error-code="unused-ignore" + + _NP_TO_CT = { np.bool_: ctypes.c_bool, np.uint8: ctypes.c_uint8, @@ -179,10 +182,10 @@ def wait( # type: ignore if remain_time <= 0: break # connection.wait hangs if the list is empty - new_ready_conns = connection.wait(remain_conns, timeout=remain_time) + new_ready_conns = connection.wait(remain_conns, timeout=remain_time) # type: ignore ready_conns.extend(new_ready_conns) # type: ignore - remain_conns = [conn for conn in remain_conns if conn not in ready_conns] - return [workers[conns.index(con)] for con in ready_conns] + remain_conns = [conn for conn in remain_conns if conn not in ready_conns] # type: ignore + return [workers[conns.index(con)] for con in ready_conns] # type: ignore def send(self, action: np.ndarray | None, **kwargs: Any) -> None: if action is None: From 26e210a6ae9093c9b30ede2b6e87310faa229f80 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 15 Feb 2024 11:38:23 +0100 Subject: [PATCH 6/6] Apply nbqa only to the docs/ folder and exclude the (old) jupyter_execute folder --- pyproject.toml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index aa66eb54f..c47b14b13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -143,8 +143,8 @@ line-length = 100 target-version = ["py311"] [tool.nbqa.exclude] -ruff = ".jupyter_cache" -mypy = ".jupyter_cache" +ruff = "\\.jupyter_cache|jupyter_execute" +mypy = "\\.jupyter_cache|jupyter_execute" [tool.ruff] select = [ @@ -203,10 +203,10 @@ test = "pytest test --cov=tianshou --cov-report=xml --cov-report=term-missing -- test-reduced = "pytest test/base test/continuous --cov=tianshou --durations=0 -v --color=yes" _black_check = "black --check ." _ruff_check = "ruff check ." -_ruff_check_nb = "nbqa ruff ." +_ruff_check_nb = "nbqa ruff docs" _black_format = "black ." _ruff_format = "ruff --fix ." -_ruff_format_nb = "nbqa ruff --fix ." +_ruff_format_nb = "nbqa ruff --fix docs" lint = ["_black_check", "_ruff_check", "_ruff_check_nb"] _poetry_install_sort_plugin = "poetry self add poetry-plugin-sort" _poetry_sort = "poetry sort" @@ -221,5 +221,5 @@ doc-generate-files = ["_autogen_rst", "_jb_generate_toc", "_jb_generate_config"] doc-spellcheck = "sphinx-build -W -b spelling docs docs/_build" doc-build = ["doc-generate-files", "doc-spellcheck", "_sphinx_build"] _mypy = "mypy tianshou" -_mypy_nb = "nbqa mypy ." +_mypy_nb = "nbqa mypy docs" type-check = ["_mypy", "_mypy_nb"]