diff --git a/README.md b/README.md index f5f9c05e..9d6683d0 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,21 @@ The algorithms sheeped by sheeprl out-of-the-box are: and more are coming soon! [Open a PR](https://github.com/Eclectic-Sheep/sheeprl/pulls) if you have any particular request :sheep: + +The environments supported by sheeprl are: + +| Algorithm | Installation command | More info | Status | +| ------------------ | ---------------------------- | ----------------------------------------------- | ------------------ | +| Classic Control | `pip install -e .` | | :heavy_check_mark: | +| Box2D | `pip install -e .` | | :heavy_check_mark: | +| Mujoco (Gymnasium) | `pip install -e .` | [how_to/mujoco](./howto/learn_in_dmc.md) | :heavy_check_mark: | +| Atari | `pip install -e .[atari]` | [how_to/atari](./howto/learn_in_atari.md) | :heavy_check_mark: | +| DeepMind Control | `pip install -e .[dmc]` | [how_to/dmc](./howto/learn_in_dmc.md) | :heavy_check_mark: | +| MineRL | `pip install -e .[minerl]` | [how_to/minerl](./howto/learn_in_minerl.md) | :heavy_check_mark: | +| MineDojo | `pip install -e .[minedojo]` | [how_to/minedojo](./howto/learn_in_minedojo.md) | :heavy_check_mark: | +| DIAMBRA | | | :construction: | + + ## Why We want to provide a framework for RL algorithms that is at the same time simple and scalable thanks to Lightning Fabric. @@ -75,6 +90,8 @@ pip install "sheeprl @ git+https://github.com/Eclectic-Sheep/sheeprl.git" pip install "sheeprl[atari,mujoco,dev] @ git+https://github.com/Eclectic-Sheep/sheeprl.git" # or, to install with minedojo environment support, do pip install "sheeprl[minedojo,dev] @ git+https://github.com/Eclectic-Sheep/sheeprl.git" +# or, to install with minedojo environment support, do +pip install "sheeprl[minerl,dev] @ git+https://github.com/Eclectic-Sheep/sheeprl.git" # or, to install all extras, do pip install "sheeprl[atari,mujoco,miedojo,dev,test] @ git+https://github.com/Eclectic-Sheep/sheeprl.git" ``` @@ -86,6 +103,8 @@ pip install "sheeprl[atari,mujoco,miedojo,dev,test] @ git+https://github.com/Ec > if you are on an M-series mac and encounter an error attributed box2dpy during install, you need to install SWIG using the instructions shown below. > > if you want to install the minedojo environment support, Java JDK 8 is required: you can install it by following the instructions at this [link](https://docs.minedojo.org/sections/getting_started/install.html#on-ubuntu-20-04). +> +> **MineRL** and **MineDojo** environments have **conflicting requirements**, so **DO NOT install them together** with the `pip install -e .[minerl,minedojo]` command, but instead **install them individually** with either the command `pip install -e .[minerl]` or `pip install -e .[minedojo]` before running an experiment with the MineRL or MineDojo environment, respectively.
Installing SWIG diff --git a/howto/learn_in_atari.md b/howto/learn_in_atari.md index e3ce5a94..bee5f032 100644 --- a/howto/learn_in_atari.md +++ b/howto/learn_in_atari.md @@ -9,9 +9,7 @@ The code for this section is available in `algos/ppo_pixel/ppo_atari.py`. First we should install the Atari environments with: ```bash -pip install gymnasium[other] -pip install gymnasium[atari] -pip install gymnasium[accept-rom-license] +pip install .[atari] ``` For more information: https://gymnasium.farama.org/environments/atari/ @@ -154,14 +152,19 @@ Options: --sheeprl_help Show this message and exit. Commands: + dreamer_v1 + dreamer_v2 droq + p2e_dv1 ppo ppo_atari ppo_continuous ppo_decoupled + ppo_pixel_continuous ppo_recurrent sac sac_decoupled + sac_pixel_continuous ``` Once this is done, we are all set. We can now train the model by running: diff --git a/howto/learn_in_dmc.md b/howto/learn_in_dmc.md new file mode 100644 index 00000000..4f97d1c2 --- /dev/null +++ b/howto/learn_in_dmc.md @@ -0,0 +1,22 @@ +## Install Gymnasium MuJoCo/DMC environments +First you should install the proper environments: + +- MuJoCo (Gymnasium): you do not need to install extra pakages, the `pip install -e .` command is enough to have available all the MuJoCo environments provided by Gym +- DMC: you have to install extra packages with the following command: `pip install -e .[dmc]`. + +## Install OpenGL rendering backands packages + +MuJoCo supports three different OpenGL rendering backends: EGL (headless), GLFW (windowed), OSMesa (headless). +For each of them, you need to install some pakages: +- GLFW: `sudo apt-get install libglfw3 libglew2.0` +- EGL: `sudo apt-get install libglew2.0` +- OSMesa: `sudo apt-get install libgl1-mesa-glx libosmesa6` +In order to use one of these rendering backends, you need to set the `MUJOCO_GL` environment variable to `"glfw"`, `"egl"`, `"osmesa"`, respectively. + +For more information: [https://github.com/deepmind/dm_control](https://github.com/deepmind/dm_control) and [https://mujoco.readthedocs.io/en/stable/programming/index.html#using-opengl](https://mujoco.readthedocs.io/en/stable/programming/index.html#using-opengl) + +## MuJoCo Gymnasium +In order to train your agents on the [MuJoCo environments](https://gymnasium.farama.org/environments/mujoco/) provided by Gymnasium, it is sufficient to set the `env_id` with the name of the environment you want to use. For instance, `"Walker2d-v4"` if you want to train your agent on the *walker walk* environment. + +## DeepMind Control +In order to train your agents on the [DeepMind control suite](https://github.com/deepmind/dm_control/blob/main/dm_control/suite/README.md), you have to prefix `"dmc_"` to the environment you want to use. A list of the available environments can be found [here](https://arxiv.org/abs/1801.00690). For instance, if you want to train your agent on the *walker walk* environment, you need to set the `env_is` to `"dmc_walker_walk"`. \ No newline at end of file diff --git a/howto/learn_in_minedojo.md b/howto/learn_in_minedojo.md new file mode 100644 index 00000000..c43d33ed --- /dev/null +++ b/howto/learn_in_minedojo.md @@ -0,0 +1,54 @@ +## Install MineDojo environment +First you need to install the JDK 1.8, on Debian based systems you can run the following: + +```bash +sudo apt update -y +sudo apt install -y software-properties-common +sudo add-apt-repository ppa:openjdk-r/ppa +sudo apt update -y +sudo apt install -y openjdk-8-jdk +sudo update-alternatives --config java +``` + +> **Note** +> +> If you work on another OS, you can follow the instructions [here](https://docs.minedojo.org/sections/getting_started/install.html#on-macos) to install JDK 1.8. + +Now, you can install the MineDojo environment: + +```bash +pip install -e .[minedojo] +``` + +## MineRL environments +It is possible to train your agents on all the tasks provided by MineDojo, you need to prefix `"minedojo"` to the `task_id` of the task on which you want to train your agent, and pass it to the `env_id` argument. +For instance, you have to set the `env_id` argument to `"minedojo_open-ended"` to select the MineDojo open-ended environment. + +### Observation Space +We slightly modified the observation space, by reshaping it (based on the idea proposed by Hafner in [DreamerV3](https://arxiv.org/abs/2301.04104)): +- We represent the inventory with a vector with one entry for each item of the game which gives the quantity of the corresponding item in the inventory. +- A max inventory vector with one entry for each item which contains the maximum number of items obtained by the agent so far in the episode. +- A delta inventory vector with one entry for each item which contains the difference of the items in the inventory after the performed action. +- The RGB first-person camera image. +- A vector of three elements representing the life, the food and the oxygen levels of the agent. +- A one-hot vectir indicating the equipped item. +- A mask for the action type indicating which actions can be executed. +- A mask for the equip/place arguments indicating which elements can be equipped or placed.. +- A mask for the destroy arguments indicating which items can be destroyed. +- A mask for *craft smelt* indicating which items can be crafted. + +### Action Space +We decided to convert the 8 multi-discrete action space into a 3 multi-discrete action space: the first maps all the functional actions (movement, craft, jump, camera, attack, ...); the second one maps the argument for the *craf* action; the third one maps the argument for the *equip*, *place*, and *destroy* actions. Moreover, we restrict the look up/down actions between `min_pitch` and `max_pitch` degrees. +In addition, we added the forward action when the agent selects one of the follwing actions: `jump`, `sprint`, and `sneak`. +Finally we added sticky action for the `jump` and `attack` actions. You can set the values of the `sticky_jump` and `sticky_attack` parameters through the `mine_sticky_jump` and `mine_sticky_attack` arguments, respectively. The sticky actions, if set, force the agent to repeat the selected actions for a certain number of steps. + +> **Note** +> Since the MineDojo environments have a multi-discrete action space, the sticky actions can be easily implemented. The agent will perform the selected action and the sticky actions simultaneously. +> +> The action repeat in the Minecraft environments is set to 1, indedd, It makes no sense to force the agent to repeat an action such as crafting (it may not have enough material for the second action). + +## Headless machines + +If you work on a headless machine, you need to software renderer. We recommend to adopt one of the following solutions: +1. Install the `xvfb` software with the `sudo apt install xvfb` command and prefix the train command with `xvfb-run`. For instance, to train DreamerV2 on the navigate task on an headless machine, you need to run the following command: `xvfb-run lightning run model --devices=1 sheeprl.py dreamer_v2 --env_id=minedojo_open-ended`, or `MINEDOJO_HEADLESS=1 lightning run model --devices=1 sheeprl.py dreamer_v2 --env_id=minedojo_open-ended`. +2. Exploit the [PyVirtualDisplay](https://github.com/ponty/PyVirtualDisplay) package. \ No newline at end of file diff --git a/howto/learn_in_minerl.md b/howto/learn_in_minerl.md new file mode 100644 index 00000000..bb81c386 --- /dev/null +++ b/howto/learn_in_minerl.md @@ -0,0 +1,56 @@ +## Install MineRL environment +First you need to install the JDK 1.8, on Debian based systems you can run the following: + +```bash +sudo add-apt-repository ppa:openjdk-r/ppa +sudo apt-get update +sudo apt-get install openjdk-8-jdk +``` + +> **Note** +> +> If you work on another OS, you can follow the instructions [here](https://minerl.readthedocs.io/en/v0.4.4/tutorials/index.html) to install JDK 1.8. + +Now, you can install the MineRL environment: + +```bash +pip install -e .[minerl] +``` + +## MineRL environments +We modified the MineRL environments to have a custom action and observation space. We provide three different tasks: +1. Navigate: you need to set the `env_id` argument to `"minerl_custom_navigate"`. +2. Obtain Iron Pickaxe: you need to set the `env_id` argument to `"minerl_custom_obtain_iron_pickaxe"`. +3. Obtain Diamond: you need to set the `env_id` argument to `"minerl_custom_obtain_diamond"`. + +> **Note** +> In all these environments, it is possible to have or not a dense reward, you can set the type of the reward by setting the `minerl_dense` argument to `True` if you want a dense reward, to `False` otherwise. +> +> In the Navigate task, it is also the possibility to choose wheter or not to train the agent on an extreme environment (For more info, check [here](https://minerl.readthedocs.io/en/v0.4.4/environments/index.html#minerlnavigateextreme-v0)). To choose wheter or not to train the agent on an extreme environment, you need to set the `minerl_extreme` argument to `True` or `False`. +> +> In addition, in all the environments, it is possible to set the break speed multiplier through the `mine_break_speed` argument. + +### Observation Space +We slightly modified the observation space, by adding the *life stats* (life, food and oxygen) and reshaping those already present (based on the idea proposed by Hafner in [DreamerV3](https://arxiv.org/abs/2301.04104)): +- We represent the inventory with a vector with one entry for each item of the game which gives the quantity of the corresponding item in the inventory. +- A max inventory vector with one entry for each item which contains the maximum number of items obtained by the agent so far in the episode. +- The RGB first-person camera image. +- A vector of three elements representing the life, the food and the oxygen levels of the agent. +- A one-hot vectir indicating the equipped item, only for the *obtain* tasks. +- A scalar indicating the compass angle to the goal location, only for the *navigate* tasks. + +### Action Space +We decided to convert the multi-discrete action space into a discrete action space. Moreover, we restrict the look up/down actions between `min_pitch` and `max_pitch` degrees. +In addition, we added the forward action when the agent selects one of the follwing actions: `jump`, `sprint`, and `sneak`. +Finally we added sticky action for the `jump` and `attack` actions. You can set the values of the `sticky_jump` and `sticky_attack` parameters through the `mine_sticky_jump` and `mine_sticky_attack` arguments, respectively. The sticky actions, if set, force the agent to repeat the selected actions for a certain number of steps. + +> **Note** +> Since the MineRL environments have a multi-discrete action space, the sticky actions can be easily implemented. The agent will perform the selected action and the sticky actions simultaneously. +> +> The action repeat in the Minecraft environments is set to 1, indedd, It makes no sense to force the agent to repeat an action such as crafting (it may not have enough material for the second action). + +## Headless machines + +If you work on a headless machine, you need to software renderer. We recommend to adopt one of the following solutions: +1. Install the `xvfb` software with the `sudo apt install xvfb` command and prefix the train command with `xvfb-run`. For instance, to train DreamerV2 on the navigate task on an headless machine, you need to run the following command: `xvfb-run lightning run model --devices=1 sheeprl.py dreamer_v2 --env_id=minerl_custom_navigate`. +2. Exploit the [PyVirtualDisplay](https://github.com/ponty/PyVirtualDisplay) package. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 8fbde727..338c80aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ atari = [ "gymnasium[other]==0.28.*", ] minedojo = ["minedojo==0.1"] +minerl = ["minerl==0.4.4"] [tool.ruff] line-length = 120 diff --git a/sheeprl/__init__.py b/sheeprl/__init__.py index e3002003..0a898470 100644 --- a/sheeprl/__init__.py +++ b/sheeprl/__init__.py @@ -23,5 +23,11 @@ except ModuleNotFoundError: pass +# Needed because MineRL 0.4.4 is not compatible with the latest version of numpy +import numpy as np + +np.float = np.float32 +np.int = np.int64 +np.bool = bool __version__ = "0.1.0" diff --git a/sheeprl/algos/dreamer_v2/args.py b/sheeprl/algos/dreamer_v2/args.py index 89c7d026..4e012cad 100644 --- a/sheeprl/algos/dreamer_v2/args.py +++ b/sheeprl/algos/dreamer_v2/args.py @@ -108,3 +108,8 @@ class DreamerV2Args(StandardArgs): mine_start_position: Optional[List[str]] = Arg( default=None, help="The starting position of the agent in Minecraft environment. (x, y, z, pitch, yaw)" ) + minerl_dense: bool = Arg(default=False, help="whether or not the task has dense reward") + minerl_extreme: bool = Arg(default=False, help="whether or not the task is extreme") + mine_break_speed: int = Arg(default=100, help="the break speed multiplier of Minecraft environments") + mine_sticky_attack: int = Arg(default=30, help="the sticky value for the attack action") + mine_sticky_jump: int = Arg(default=10, help="the sticky value for the jump action") diff --git a/sheeprl/algos/dreamer_v2/utils.py b/sheeprl/algos/dreamer_v2/utils.py index 9f38ac0e..872df04d 100644 --- a/sheeprl/algos/dreamer_v2/utils.py +++ b/sheeprl/algos/dreamer_v2/utils.py @@ -85,6 +85,22 @@ def make_env( start_position=start_position, ) args.action_repeat = 1 + elif "minerl" in _env_id: + from sheeprl.envs.minerl import MineRLWrapper + + task_id = "_".join(env_id.split("_")[1:]) + env = MineRLWrapper( + task_id, + height=64, + width=64, + pitch_limits=(args.mine_min_pitch, args.mine_max_pitch), + seed=args.seed, + break_speed_multiplier=args.mine_break_speed, + sticky_attack=args.mine_sticky_attack, + sticky_jump=args.mine_sticky_jump, + dense=args.minerl_dense, + extreme=args.minerl_extreme, + ) else: env_spec = gym.spec(env_id).entry_point env = gym.make(env_id, render_mode="rgb_array") diff --git a/sheeprl/envs/minerl.py b/sheeprl/envs/minerl.py new file mode 100644 index 00000000..000ed2e2 --- /dev/null +++ b/sheeprl/envs/minerl.py @@ -0,0 +1,204 @@ +import copy +from typing import Any, Dict, Optional, SupportsFloat, Tuple + +import gymnasium +import minerl +import numpy as np +from gymnasium import core +from minerl.herobraine.hero import mc + +from sheeprl.envs.minerl_envs.navigate import CustomNavigate +from sheeprl.envs.minerl_envs.obtain import CustomObtainDiamond, CustomObtainIronPickaxe + +# In order to use the environment as a gym you need to register it with gym +CUSTOM_ENVS = { + "custom_navigate": CustomNavigate, + "custom_obtain_diamond": CustomObtainDiamond, + "custom_obtain_iron_pickaxe": CustomObtainIronPickaxe, +} + + +N_ALL_ITEMS = len(mc.ALL_ITEMS) +NOOP = { + "camera": (0, 0), + "forward": 0, + "back": 0, + "left": 0, + "right": 0, + "attack": 0, + "sprint": 0, + "jump": 0, + "sneak": 0, + "craft": "none", + "nearbyCraft": "none", + "nearbySmelt": "none", + "place": "none", + "equip": "none", +} +ITEM_ID_TO_NAME = dict(enumerate(mc.ALL_ITEMS)) +ITEM_NAME_TO_ID = dict(zip(mc.ALL_ITEMS, range(N_ALL_ITEMS))) + + +class MineRLWrapper(core.Env): + def __init__( + self, + task_id: str, + height: int = 64, + width: int = 64, + pitch_limits: Tuple[int, int] = (-60, 60), + seed: Optional[int] = None, + sticky_attack: Optional[int] = 30, + sticky_jump: Optional[int] = 10, + break_speed_multiplier: Optional[int] = 100, + **kwargs: Optional[Dict[Any, Any]], + ): + self._height = height + self._width = width + self._pitch_limits = pitch_limits + self._sticky_attack = sticky_attack + self._sticky_jump = sticky_jump + self._sticky_attack_counter = 0 + self._sticky_jump_counter = 0 + self._break_speed_multiplier = break_speed_multiplier + if "navigate" not in task_id.lower(): + kwargs.pop("extreme", None) + + self._env = CUSTOM_ENVS[task_id.lower()](break_speed=break_speed_multiplier, **kwargs).make() + self.ACTIONS_MAP = {0: {}} + act_idx = 1 + for act in self._env.action_space: + if isinstance(self._env.action_space[act], minerl.herobraine.hero.spaces.Enum): + act_val = set(self._env.action_space[act].values.tolist()) - {"none"} + act_len = len(act_val) + elif act != "camera": + act_len = 1 + act_val = [1] + else: + act_len = 4 + act_val = [ + np.array([-15, 0]), + np.array([15, 0]), + np.array([0, -15]), + np.array([0, 15]), + ] + action = dict(zip((np.arange(act_len) + act_idx).tolist(), [{act: v} for v in act_val])) + if act in {"jump", "sneak", "sprint"}: + action[act_idx]["forward"] = 1 + self.ACTIONS_MAP.update(action) + act_idx += act_len + + # action and observations space + self.action_space = gymnasium.spaces.Discrete(len(self.ACTIONS_MAP)) + + obs_space = { + "rgb": gymnasium.spaces.Box(0, 255, (3, 64, 64), np.uint8), + "life_stats": gymnasium.spaces.Box(0.0, np.array([20.0, 20.0, 300.0]), (3,), np.float32), + "inventory": gymnasium.spaces.Box(0.0, np.inf, (N_ALL_ITEMS,), np.float32), + "max_inventory": gymnasium.spaces.Box(0.0, np.inf, (N_ALL_ITEMS,), np.float32), + } + if "compass" in self._env.observation_space.spaces: + obs_space["compass"] = gymnasium.spaces.Box(-180, 180, (1,), np.float32) + if "equipped_items" in self._env.observation_space.spaces: + obs_space["equipment"] = gymnasium.spaces.Box(0.0, 1.0, (N_ALL_ITEMS,), np.int32) + self.observation_space = gymnasium.spaces.Dict(obs_space) + self._pos = { + "pitch": 0.0, + "yaw": 0.0, + } + self._max_inventory = np.zeros(N_ALL_ITEMS) + self.render_mode: str = "rgb_array" + self.seed(seed=seed) + + def __getattr__(self, name): + return getattr(self._env, name) + + def _convert_actions(self, action: np.ndarray) -> Dict[str, Any]: + converted_actions = copy.deepcopy(NOOP) + converted_actions.update(self.ACTIONS_MAP[action.item()]) + if self._sticky_attack: + if converted_actions["attack"]: + self._sticky_attack_counter = self._sticky_attack + if self._sticky_attack_counter > 0: + converted_actions["attack"] = 1 + converted_actions["jump"] = 0 + self._sticky_attack_counter -= 1 + if self._sticky_jump: + if converted_actions["jump"]: + self._sticky_jump_counter = self._sticky_jump + if self._sticky_jump_counter > 0: + converted_actions["jump"] = 1 + converted_actions["forward"] = 1 + self._sticky_jump_counter -= 1 + return converted_actions + + def _convert_equipment(self, equipment: Dict[str, Any]) -> np.ndarray: + equip = np.zeros(N_ALL_ITEMS, dtype=np.int32) + equip[ITEM_NAME_TO_ID[equipment["mainhand"]["type"]]] = 1 + return equip + + def _convert_inventory(self, inventory: Dict[str, Any]) -> Dict[str, np.ndarray]: + # the inventory counts, as a vector with one entry for each Minecraft item + converted_inventory = {"inventory": np.zeros(N_ALL_ITEMS)} + for i, (item, quantity) in enumerate(inventory.items()): + # count the items in the inventory + if item == "air": + converted_inventory["inventory"][ITEM_NAME_TO_ID[item]] += 1 + else: + converted_inventory["inventory"][ITEM_NAME_TO_ID[item]] += quantity + converted_inventory["max_inventory"] = np.maximum(converted_inventory["inventory"], self._max_inventory) + self._max_inventory = converted_inventory["max_inventory"].copy() + return converted_inventory + + def _convert_obs(self, obs: Dict[str, Any]) -> Dict[str, np.ndarray]: + converted_obs = { + "rgb": obs["pov"].copy().transpose(2, 0, 1), + "life_stats": np.array( + [obs["life_stats"]["life"], obs["life_stats"]["food"], obs["life_stats"]["air"]], dtype=np.float32 + ), + **self._convert_inventory(obs["inventory"]), + } + if "equipment" in self.observation_space.spaces: + converted_obs["equipment"] = self._convert_equipment(obs["equipped_items"]) + if "compass" in self.observation_space.spaces: + converted_obs["compass"] = obs["compass"]["angle"].reshape(-1) + return converted_obs + + def seed(self, seed: Optional[int] = None) -> None: + self.observation_space.seed(seed) + self.action_space.seed(seed) + + def step(self, actions: np.ndarray) -> Tuple[Dict[str, Any], SupportsFloat, bool, bool, Dict[str, Any]]: + converted_actions = self._convert_actions(actions) + next_pitch = self._pos["pitch"] + converted_actions["camera"][0] + next_yaw = ((self._pos["yaw"] + converted_actions["camera"][1]) + 180) % 360 - 180 + if not (self._pitch_limits[0] <= next_pitch <= self._pitch_limits[1]): + converted_actions["camera"] = np.array([0, converted_actions["camera"][1]]) + next_pitch = self._pos["pitch"] + + obs, reward, done, info = self._env.step(converted_actions) + self._pos = { + "pitch": next_pitch, + "yaw": next_yaw, + } + info = {} + return self._convert_obs(obs), reward, done, False, info + + def reset( + self, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None + ) -> Tuple[np.ndarray, Dict[str, Any]]: + obs = self._env.reset() + self._max_inventory = np.zeros(N_ALL_ITEMS) + self._sticky_attack_counter = 0 + self._sticky_jump_counter = 0 + self._pos = { + "pitch": 0.0, + "yaw": 0.0, + } + return self._convert_obs(obs), {} + + def render(self, mode: Optional[str] = "rgb_array"): + return self._env.render(self.render_mode) + + def close(self): + self._env.close() + return super().close() diff --git a/sheeprl/envs/minerl_envs/__init__.py b/sheeprl/envs/minerl_envs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sheeprl/envs/minerl_envs/backend.py b/sheeprl/envs/minerl_envs/backend.py new file mode 100644 index 00000000..03bbb563 --- /dev/null +++ b/sheeprl/envs/minerl_envs/backend.py @@ -0,0 +1,56 @@ +# adapted from https://github.com/minerllabs/minerl + +from abc import ABC +from typing import List + +from minerl.herobraine.env_spec import EnvSpec +from minerl.herobraine.hero import handler, handlers +from minerl.herobraine.hero.handlers.translation import TranslationHandler +from minerl.herobraine.hero.mc import INVERSE_KEYMAP + +SIMPLE_KEYBOARD_ACTION = ["forward", "back", "left", "right", "jump", "sneak", "sprint", "attack"] + + +class CustomSimpleEmbodimentEnvSpec(EnvSpec, ABC): + """ + A simple base environment from which all other simple envs inherit. + """ + + def __init__(self, name, *args, resolution=(64, 64), break_speed: int = 100, **kwargs): + self.resolution = resolution + self.break_speed = break_speed + super().__init__(name, *args, **kwargs) + + def create_agent_start(self): + return [BreakSpeedMultiplier(self.break_speed)] + + def create_observables(self) -> List[TranslationHandler]: + return [ + handlers.POVObservation(self.resolution), + handlers.ObservationFromCurrentLocation(), + handlers.ObservationFromLifeStats(), + ] + + def create_actionables(self) -> List[TranslationHandler]: + """ + Simple envs have some basic keyboard control functionality, but + not all. + """ + return [ + handlers.KeybasedCommandAction(k, v) for k, v in INVERSE_KEYMAP.items() if k in SIMPLE_KEYBOARD_ACTION + ] + [handlers.CameraAction()] + + def create_monitors(self) -> List[TranslationHandler]: + return [] # No base monitor needed + + +# adapted from https://github.com/danijar/diamond_env +class BreakSpeedMultiplier(handler.Handler): + def __init__(self, multiplier=1.0): + self.multiplier = multiplier + + def to_string(self): + return f"break_speed({self.multiplier})" + + def xml_template(self): + return "{{multiplier}}" diff --git a/sheeprl/envs/minerl_envs/navigate.py b/sheeprl/envs/minerl_envs/navigate.py new file mode 100644 index 00000000..379b058b --- /dev/null +++ b/sheeprl/envs/minerl_envs/navigate.py @@ -0,0 +1,124 @@ +# adapted from https://github.com/minerllabs/minerl + +from typing import List + +import minerl.herobraine.hero.handlers as handlers +from minerl.herobraine.hero.handler import Handler +from minerl.herobraine.hero.mc import MS_PER_STEP + +from sheeprl.envs.minerl_envs.backend import CustomSimpleEmbodimentEnvSpec + +NAVIGATE_STEPS = 6000 + + +class CustomNavigate(CustomSimpleEmbodimentEnvSpec): + def __init__(self, dense, extreme, *args, **kwargs): + suffix = "Extreme" if extreme else "" + suffix += "Dense" if dense else "" + name = "CustomMineRLNavigate{}-v0".format(suffix) + self.dense, self.extreme = dense, extreme + super().__init__(name, *args, max_episode_steps=6000, **kwargs) + + def is_from_folder(self, folder: str) -> bool: + return folder == "navigateextreme" if self.extreme else folder == "navigate" + + def create_observables(self) -> List[Handler]: + return super().create_observables() + [ + handlers.CompassObservation(angle=True, distance=False), + handlers.FlatInventoryObservation(["dirt"]), + ] + + def create_actionables(self) -> List[Handler]: + return super().create_actionables() + [handlers.PlaceBlock(["none", "dirt"], _other="none", _default="none")] + + # john rl nyu microsfot van roy and ian osband + + def create_rewardables(self) -> List[Handler]: + return [ + handlers.RewardForTouchingBlockType( + [ + {"type": "diamond_block", "behaviour": "onceOnly", "reward": 100.0}, + ] + ) + ] + ([handlers.RewardForDistanceTraveledToCompassTarget(reward_per_block=1.0)] if self.dense else []) + + def create_agent_start(self) -> List[Handler]: + return super().create_agent_start() + [handlers.SimpleInventoryAgentStart([dict(type="compass", quantity="1")])] + + def create_agent_handlers(self) -> List[Handler]: + return [handlers.AgentQuitFromTouchingBlockType(["diamond_block"])] + + def create_server_world_generators(self) -> List[Handler]: + if self.extreme: + return [handlers.BiomeGenerator(biome=3, force_reset=True)] + else: + return [handlers.DefaultWorldGenerator(force_reset=True)] + + def create_server_quit_producers(self) -> List[Handler]: + return [handlers.ServerQuitFromTimeUp(NAVIGATE_STEPS * MS_PER_STEP), handlers.ServerQuitWhenAnyAgentFinishes()] + + def create_server_decorators(self) -> List[Handler]: + return [ + handlers.NavigationDecorator( + max_randomized_radius=64, + min_randomized_radius=64, + block="diamond_block", + placement="surface", + max_radius=8, + min_radius=0, + max_randomized_distance=8, + min_randomized_distance=0, + randomize_compass_location=True, + ) + ] + + def create_server_initial_conditions(self) -> List[Handler]: + return [ + handlers.TimeInitialCondition(allow_passage_of_time=False, start_time=6000), + handlers.WeatherInitialCondition("clear"), + handlers.SpawningInitialCondition("false"), + ] + + def get_docstring(self): + return make_navigate_text(top="normal" if not self.extreme else "extreme", dense=self.dense) + + def determine_success_from_rewards(self, rewards: list) -> bool: + reward_threshold = 100.0 + if self.dense: + reward_threshold += 60 + return sum(rewards) >= reward_threshold + + +def make_navigate_text(top, dense): + navigate_text = """ +.. image:: ../assets/navigate{}1.mp4.gif + :scale: 100 % + :alt: + +.. image:: ../assets/navigate{}2.mp4.gif + :scale: 100 % + :alt: + +.. image:: ../assets/navigate{}3.mp4.gif + :scale: 100 % + :alt: + +.. image:: ../assets/navigate{}4.mp4.gif + :scale: 100 % + :alt: + +In this task, the agent must move to a goal location denoted by a diamond block. This represents a basic primitive used in many tasks throughout Minecraft. In addition to standard observations, the agent has access to a “compass” observation, which points near the goal location, 64 meters from the start location. The goal has a small random horizontal offset from the compass location and may be slightly below surface level. On the goal location is a unique block, so the agent must find the final goal by searching based on local visual features. + +The agent is given a sparse reward (+100 upon reaching the goal, at which point the episode terminates). """ + if dense: + navigate_text += "**This variant of the environment is dense reward-shaped where the agent is given a reward every tick for how much closer (or negative reward for farther) the agent gets to the target.**\n" + else: + navigate_text += "**This variant of the environment is sparse.**\n" + + if top == "normal": + navigate_text += "\nIn this environment, the agent spawns on a random survival map.\n" + navigate_text = navigate_text.format(*["" for _ in range(4)]) + else: + navigate_text += "\nIn this environment, the agent spawns in an extreme hills biome.\n" + navigate_text = navigate_text.format(*["extreme" for _ in range(4)]) + return navigate_text diff --git a/sheeprl/envs/minerl_envs/obtain.py b/sheeprl/envs/minerl_envs/obtain.py new file mode 100644 index 00000000..868340f0 --- /dev/null +++ b/sheeprl/envs/minerl_envs/obtain.py @@ -0,0 +1,301 @@ +# adapted from https://github.com/minerllabs/minerl + +from typing import Dict, List, Union + +from minerl.herobraine.hero import handlers, mc +from minerl.herobraine.hero.handler import Handler +from minerl.herobraine.hero.mc import MS_PER_STEP + +from sheeprl.envs.minerl_envs.backend import CustomSimpleEmbodimentEnvSpec + +none = "none" +other = "other" + + +def snake_to_camel(word): + return "".join(x.capitalize() or "_" for x in word.split("_")) + + +class CustomObtain(CustomSimpleEmbodimentEnvSpec): + def __init__( + self, + target_item, + dense, + reward_schedule: List[Dict[str, Union[str, int, float]]], + *args, + max_episode_steps=6000, + **kwargs, + ): + # 6000 for obtain iron (5 mins) + # 18000 for obtain diamond (15 mins) + self.target_item = target_item + self.dense = dense + suffix = snake_to_camel(self.target_item) + dense_suffix = "Dense" if self.dense else "" + if self.dense: + self.reward_text = "every time it obtains an item" + else: + self.reward_text = "only once per item the first time it obtains that item" + self.reward_schedule = reward_schedule + + super().__init__( + *args, + name="CustomMineRLObtain{}{}-v0".format(suffix, dense_suffix), + max_episode_steps=max_episode_steps, + **kwargs, + ) + + def create_observables(self) -> List[Handler]: + # TODO: Parameterize these observations. + return super().create_observables() + [ + handlers.FlatInventoryObservation( + [ + "dirt", + "coal", + "torch", + "log", + "planks", + "stick", + "crafting_table", + "wooden_axe", + "wooden_pickaxe", + "stone", + "cobblestone", + "furnace", + "stone_axe", + "stone_pickaxe", + "iron_ore", + "iron_ingot", + "iron_axe", + "iron_pickaxe", + ] + ), + handlers.EquippedItemObservation(items=mc.ALL_ITEMS, _default="air", _other=other), + ] + + def create_actionables(self): + # TODO (R): MineRL-v1 use invalid (for data) + return super().create_actionables() + [ + handlers.PlaceBlock( + [none, "dirt", "stone", "cobblestone", "crafting_table", "furnace", "torch"], _other=none, _default=none + ), + handlers.EquipAction( + [none, "air", "wooden_axe", "wooden_pickaxe", "stone_axe", "stone_pickaxe", "iron_axe", "iron_pickaxe"], + _other=none, + _default=none, + ), + handlers.CraftAction([none, "torch", "stick", "planks", "crafting_table"], _other=none, _default=none), + handlers.CraftNearbyAction( + [ + none, + "wooden_axe", + "wooden_pickaxe", + "stone_axe", + "stone_pickaxe", + "iron_axe", + "iron_pickaxe", + "furnace", + ], + _other=none, + _default=none, + ), + handlers.SmeltItemNearby([none, "iron_ingot", "coal"], _other=none, _default=none), + # As apart of pervious todo + # this should be handlers.SmeltItem([none, 'iron_ingot', 'coal', other]), but this is not supported by mineRL-v0 + ] + + def create_rewardables(self) -> List[Handler]: + reward_handler = handlers.RewardForCollectingItems if self.dense else handlers.RewardForCollectingItemsOnce + + return [reward_handler(self.reward_schedule if self.reward_schedule else {self.target_item: 1})] + + def create_agent_start(self) -> List[Handler]: + return super().create_agent_start() + + def create_agent_handlers(self) -> List[Handler]: + return [handlers.AgentQuitFromPossessingItem([dict(type="diamond", amount=1)])] + + def create_server_world_generators(self) -> List[Handler]: + return [handlers.DefaultWorldGenerator(force_reset=True)] + + def create_server_quit_producers(self) -> List[Handler]: + return [ + handlers.ServerQuitFromTimeUp(time_limit_ms=self.max_episode_steps * MS_PER_STEP), + handlers.ServerQuitWhenAnyAgentFinishes(), + ] + + def create_server_decorators(self) -> List[Handler]: + return [] + + def create_server_initial_conditions(self) -> List[Handler]: + return [ + handlers.TimeInitialCondition( + start_time=6000, + allow_passage_of_time=True, + ), + handlers.SpawningInitialCondition(allow_spawning=True), + ] + + def is_from_folder(self, folder: str): + return folder == "o_{}".format(self.target_item) + + def get_docstring(self): + return "" + + def determine_success_from_rewards(self, rewards: list) -> bool: + # TODO: Convert this to finish handlers. + rewards = set(rewards) + allow_missing_ratio = 0.1 + max_missing = round(len(self.reward_schedule) * allow_missing_ratio) + + # Get a list of the rewards from the reward_schedule. + reward_values = [s["reward"] for s in self.reward_schedule] + + return len(rewards.intersection(reward_values)) >= len(reward_values) - max_missing + + +class CustomObtainDiamond(CustomObtain): + def __init__(self, dense, *args, **kwargs): + super(CustomObtainDiamond, self).__init__( + *args, + target_item="diamond", + dense=dense, + reward_schedule=[ + dict(type="log", amount=1, reward=1), + dict(type="planks", amount=1, reward=2), + dict(type="stick", amount=1, reward=4), + dict(type="crafting_table", amount=1, reward=4), + dict(type="wooden_pickaxe", amount=1, reward=8), + dict(type="cobblestone", amount=1, reward=16), + dict(type="furnace", amount=1, reward=32), + dict(type="stone_pickaxe", amount=1, reward=32), + dict(type="iron_ore", amount=1, reward=64), + dict(type="iron_ingot", amount=1, reward=128), + dict(type="iron_pickaxe", amount=1, reward=256), + dict(type="diamond", amount=1, reward=1024), + ], + max_episode_steps=18000, + **kwargs, + ) + + def is_from_folder(self, folder: str) -> bool: + return folder == "o_dia" + + def get_docstring(self): + return ( + """ +.. image:: ../assets/odia1.mp4.gif + :scale: 100 % + :alt: + +.. image:: ../assets/odia2.mp4.gif + :scale: 100 % + :alt: + +.. image:: ../assets/odia3.mp4.gif + :scale: 100 % + :alt: + +.. image:: ../assets/odia4.mp4.gif + :scale: 100 % + :alt: + +In this environment the agent is required to obtain a diamond. +The agent begins in a random starting location on a random survival map without any items, matching the normal starting conditions for human players in Minecraft. +The agent is given access to a selected summary of its inventory and GUI free +crafting, smelting, and inventory management actions. + + +During an episode **the agent is rewarded """ + + self.reward_text + + """** +in the requisite item hierarchy to obtaining a diamond. The rewards for each +item are given here:: + + + + + + + + + + + + + + +\n""" + ) + + +class CustomObtainIronPickaxe(CustomObtain): + def __init__(self, dense, *args, **kwargs): + super(CustomObtainIronPickaxe, self).__init__( + *args, + target_item="iron_pickaxe", + dense=dense, + reward_schedule=[ + dict(type="log", amount=1, reward=1), + dict(type="planks", amount=1, reward=2), + dict(type="stick", amount=1, reward=4), + dict(type="crafting_table", amount=1, reward=4), + dict(type="wooden_pickaxe", amount=1, reward=8), + dict(type="cobblestone", amount=1, reward=16), + dict(type="furnace", amount=1, reward=32), + dict(type="stone_pickaxe", amount=1, reward=32), + dict(type="iron_ore", amount=1, reward=64), + dict(type="iron_ingot", amount=1, reward=128), + dict(type="iron_pickaxe", amount=1, reward=256), + ], + **kwargs, + ) + + def create_agent_handlers(self): + return [handlers.AgentQuitFromCraftingItem([dict(type="iron_pickaxe", amount=1)])] + + def is_from_folder(self, folder: str) -> bool: + return folder == "o_iron" + + def get_docstring(self): + return ( + """ +.. image:: ../assets/orion1.mp4.gif + :scale: 100 % + :alt: + +.. image:: ../assets/orion2.mp4.gif + :scale: 100 % + :alt: + +.. image:: ../assets/orion3.mp4.gif + :scale: 100 % + :alt: + +.. image:: ../assets/orion4.mp4.gif + :scale: 100 % + :alt: +In this environment the agent is required to obtain an iron pickaxe. The agent begins in a random starting location, on a random survival map, without any items, matching the normal starting conditions for human players in Minecraft. +The agent is given access to a selected view of its inventory and GUI free +crafting, smelting, and inventory management actions. + + +During an episode **the agent is rewarded """ + + self.reward_text + + """** +in the requisite item hierarchy for obtaining an iron pickaxe. The reward for each +item is given here:: + + + + + + + + + + + + +\n""" + )