From 9856423abd63579fccbf2450acff50a4d3e1a07e Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 4 Nov 2024 13:06:04 +0100 Subject: [PATCH] Add support for gymnasium v1.0 (#261) * Add support for gymnasium v1.0 * Fix for gym v1.0 * Update CI matrix * Update SB3 min version * Fix warning --- .github/workflows/ci.yml | 20 ++++++++++++-------- docs/conda_env.yml | 2 +- docs/misc/changelog.rst | 3 ++- pyproject.toml | 2 ++ sb3_contrib/common/maskable/utils.py | 8 ++++++-- sb3_contrib/common/wrappers/time_feature.py | 1 + sb3_contrib/version.txt | 2 +- setup.py | 2 +- tests/test_lstm.py | 4 ++-- 9 files changed, 28 insertions(+), 16 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8f924ec0..c3a2ec21 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,7 +20,12 @@ jobs: strategy: matrix: python-version: ["3.8", "3.9", "3.10", "3.11"] - + include: + # Default version + - gymnasium-version: "1.0.0" + # Add a new config to test gym<1.0 + - python-version: "3.10" + gymnasium-version: "0.29.1" steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} @@ -36,19 +41,18 @@ jobs: # See https://github.com/astral-sh/uv/issues/1497 uv pip install --system torch==2.4.1+cpu --index https://download.pytorch.org/whl/cpu - # Install Atari Roms - uv pip install --system autorom - wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64 - base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz - AutoROM --accept-license --source-file Roms.tar.gz - # Install master version # and dependencies for docs and tests - uv pip install --system "stable_baselines3[extra_no_roms,tests,docs] @ git+https://github.com/DLR-RM/stable-baselines3" + uv pip install --system "stable_baselines3[extra,tests,docs] @ git+https://github.com/DLR-RM/stable-baselines3" uv pip install --system . # Use headless version uv pip install --system opencv-python-headless + - name: Install specific version of gym + run: | + uv pip install --system gymnasium==${{ matrix.gymnasium-version }} + # Only run for python 3.10, downgrade gym to 0.29.1 + - name: Lint with ruff run: | make lint diff --git a/docs/conda_env.yml b/docs/conda_env.yml index 0ed3efb1..a080a9db 100644 --- a/docs/conda_env.yml +++ b/docs/conda_env.yml @@ -8,7 +8,7 @@ dependencies: - python=3.11 - pytorch=2.5.0=py3.11_cpu_0 - pip: - - gymnasium>=0.28.1,<0.30 + - gymnasium>=0.29.1,<1.1.0 - stable-baselines3>=2.0.0,<3.0 - cloudpickle - opencv-python-headless diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index d86094c7..f611004f 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.4.0a10 (WIP) +Release 2.4.0a11 (WIP) -------------------------- **New algorithm: added CrossQ** @@ -16,6 +16,7 @@ New Features: ^^^^^^^^^^^^^ - Added ``CrossQ`` algorithm, from "Batch Normalization in Deep Reinforcement Learning" paper (@danielpalen) - Added ``BatchRenorm`` PyTorch layer used in ``CrossQ`` (@danielpalen) +- Added support for Gymnasium v1.0 Bug Fixes: ^^^^^^^^^^ diff --git a/pyproject.toml b/pyproject.toml index 72e01933..bbac8c41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,8 @@ env = ["PYTHONHASHSEED=0"] filterwarnings = [ # Tensorboard warnings "ignore::DeprecationWarning:tensorboard", + # tqdm warning about rich being experimental + "ignore:rich is experimental", ] markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"] diff --git a/sb3_contrib/common/maskable/utils.py b/sb3_contrib/common/maskable/utils.py index b4f09900..c20c06d5 100644 --- a/sb3_contrib/common/maskable/utils.py +++ b/sb3_contrib/common/maskable/utils.py @@ -16,7 +16,7 @@ def get_action_masks(env: GymEnv) -> np.ndarray: if isinstance(env, VecEnv): return np.stack(env.env_method(EXPECTED_METHOD_NAME)) else: - return getattr(env, EXPECTED_METHOD_NAME)() + return env.get_wrapper_attr(EXPECTED_METHOD_NAME)() def is_masking_supported(env: GymEnv) -> bool: @@ -35,4 +35,8 @@ def is_masking_supported(env: GymEnv) -> bool: except AttributeError: return False else: - return hasattr(env, EXPECTED_METHOD_NAME) + try: + env.get_wrapper_attr(EXPECTED_METHOD_NAME) + return True + except AttributeError: + return False diff --git a/sb3_contrib/common/wrappers/time_feature.py b/sb3_contrib/common/wrappers/time_feature.py index 91322641..7eeb9b11 100644 --- a/sb3_contrib/common/wrappers/time_feature.py +++ b/sb3_contrib/common/wrappers/time_feature.py @@ -44,6 +44,7 @@ def __init__(self, env: gym.Env, max_steps: int = 1000, test_mode: bool = False) low, high = obs_space.low, obs_space.high low, high = np.concatenate((low, [0.0])), np.concatenate((high, [1.0])) # type: ignore[arg-type] self.dtype = obs_space.dtype + low, high = low.astype(self.dtype), high.astype(self.dtype) if isinstance(env.observation_space, spaces.Dict): env.observation_space.spaces["observation"] = spaces.Box( diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index 852a32b3..d5cafdb5 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -2.4.0a10 +2.4.0a11 diff --git a/setup.py b/setup.py index c8e00ec7..d3a567d4 100644 --- a/setup.py +++ b/setup.py @@ -67,7 +67,7 @@ packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=2.4.0a6,<3.0", + "stable_baselines3>=2.4.0a11,<3.0", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin", diff --git a/tests/test_lstm.py b/tests/test_lstm.py index f0243dfc..d1cfa4e4 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -5,7 +5,7 @@ import pytest from gymnasium import spaces from gymnasium.envs.classic_control import CartPoleEnv -from gymnasium.wrappers.time_limit import TimeLimit +from gymnasium.wrappers import TimeLimit from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import make_vec_env @@ -43,7 +43,7 @@ def __init__(self): self.x_threshold * 2, self.theta_threshold_radians * 2, ] - ) + ).astype(np.float32) self.observation_space = spaces.Box(-high, high, dtype=np.float32) @staticmethod