Skip to content

Commit

Permalink
[WIP] Move files to core
Browse files Browse the repository at this point in the history
  • Loading branch information
dandansamax committed Nov 14, 2024
1 parent 4ae656b commit 25f8093
Show file tree
Hide file tree
Showing 9 changed files with 66 additions and 66 deletions.
4 changes: 2 additions & 2 deletions refactor_demo/core/benchmark.py → crab/corev2/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Iterable

from .environment import Environment
from .evaluator import Evaluator
from .task import Task


class Benchmark(ABC):
Expand All @@ -28,5 +28,5 @@ def get_task_by_id(self, id: str) -> str:
pass

@abstractmethod
def tasks(self) -> Iterable[tuple[str, Evaluator]]:
def get_tall_tasks(self) -> Iterable[tuple[Task]]:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam


class Environment(gym.Env, ABC):
class CrabEnvironment(gym.Env, ABC):
"""The base environment class for agents to interact with in the CRAB framework.
Crab Environment is a subclass of `gymnasium.Env` and is designed to be a base class
Expand Down
4 changes: 4 additions & 0 deletions refactor_demo/core/evaluator.py → crab/corev2/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
# limitations under the License.
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
from abc import ABC, abstractmethod
from typing import Any

from .environment import Environment
from .task import Task


class Evaluator(ABC):
Expand Down
File renamed without changes.
16 changes: 14 additions & 2 deletions refactor_demo/core/task_wrapper.py → crab/corev2/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,19 @@
import gymnasium as gym
from gymnasium import Wrapper
from gymnasium.core import ActType, ObsType, WrapperObsType
from gymnasium.spaces import Dict, Space, Text, Tuple
from gymnasium.spaces import Dict, Text, Tuple
from pydantic import BaseModel, ConfigDict

from .evaluator import Evaluator


class Task(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
id: str
description: str
evaluator: Evaluator
setup: callable[(), None]
# extra_action: list[Action] = []


class TaskWrapper(Wrapper[WrapperObsType, ActType, ObsType, ActType]):
Expand Down Expand Up @@ -61,7 +73,7 @@ def step(
self, action: ActType
) -> tuple[WrapperObsType, float, bool, bool, dict[str, Any]]:
observation, reward, terminal, truncated, info = self.step(action)
reward = self.task.evaluate(self.env)
reward = self.task.evaluator.step(self.env)
return self.observation(observation), reward, terminal, truncated, info

def observation(self, observation: ObsType):
Expand Down
File renamed without changes.
79 changes: 45 additions & 34 deletions refactor_demo/envs/multi_env.py → crab/envv2/multi_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
from typing import Any

import gymnasium as gym
import numpy as np
from gymnasium import spaces

from crab.corev2.environment import CrabEnvironment


class MultiEnv(gym.Env):
def __init__(self, envs):
class MultiEnv(gym.Env[dict[str, Any], tuple[int, Any]]):
def __init__(self, envs: dict[str, CrabEnvironment]):
"""
Initialize the MultiEnv environment.
Expand All @@ -28,15 +31,26 @@ def __init__(self, envs):

# Store the environments
self.envs = envs

# Create action space using OneOf with the action spaces of each environment
self.action_space = spaces.OneOf([env.action_space for env in envs])

# Create observation space as a Dict space containing each environment's observation space
# Create observation space as a Dict space containing each environment's
# observation space
self.observation_space = spaces.Dict(
{f"env_{i}": env.observation_space for i, env in enumerate(envs)}
{name: env.observation_space for name, env in envs.items()}
)

self._idx_to_env_name = {
idx: env_name for idx, env_name in enumerate(envs.keys())
}
self._env_name_to_index = {
env_name: idx for idx, env_name in enumerate(envs.keys())
}

self._saved_observations = {key: None for key in envs.keys()}
self._saved_infos = {key: None for key in envs.keys()}
self._saved_dones = {key: False for key in envs.keys()}

def reset(self):
"""
Reset all environments and return initial observations.
Expand All @@ -45,51 +59,48 @@ def reset(self):
dict: A dictionary with initial observations from each environment.
"""
observations = {}
for i, env in enumerate(self.envs):
observations[f"env_{i}"], _ = env.reset()
return observations
infos = {}
for name, env in self.envs.items():
observations[name], infos[name] = env.reset()
self._saved_observations = observations
self._saved_infos = infos
return observations, infos

def step(self, action):
def step(self, action: tuple[int, Any]):
"""
Take a step in the selected environment based on the action.
Args:
action (int): The index of the environment to take a step in.
action: The index of the environment to take a step in.
Returns:
tuple: A tuple containing the observations, rewards, done flags, and info.
"""
assert 0 <= action < len(self.envs), "Invalid action for environment selection."
env_idx, actual_action = action
env_name = self._idx_to_env_name[env_idx]
assert (
0 <= env_idx < len(self.envs)
), "Invalid action for environment selection."
assert self.action_space[env_idx].contains(
actual_action
), f"{actual_action!r} ({type(actual_action)}) invalid in {env_name}"

# Initialize dictionaries to store results
observations = {}
rewards = {}
dones = {}
infos = {}
env = self.envs[env_idx]

reward = 0 # No reward in bare MultiEnv

# Perform a step in the selected environment
obs, reward, done, truncated, info = self.envs[action].step(action)
obs, reward, done, truncated, info = env.step(actual_action)

# Populate results for the selected environment
observations[f"env_{action}"] = obs
rewards[f"env_{action}"] = reward
dones[f"env_{action}"] = done
infos[f"env_{action}"] = info

# For other environments, simply pass their previous observations
for i, env in enumerate(self.envs):
if i != action:
observations[f"env_{i}"] = (
None # No new observation for non-acting environments
)
rewards[f"env_{i}"] = 0
dones[f"env_{i}"] = False
infos[f"env_{i}"] = {}
self._saved_observations[env_name] = obs
self._saved_dones[env_name] = done
self._saved_infos[env_name] = info

# Set done if all environments are done
all_done = all(dones.values())
all_done = all(self._saved_dones.values())

return observations, rewards, all_done, infos
return self._saved_observations, reward, all_done, truncated, self._saved_infos

def render(self, mode="human"):
"""
Expand Down
File renamed without changes.
27 changes: 0 additions & 27 deletions refactor_demo/core/task.py

This file was deleted.

0 comments on commit 25f8093

Please sign in to comment.