From 25f8093c70ab101ec6b1b9856547c36953eb882b Mon Sep 17 00:00:00 2001 From: Tianqi Xu Date: Thu, 14 Nov 2024 14:21:41 +0300 Subject: [PATCH] [WIP] Move files to core --- .../core => crab/corev2}/benchmark.py | 4 +- .../core => crab/corev2}/environment.py | 2 +- .../core => crab/corev2}/evaluator.py | 4 + {refactor_demo/core => crab/corev2}/policy.py | 0 .../task_wrapper.py => crab/corev2/task.py | 16 +++- .../core => crab/corev2}/workflow.py | 0 .../envs => crab/envv2}/multi_env.py | 79 +++++++++++-------- ...lti_env_test.ipynb => multi_env_test.ipynb | 0 refactor_demo/core/task.py | 27 ------- 9 files changed, 66 insertions(+), 66 deletions(-) rename {refactor_demo/core => crab/corev2}/benchmark.py (92%) rename {refactor_demo/core => crab/corev2}/environment.py (98%) rename {refactor_demo/core => crab/corev2}/evaluator.py (91%) rename {refactor_demo/core => crab/corev2}/policy.py (100%) rename refactor_demo/core/task_wrapper.py => crab/corev2/task.py (88%) rename {refactor_demo/core => crab/corev2}/workflow.py (100%) rename {refactor_demo/envs => crab/envv2}/multi_env.py (54%) rename refactor_demo/envs/multi_env_test.ipynb => multi_env_test.ipynb (100%) delete mode 100644 refactor_demo/core/task.py diff --git a/refactor_demo/core/benchmark.py b/crab/corev2/benchmark.py similarity index 92% rename from refactor_demo/core/benchmark.py rename to crab/corev2/benchmark.py index 88aa52d..856e356 100644 --- a/refactor_demo/core/benchmark.py +++ b/crab/corev2/benchmark.py @@ -15,7 +15,7 @@ from typing import Iterable from .environment import Environment -from .evaluator import Evaluator +from .task import Task class Benchmark(ABC): @@ -28,5 +28,5 @@ def get_task_by_id(self, id: str) -> str: pass @abstractmethod - def tasks(self) -> Iterable[tuple[str, Evaluator]]: + def get_tall_tasks(self) -> Iterable[tuple[Task]]: pass diff --git a/refactor_demo/core/environment.py b/crab/corev2/environment.py similarity index 98% rename from refactor_demo/core/environment.py rename to crab/corev2/environment.py index b5c33b8..e70b5d0 100644 --- a/refactor_demo/core/environment.py +++ b/crab/corev2/environment.py @@ -18,7 +18,7 @@ from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam -class Environment(gym.Env, ABC): +class CrabEnvironment(gym.Env, ABC): """The base environment class for agents to interact with in the CRAB framework. Crab Environment is a subclass of `gymnasium.Env` and is designed to be a base class diff --git a/refactor_demo/core/evaluator.py b/crab/corev2/evaluator.py similarity index 91% rename from refactor_demo/core/evaluator.py rename to crab/corev2/evaluator.py index 82cb0f9..a5739f1 100644 --- a/refactor_demo/core/evaluator.py +++ b/crab/corev2/evaluator.py @@ -12,6 +12,10 @@ # limitations under the License. # =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== from abc import ABC, abstractmethod +from typing import Any + +from .environment import Environment +from .task import Task class Evaluator(ABC): diff --git a/refactor_demo/core/policy.py b/crab/corev2/policy.py similarity index 100% rename from refactor_demo/core/policy.py rename to crab/corev2/policy.py diff --git a/refactor_demo/core/task_wrapper.py b/crab/corev2/task.py similarity index 88% rename from refactor_demo/core/task_wrapper.py rename to crab/corev2/task.py index 2ae112a..886ec22 100644 --- a/refactor_demo/core/task_wrapper.py +++ b/crab/corev2/task.py @@ -16,7 +16,19 @@ import gymnasium as gym from gymnasium import Wrapper from gymnasium.core import ActType, ObsType, WrapperObsType -from gymnasium.spaces import Dict, Space, Text, Tuple +from gymnasium.spaces import Dict, Text, Tuple +from pydantic import BaseModel, ConfigDict + +from .evaluator import Evaluator + + +class Task(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + id: str + description: str + evaluator: Evaluator + setup: callable[(), None] + # extra_action: list[Action] = [] class TaskWrapper(Wrapper[WrapperObsType, ActType, ObsType, ActType]): @@ -61,7 +73,7 @@ def step( self, action: ActType ) -> tuple[WrapperObsType, float, bool, bool, dict[str, Any]]: observation, reward, terminal, truncated, info = self.step(action) - reward = self.task.evaluate(self.env) + reward = self.task.evaluator.step(self.env) return self.observation(observation), reward, terminal, truncated, info def observation(self, observation: ObsType): diff --git a/refactor_demo/core/workflow.py b/crab/corev2/workflow.py similarity index 100% rename from refactor_demo/core/workflow.py rename to crab/corev2/workflow.py diff --git a/refactor_demo/envs/multi_env.py b/crab/envv2/multi_env.py similarity index 54% rename from refactor_demo/envs/multi_env.py rename to crab/envv2/multi_env.py index 0d13895..1f430dd 100644 --- a/refactor_demo/envs/multi_env.py +++ b/crab/envv2/multi_env.py @@ -11,13 +11,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== +from typing import Any + import gymnasium as gym -import numpy as np from gymnasium import spaces +from crab.corev2.environment import CrabEnvironment + -class MultiEnv(gym.Env): - def __init__(self, envs): +class MultiEnv(gym.Env[dict[str, Any], tuple[int, Any]]): + def __init__(self, envs: dict[str, CrabEnvironment]): """ Initialize the MultiEnv environment. @@ -28,15 +31,26 @@ def __init__(self, envs): # Store the environments self.envs = envs - # Create action space using OneOf with the action spaces of each environment self.action_space = spaces.OneOf([env.action_space for env in envs]) - # Create observation space as a Dict space containing each environment's observation space + # Create observation space as a Dict space containing each environment's + # observation space self.observation_space = spaces.Dict( - {f"env_{i}": env.observation_space for i, env in enumerate(envs)} + {name: env.observation_space for name, env in envs.items()} ) + self._idx_to_env_name = { + idx: env_name for idx, env_name in enumerate(envs.keys()) + } + self._env_name_to_index = { + env_name: idx for idx, env_name in enumerate(envs.keys()) + } + + self._saved_observations = {key: None for key in envs.keys()} + self._saved_infos = {key: None for key in envs.keys()} + self._saved_dones = {key: False for key in envs.keys()} + def reset(self): """ Reset all environments and return initial observations. @@ -45,51 +59,48 @@ def reset(self): dict: A dictionary with initial observations from each environment. """ observations = {} - for i, env in enumerate(self.envs): - observations[f"env_{i}"], _ = env.reset() - return observations + infos = {} + for name, env in self.envs.items(): + observations[name], infos[name] = env.reset() + self._saved_observations = observations + self._saved_infos = infos + return observations, infos - def step(self, action): + def step(self, action: tuple[int, Any]): """ Take a step in the selected environment based on the action. Args: - action (int): The index of the environment to take a step in. + action: The index of the environment to take a step in. Returns: tuple: A tuple containing the observations, rewards, done flags, and info. """ - assert 0 <= action < len(self.envs), "Invalid action for environment selection." + env_idx, actual_action = action + env_name = self._idx_to_env_name[env_idx] + assert ( + 0 <= env_idx < len(self.envs) + ), "Invalid action for environment selection." + assert self.action_space[env_idx].contains( + actual_action + ), f"{actual_action!r} ({type(actual_action)}) invalid in {env_name}" - # Initialize dictionaries to store results - observations = {} - rewards = {} - dones = {} - infos = {} + env = self.envs[env_idx] + + reward = 0 # No reward in bare MultiEnv # Perform a step in the selected environment - obs, reward, done, truncated, info = self.envs[action].step(action) + obs, reward, done, truncated, info = env.step(actual_action) # Populate results for the selected environment - observations[f"env_{action}"] = obs - rewards[f"env_{action}"] = reward - dones[f"env_{action}"] = done - infos[f"env_{action}"] = info - - # For other environments, simply pass their previous observations - for i, env in enumerate(self.envs): - if i != action: - observations[f"env_{i}"] = ( - None # No new observation for non-acting environments - ) - rewards[f"env_{i}"] = 0 - dones[f"env_{i}"] = False - infos[f"env_{i}"] = {} + self._saved_observations[env_name] = obs + self._saved_dones[env_name] = done + self._saved_infos[env_name] = info # Set done if all environments are done - all_done = all(dones.values()) + all_done = all(self._saved_dones.values()) - return observations, rewards, all_done, infos + return self._saved_observations, reward, all_done, truncated, self._saved_infos def render(self, mode="human"): """ diff --git a/refactor_demo/envs/multi_env_test.ipynb b/multi_env_test.ipynb similarity index 100% rename from refactor_demo/envs/multi_env_test.ipynb rename to multi_env_test.ipynb diff --git a/refactor_demo/core/task.py b/refactor_demo/core/task.py deleted file mode 100644 index a8f2f62..0000000 --- a/refactor_demo/core/task.py +++ /dev/null @@ -1,27 +0,0 @@ -# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== -# Licensed under the Apache License, Version 2.0 (the “License”); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an “AS IS” BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== -from abc import ABC, abstractmethod -from pydantic import BaseModel, ConfigDict -from typing import Any - -from .environment import Environment - - -class Task(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - id: str - description: str - evaluator: Evaluator - setup: setup = [] - extra_action: list[Action] = []