diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index 4128a0f14..d9a310ac9 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -3,7 +3,7 @@ import os import pickle import pprint -from test.utils import get_space_info, print_final_stats +from test.utils import print_final_stats import gymnasium as gym import numpy as np @@ -18,6 +18,7 @@ from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.continuous import VAE, Critic, Perturbation +from tianshou.utils.space_info import SpaceInfo if __name__ == "__main__": from gather_pendulum_data import expert_file_name, gather_data @@ -77,14 +78,13 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None: buffer = gather_data() env = gym.make(args.task) - action_space_info = get_space_info(env.action_space) - observation_space_info = get_space_info(env.observation_space) + space_info = SpaceInfo.from_env(env.action_space, env.observation_space) - args.state_shape = observation_space_info.state_shape - args.action_shape = action_space_info.action_shape - args.max_action = action_space_info.max_action - args.state_dim = observation_space_info.state_dim - args.action_dim = action_space_info.action_dim + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + args.max_action = space_info.action_info.max_action + args.state_dim = space_info.observation_info.obs_dim + args.action_dim = space_info.action_info.action_dim if args.reward_threshold is None: # too low? diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 48d03d6c0..791225666 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -3,7 +3,6 @@ import os import pickle import pprint -from test.utils import get_space_info from typing import cast import gymnasium as gym @@ -19,6 +18,7 @@ from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.space_info import SpaceInfo if __name__ == "__main__": from gather_pendulum_data import expert_file_name, gather_data @@ -81,17 +81,16 @@ def test_cql(args: argparse.Namespace = get_args()) -> None: else: buffer = gather_data() env = gym.make(args.task) - action_space = cast(gym.spaces.Box, env.action_space) + env.action_space = cast(gym.spaces.Box, env.action_space) - action_space_info = get_space_info(action_space) - observation_space_info = get_space_info(env.observation_space) + space_info = SpaceInfo.from_env(env.action_space, env.observation_space) - args.state_shape = observation_space_info.state_shape - args.action_shape = action_space_info.action_shape - args.min_action = action_space_info.min_action - args.max_action = action_space_info.max_action - args.state_dim = observation_space_info.state_dim - args.action_dim = action_space_info.action_dim + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + args.min_action = space_info.action_info.min_action + args.max_action = space_info.action_info.max_action + args.state_dim = space_info.observation_info.obs_dim + args.action_dim = space_info.action_info.action_dim if args.reward_threshold is None: # too low? @@ -150,7 +149,7 @@ def test_cql(args: argparse.Namespace = get_args()) -> None: # CQL seems to perform better without action scaling # TODO: investigate why action_scaling=False, - action_space=action_space, + action_space=env.action_space, cql_alpha_lr=args.cql_alpha_lr, cql_weight=args.cql_weight, tau=args.tau, diff --git a/test/utils.py b/test/utils.py index 03479f17b..7510b5c20 100644 --- a/test/utils.py +++ b/test/utils.py @@ -1,47 +1,6 @@ -from collections.abc import Sequence -from dataclasses import dataclass - -import numpy as np -from gymnasium import spaces - from tianshou.data.collector import CollectStats -@dataclass -class SpaceInfo: - action_shape: int | Sequence[int] - state_shape: int | Sequence[int] - action_dim: int - state_dim: int - min_action: float - max_action: float - - -def get_space_info( - space: spaces.Space, -) -> SpaceInfo: - if isinstance(space, spaces.Box): - return SpaceInfo( - action_shape=space.shape, - state_shape=space.shape, - action_dim=space.shape[0], - state_dim=space.shape[0], - min_action=float(np.min(space.low)), - max_action=float(np.max(space.high)), - ) - elif isinstance(space, spaces.Discrete): - return SpaceInfo( - action_shape=int(space.n), - state_shape=int(space.n), - action_dim=int(space.n), - state_dim=int(space.n), - min_action=float(space.start), - max_action=float(space.start + space.n - 1), - ) - else: - raise NotImplementedError("Unsupported space type") - - def print_final_stats(collect_stats: CollectStats) -> None: if collect_stats.returns_stat is not None and collect_stats.lens_stat is not None: print( diff --git a/tianshou/utils/space_info.py b/tianshou/utils/space_info.py new file mode 100644 index 000000000..cf2fab03a --- /dev/null +++ b/tianshou/utils/space_info.py @@ -0,0 +1,81 @@ +from collections.abc import Sequence +from dataclasses import dataclass, field + +import numpy as np +from gymnasium import spaces + + +@dataclass(kw_only=True) +class ActionSpaceInfo: + action_shape: int | Sequence[int] + action_dim: int = field(init=False) + min_action: float + max_action: float + + def __post_init__(self) -> None: + if isinstance(self.action_shape, int): + self.action_dim = self.action_shape + else: + self.action_dim = int(self.action_shape[0]) + + @classmethod + def from_space(cls, space: spaces.Space) -> "ActionSpaceInfo": + if isinstance(space, spaces.Box): + return cls( + action_shape=space.shape, + min_action=float(np.min(space.low)), + max_action=float(np.max(space.high)), + ) + elif isinstance(space, spaces.Discrete): + return cls( + action_shape=int(space.n), + min_action=float(space.start), + max_action=float(space.start + space.n - 1), + ) + else: + raise ValueError( + f"Unsupported space type: {space.__class__}. Currently supported types are Discrete and Box.", + ) + + +@dataclass(kw_only=True) +class ObservationSpaceInfo: + obs_shape: int | Sequence[int] + obs_dim: int = field(init=False) + + def __post_init__(self) -> None: + if isinstance(self.obs_shape, int): + self.obs_dim = self.obs_shape + else: + self.obs_dim = int(self.obs_shape[0]) + + @classmethod + def from_space(cls, space: spaces.Space) -> "ObservationSpaceInfo": + if isinstance(space, spaces.Box): + return cls( + obs_shape=space.shape, + ) + elif isinstance(space, spaces.Discrete): + return cls( + obs_shape=int(space.n), + ) + else: + raise ValueError( + f"Unsupported space type: {space.__class__}. Currently supported types are Discrete and Box.", + ) + + +@dataclass(kw_only=True) +class SpaceInfo: + action_info: ActionSpaceInfo + observation_info: ObservationSpaceInfo + + @classmethod + def from_env(cls, action_space: spaces.Space, observation_space: spaces.Space) -> "SpaceInfo": + action_info = ActionSpaceInfo.from_space(action_space) + observation_info = ObservationSpaceInfo.from_space(observation_space) + + return cls( + action_info=action_info, + observation_info=observation_info, + )