Skip to content

Commit

Permalink
Improvements in High-Level API and Poe Tasks (thu-ml#1055)
Browse files Browse the repository at this point in the history
* Add an option to SamplingConfig which allows to configure number of
test episodes
* Make OptimizerFactory more flexible, adding method
`create_optimizer_for_params`
* Fix AutoAlphaFactoryDefault using hard-coded Adam optimizer
* Fix mypy issues that were platform/installation-dependent
* Limit scope of nbqa, resolving issues with files generated by old
versions of the build

Fixes thu-ml#1054
  • Loading branch information
MischaPanch authored Feb 15, 2024
2 parents 8742e36 + 26e210a commit 9b6cb69
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ celerybeat.pid
.env
.venv
venv/
ENV/
/ENV/
env.bak/
venv.bak/

Expand Down
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ line-length = 100
target-version = ["py311"]

[tool.nbqa.exclude]
ruff = ".jupyter_cache"
mypy = ".jupyter_cache"
ruff = "\\.jupyter_cache|jupyter_execute"
mypy = "\\.jupyter_cache|jupyter_execute"

[tool.ruff]
select = [
Expand Down Expand Up @@ -203,10 +203,10 @@ test = "pytest test --cov=tianshou --cov-report=xml --cov-report=term-missing --
test-reduced = "pytest test/base test/continuous --cov=tianshou --durations=0 -v --color=yes"
_black_check = "black --check ."
_ruff_check = "ruff check ."
_ruff_check_nb = "nbqa ruff ."
_ruff_check_nb = "nbqa ruff docs"
_black_format = "black ."
_ruff_format = "ruff --fix ."
_ruff_format_nb = "nbqa ruff --fix ."
_ruff_format_nb = "nbqa ruff --fix docs"
lint = ["_black_check", "_ruff_check", "_ruff_check_nb"]
_poetry_install_sort_plugin = "poetry self add poetry-plugin-sort"
_poetry_sort = "poetry sort"
Expand All @@ -221,5 +221,5 @@ doc-generate-files = ["_autogen_rst", "_jb_generate_toc", "_jb_generate_config"]
doc-spellcheck = "sphinx-build -W -b spelling docs docs/_build"
doc-build = ["doc-generate-files", "doc-spellcheck", "_sphinx_build"]
_mypy = "mypy tianshou"
_mypy_nb = "nbqa mypy ."
_mypy_nb = "nbqa mypy docs"
type-check = ["_mypy", "_mypy_nb"]
3 changes: 3 additions & 0 deletions tianshou/env/worker/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import ray


# mypy: disable-error-code="unused-ignore"


class _SetAttrWrapper(gym.Wrapper):
def set_env_attr(self, key: str, value: Any) -> None:
setattr(self.env.unwrapped, key, value)
Expand Down
9 changes: 6 additions & 3 deletions tianshou/env/worker/subproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from tianshou.env.utils import CloudpickleWrapper, gym_new_venv_step_type
from tianshou.env.worker import EnvWorker

# mypy: disable-error-code="unused-ignore"


_NP_TO_CT = {
np.bool_: ctypes.c_bool,
np.uint8: ctypes.c_uint8,
Expand Down Expand Up @@ -179,10 +182,10 @@ def wait( # type: ignore
if remain_time <= 0:
break
# connection.wait hangs if the list is empty
new_ready_conns = connection.wait(remain_conns, timeout=remain_time)
new_ready_conns = connection.wait(remain_conns, timeout=remain_time) # type: ignore
ready_conns.extend(new_ready_conns) # type: ignore
remain_conns = [conn for conn in remain_conns if conn not in ready_conns]
return [workers[conns.index(con)] for con in ready_conns]
remain_conns = [conn for conn in remain_conns if conn not in ready_conns] # type: ignore
return [workers[conns.index(con)] for con in ready_conns] # type: ignore

def send(self, action: np.ndarray | None, **kwargs: Any) -> None:
if action is None:
Expand Down
4 changes: 2 additions & 2 deletions tianshou/highlevel/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def create_trainer(
max_epoch=sampling_config.num_epochs,
step_per_epoch=sampling_config.step_per_epoch,
repeat_per_collect=sampling_config.repeat_per_collect,
episode_per_test=sampling_config.num_test_envs,
episode_per_test=sampling_config.num_test_episodes_per_test_env,
batch_size=sampling_config.batch_size,
step_per_collect=sampling_config.step_per_collect,
save_best_fn=policy_persistence.get_save_best_fn(world),
Expand Down Expand Up @@ -228,7 +228,7 @@ def create_trainer(
max_epoch=sampling_config.num_epochs,
step_per_epoch=sampling_config.step_per_epoch,
step_per_collect=sampling_config.step_per_collect,
episode_per_test=sampling_config.num_test_envs,
episode_per_test=sampling_config.num_test_episodes_per_test_env,
batch_size=sampling_config.batch_size,
save_best_fn=policy_persistence.get_save_best_fn(world),
logger=world.logger,
Expand Down
17 changes: 16 additions & 1 deletion tianshou/highlevel/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import multiprocessing
from dataclasses import dataclass

Expand All @@ -16,7 +17,10 @@ class SamplingConfig(ToStringMixin):
* collects environment steps/transitions (collection step), adding them to the (replay)
buffer (see :attr:`step_per_collect`)
* performs one or more gradient updates (see :attr:`update_per_step`).
* performs one or more gradient updates (see :attr:`update_per_step`),
and the test step collects :attr:`num_episodes_per_test` test episodes in order to evaluate
agent performance.
The number of training steps in each epoch is indirectly determined by
:attr:`step_per_epoch`: As many training steps will be performed as are required in
Expand Down Expand Up @@ -49,6 +53,12 @@ class SamplingConfig(ToStringMixin):
num_test_envs: int = 1
"""the number of test environments to use"""

num_test_episodes: int = 1
"""the total number of episodes to collect in each test step (across all test environments).
This should be a multiple of the number of test environments; if it is not, the effective
number of episodes collected will be the nearest multiple (rounded up).
"""

buffer_size: int = 4096
"""the total size of the sample/replay buffer, in which environment steps (transitions) are
stored"""
Expand Down Expand Up @@ -119,3 +129,8 @@ class SamplingConfig(ToStringMixin):
def __post_init__(self) -> None:
if self.num_train_envs == -1:
self.num_train_envs = multiprocessing.cpu_count()

@property
def num_test_episodes_per_test_env(self) -> int:
""":return: the number of episodes to collect per test environment in every test step"""
return math.ceil(self.num_test_episodes / self.num_test_envs)
26 changes: 18 additions & 8 deletions tianshou/highlevel/optim.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,30 @@
from abc import ABC, abstractmethod
from typing import Any, Protocol
from collections.abc import Iterable
from typing import Any, Protocol, TypeAlias

import torch
from torch.optim import Adam, RMSprop

from tianshou.utils.string import ToStringMixin

TParams: TypeAlias = Iterable[torch.Tensor] | Iterable[dict[str, Any]]


class OptimizerWithLearningRateProtocol(Protocol):
def __call__(self, parameters: Any, lr: float, **kwargs: Any) -> torch.optim.Optimizer:
pass


class OptimizerFactory(ABC, ToStringMixin):
def create_optimizer(
self,
module: torch.nn.Module,
lr: float,
) -> torch.optim.Optimizer:
return self.create_optimizer_for_params(module.parameters(), lr)

@abstractmethod
def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
pass


Expand All @@ -30,8 +40,8 @@ def __init__(self, optim_class: OptimizerWithLearningRateProtocol, **kwargs: Any
self.optim_class = optim_class
self.kwargs = kwargs

def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
return self.optim_class(module.parameters(), lr=lr, **self.kwargs)
def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
return self.optim_class(params, lr=lr, **self.kwargs)


class OptimizerFactoryAdam(OptimizerFactory):
Expand All @@ -45,9 +55,9 @@ def __init__(
self.eps = eps
self.betas = betas

def create_optimizer(self, module: torch.nn.Module, lr: float) -> Adam:
def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
return Adam(
module.parameters(),
params,
lr=lr,
betas=self.betas,
eps=self.eps,
Expand All @@ -70,9 +80,9 @@ def __init__(
self.weight_decay = weight_decay
self.eps = eps

def create_optimizer(self, module: torch.nn.Module, lr: float) -> RMSprop:
def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
return RMSprop(
module.parameters(),
params,
lr=lr,
alpha=self.alpha,
eps=self.eps,
Expand Down
4 changes: 2 additions & 2 deletions tianshou/highlevel/params/alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def create_auto_alpha(
pass


class AutoAlphaFactoryDefault(AutoAlphaFactory): # TODO better name?
class AutoAlphaFactoryDefault(AutoAlphaFactory):
def __init__(self, lr: float = 3e-4):
self.lr = lr

Expand All @@ -32,5 +32,5 @@ def create_auto_alpha(
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
target_entropy = float(-np.prod(envs.get_action_shape()))
log_alpha = torch.zeros(1, requires_grad=True, device=device)
alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr)
alpha_optim = optim_factory.create_optimizer_for_params([log_alpha], self.lr)
return target_entropy, log_alpha, alpha_optim

0 comments on commit 9b6cb69

Please sign in to comment.