Skip to content

Commit

Permalink
Refactor get_spaces_info
Browse files Browse the repository at this point in the history
  * Move function to dataclass method
  * Move code to library utils
  * Unify gathering of space info s.t. caller can get them in one call
  * Define `obs_dim` and `action_dim` after init (in __post_init__())
  * Rename state -> obs
  * Update the two test files that currently use this utility
  • Loading branch information
dantp-ai committed Jan 26, 2024
1 parent 7828303 commit 7f81ac7
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 60 deletions.
16 changes: 8 additions & 8 deletions test/offline/test_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import pickle
import pprint
from test.utils import get_space_info, print_final_stats
from test.utils import print_final_stats

import gymnasium as gym
import numpy as np
Expand All @@ -18,6 +18,7 @@
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import MLP, Net
from tianshou.utils.net.continuous import VAE, Critic, Perturbation
from tianshou.utils.space_info import SpaceInfo

if __name__ == "__main__":
from gather_pendulum_data import expert_file_name, gather_data
Expand Down Expand Up @@ -77,14 +78,13 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None:
buffer = gather_data()
env = gym.make(args.task)

action_space_info = get_space_info(env.action_space)
observation_space_info = get_space_info(env.observation_space)
space_info = SpaceInfo.from_env(env.action_space, env.observation_space)

args.state_shape = observation_space_info.state_shape
args.action_shape = action_space_info.action_shape
args.max_action = action_space_info.max_action
args.state_dim = observation_space_info.state_dim
args.action_dim = action_space_info.action_dim
args.state_shape = space_info.observation_info.obs_shape
args.action_shape = space_info.action_info.action_shape
args.max_action = space_info.action_info.max_action
args.state_dim = space_info.observation_info.obs_dim
args.action_dim = space_info.action_info.action_dim

if args.reward_threshold is None:
# too low?
Expand Down
21 changes: 10 additions & 11 deletions test/offline/test_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import pickle
import pprint
from test.utils import get_space_info
from typing import cast

import gymnasium as gym
Expand All @@ -19,6 +18,7 @@
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.space_info import SpaceInfo

if __name__ == "__main__":
from gather_pendulum_data import expert_file_name, gather_data
Expand Down Expand Up @@ -81,17 +81,16 @@ def test_cql(args: argparse.Namespace = get_args()) -> None:
else:
buffer = gather_data()
env = gym.make(args.task)
action_space = cast(gym.spaces.Box, env.action_space)
env.action_space = cast(gym.spaces.Box, env.action_space)

action_space_info = get_space_info(action_space)
observation_space_info = get_space_info(env.observation_space)
space_info = SpaceInfo.from_env(env.action_space, env.observation_space)

args.state_shape = observation_space_info.state_shape
args.action_shape = action_space_info.action_shape
args.min_action = action_space_info.min_action
args.max_action = action_space_info.max_action
args.state_dim = observation_space_info.state_dim
args.action_dim = action_space_info.action_dim
args.state_shape = space_info.observation_info.obs_shape
args.action_shape = space_info.action_info.action_shape
args.min_action = space_info.action_info.min_action
args.max_action = space_info.action_info.max_action
args.state_dim = space_info.observation_info.obs_dim
args.action_dim = space_info.action_info.action_dim

if args.reward_threshold is None:
# too low?
Expand Down Expand Up @@ -150,7 +149,7 @@ def test_cql(args: argparse.Namespace = get_args()) -> None:
# CQL seems to perform better without action scaling
# TODO: investigate why
action_scaling=False,
action_space=action_space,
action_space=env.action_space,
cql_alpha_lr=args.cql_alpha_lr,
cql_weight=args.cql_weight,
tau=args.tau,
Expand Down
41 changes: 0 additions & 41 deletions test/utils.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,6 @@
from collections.abc import Sequence
from dataclasses import dataclass

import numpy as np
from gymnasium import spaces

from tianshou.data.collector import CollectStats


@dataclass
class SpaceInfo:
action_shape: int | Sequence[int]
state_shape: int | Sequence[int]
action_dim: int
state_dim: int
min_action: float
max_action: float


def get_space_info(
space: spaces.Space,
) -> SpaceInfo:
if isinstance(space, spaces.Box):
return SpaceInfo(
action_shape=space.shape,
state_shape=space.shape,
action_dim=space.shape[0],
state_dim=space.shape[0],
min_action=float(np.min(space.low)),
max_action=float(np.max(space.high)),
)
elif isinstance(space, spaces.Discrete):
return SpaceInfo(
action_shape=int(space.n),
state_shape=int(space.n),
action_dim=int(space.n),
state_dim=int(space.n),
min_action=float(space.start),
max_action=float(space.start + space.n - 1),
)
else:
raise NotImplementedError("Unsupported space type")


def print_final_stats(collect_stats: CollectStats) -> None:
if collect_stats.returns_stat is not None and collect_stats.lens_stat is not None:
print(
Expand Down
81 changes: 81 additions & 0 deletions tianshou/utils/space_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from collections.abc import Sequence
from dataclasses import dataclass, field

import numpy as np
from gymnasium import spaces


@dataclass(kw_only=True)
class ActionSpaceInfo:
action_shape: int | Sequence[int]
action_dim: int = field(init=False)
min_action: float
max_action: float

def __post_init__(self) -> None:
if isinstance(self.action_shape, int):
self.action_dim = self.action_shape
else:
self.action_dim = int(self.action_shape[0])

@classmethod
def from_space(cls, space: spaces.Space) -> "ActionSpaceInfo":
if isinstance(space, spaces.Box):
return cls(
action_shape=space.shape,
min_action=float(np.min(space.low)),
max_action=float(np.max(space.high)),
)
elif isinstance(space, spaces.Discrete):
return cls(
action_shape=int(space.n),
min_action=float(space.start),
max_action=float(space.start + space.n - 1),
)
else:
raise ValueError(
f"Unsupported space type: {space.__class__}. Currently supported types are Discrete and Box.",
)


@dataclass(kw_only=True)
class ObservationSpaceInfo:
obs_shape: int | Sequence[int]
obs_dim: int = field(init=False)

def __post_init__(self) -> None:
if isinstance(self.obs_shape, int):
self.obs_dim = self.obs_shape
else:
self.obs_dim = int(self.obs_shape[0])

@classmethod
def from_space(cls, space: spaces.Space) -> "ObservationSpaceInfo":
if isinstance(space, spaces.Box):
return cls(
obs_shape=space.shape,
)
elif isinstance(space, spaces.Discrete):
return cls(
obs_shape=int(space.n),
)
else:
raise ValueError(
f"Unsupported space type: {space.__class__}. Currently supported types are Discrete and Box.",
)


@dataclass(kw_only=True)
class SpaceInfo:
action_info: ActionSpaceInfo
observation_info: ObservationSpaceInfo

@classmethod
def from_env(cls, action_space: spaces.Space, observation_space: spaces.Space) -> "SpaceInfo":
action_info = ActionSpaceInfo.from_space(action_space)
observation_info = ObservationSpaceInfo.from_space(observation_space)

return cls(
action_info=action_info,
observation_info=observation_info,
)

0 comments on commit 7f81ac7

Please sign in to comment.