Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow XLand benchmarks as a single task envs in single-task ppo #13

Merged
merged 2 commits into from
Mar 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 33 additions & 19 deletions examples/train_meta_standalone.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -143,40 +143,48 @@
" RNNModel, variable_axes={\"params\": None}, split_rngs={\"params\": False}, axis_name=\"batch\"\n",
")\n",
"\n",
"class MaxPool2d(nn.Module):\n",
" kernel_size: tuple[int, int]\n",
"\n",
" @nn.compact\n",
" def __call__(self, x):\n",
" return nn.max_pool(inputs=x, window_shape=self.kernel_size, strides=self.kernel_size, padding=\"VALID\")\n",
"\n",
"class ActorCriticInput(TypedDict):\n",
" observation: jax.Array\n",
" prev_action: jax.Array\n",
" prev_reward: jax.Array\n",
"\n",
"\n",
"class ActorCriticRNN(nn.Module):\n",
" num_actions: int\n",
" action_emb_dim: int = 16\n",
" rnn_hidden_dim: int = 64\n",
" rnn_num_layers: int = 1\n",
" head_hidden_dim: int = 64\n",
" img_obs: bool = False\n",
"\n",
" @nn.compact\n",
" def __call__(self, inputs: ActorCriticInput, hidden: jax.Array) -> tuple[distrax.Categorical, jax.Array, jax.Array]:\n",
" B, S = inputs[\"observation\"].shape[:2]\n",
" # encoder from https://github.com/lcswillems/rl-starter-files/blob/master/model.py\n",
" img_encoder = nn.Sequential(\n",
" [\n",
" nn.Conv(16, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" MaxPool2d((2, 2)),\n",
" nn.Conv(32, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" nn.Conv(64, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" ]\n",
" )\n",
" if self.img_obs:\n",
" img_encoder = nn.Sequential(\n",
" [\n",
" nn.Conv(16, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" nn.Conv(32, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" nn.Conv(32, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" nn.Conv(32, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" ]\n",
" )\n",
" else:\n",
" img_encoder = nn.Sequential(\n",
" [\n",
" nn.Conv(16, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" nn.Conv(32, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" nn.Conv(64, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" ]\n",
" )\n",
" action_encoder = nn.Embed(self.num_actions, self.action_emb_dim)\n",
"\n",
" rnn_core = BatchedRNNModel(self.rnn_hidden_dim, self.rnn_num_layers)\n",
Expand Down Expand Up @@ -294,8 +302,6 @@
" value_loss = jnp.square(value - targets)\n",
" value_loss_clipped = jnp.square(value_pred_clipped - targets)\n",
" value_loss = 0.5 * jnp.maximum(value_loss, value_loss_clipped).mean()\n",
" # TODO: ablate this!\n",
" # value_loss = jnp.square(value - targets).mean()\n",
"\n",
" # CALCULATE ACTOR LOSS\n",
" ratio = jnp.exp(log_prob - transitions.log_prob)\n",
Expand Down Expand Up @@ -391,6 +397,7 @@
"class TrainConfig:\n",
" env_id: str = \"XLand-MiniGrid-R1-8x8\"\n",
" benchmark_id: str = \"trivial-1m\"\n",
" img_obs: bool = False\n",
" # agent\n",
" action_emb_dim: int = 16\n",
" rnn_hidden_dim: int = 64\n",
Expand Down Expand Up @@ -444,6 +451,12 @@
" env, env_params = xminigrid.make(config.env_id)\n",
" env = GymAutoResetWrapper(env)\n",
"\n",
" # enabling image observations if needed\n",
" if config.img_obs:\n",
" from xminigrid.experimental.img_obs import RGBImgObservationWrapper\n",
"\n",
" env = RGBImgObservationWrapper(env)\n",
" \n",
" # loading benchmark\n",
" benchmark = xminigrid.load_benchmark(config.benchmark_id)\n",
"\n",
Expand All @@ -457,6 +470,7 @@
" rnn_hidden_dim=config.rnn_hidden_dim,\n",
" rnn_num_layers=config.rnn_num_layers,\n",
" head_hidden_dim=config.head_hidden_dim,\n",
" img_obs=config.img_obs,\n",
" )\n",
" # [batch_size, seq_len, ...]\n",
" init_obs = {\n",
Expand Down
64 changes: 43 additions & 21 deletions examples/train_single_standalone.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
"source": [
"import time\n",
"import math\n",
"from typing import TypedDict\n",
"from typing import TypedDict, Optional\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
Expand Down Expand Up @@ -142,40 +142,48 @@
" RNNModel, variable_axes={\"params\": None}, split_rngs={\"params\": False}, axis_name=\"batch\"\n",
")\n",
"\n",
"class MaxPool2d(nn.Module):\n",
" kernel_size: tuple[int, int]\n",
"\n",
" @nn.compact\n",
" def __call__(self, x):\n",
" return nn.max_pool(inputs=x, window_shape=self.kernel_size, strides=self.kernel_size, padding=\"VALID\")\n",
"\n",
"class ActorCriticInput(TypedDict):\n",
" observation: jax.Array\n",
" prev_action: jax.Array\n",
" prev_reward: jax.Array\n",
"\n",
"\n",
"class ActorCriticRNN(nn.Module):\n",
" num_actions: int\n",
" action_emb_dim: int = 16\n",
" rnn_hidden_dim: int = 64\n",
" rnn_num_layers: int = 1\n",
" head_hidden_dim: int = 64\n",
" img_obs: bool = False\n",
"\n",
" @nn.compact\n",
" def __call__(self, inputs: ActorCriticInput, hidden: jax.Array) -> tuple[distrax.Categorical, jax.Array, jax.Array]:\n",
" B, S = inputs[\"observation\"].shape[:2]\n",
" # encoder from https://github.com/lcswillems/rl-starter-files/blob/master/model.py\n",
" img_encoder = nn.Sequential(\n",
" [\n",
" nn.Conv(16, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" MaxPool2d((2, 2)),\n",
" nn.Conv(32, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" nn.Conv(64, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" ]\n",
" )\n",
" if self.img_obs:\n",
" img_encoder = nn.Sequential(\n",
" [\n",
" nn.Conv(16, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" nn.Conv(32, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" nn.Conv(32, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" nn.Conv(32, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" ]\n",
" )\n",
" else:\n",
" img_encoder = nn.Sequential(\n",
" [\n",
" nn.Conv(16, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" nn.Conv(32, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" nn.Conv(64, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" ]\n",
" )\n",
" action_encoder = nn.Embed(self.num_actions, self.action_emb_dim)\n",
"\n",
" rnn_core = BatchedRNNModel(self.rnn_hidden_dim, self.rnn_num_layers)\n",
Expand Down Expand Up @@ -389,6 +397,9 @@
"@dataclass\n",
"class TrainConfig:\n",
" env_id: str = \"MiniGrid-Empty-6x6\"\n",
" benchmark_id: Optional[str] = None\n",
" ruleset_id: Optional[int] = None\n",
" img_obs: bool = False\n",
" # agent\n",
" action_emb_dim: int = 16\n",
" rnn_hidden_dim: int = 1024\n",
Expand Down Expand Up @@ -428,12 +439,22 @@
" return config.lr * frac\n",
"\n",
" # setup environment\n",
" if \"XLand-MiniGrid\" in config.env_id:\n",
" raise ValueError(\"Only single-task environments are supported.\")\n",
"\n",
" env, env_params = xminigrid.make(config.env_id)\n",
" env = GymAutoResetWrapper(env)\n",
"\n",
" # for single-task XLand environments\n",
" if config.benchmark_id is not None:\n",
" assert \"XLand-MiniGrid\" in config.env_id, \"Benchmarks should be used only with XLand environments.\"\n",
" assert config.ruleset_id is not None, \"Ruleset ID should be specified for benchmarks usage.\"\n",
" benchmark = xminigrid.load_benchmark(config.benchmark_id)\n",
" env_params = env_params.replace(ruleset=benchmark.get_ruleset(config.ruleset_id))\n",
"\n",
" # enabling image observations if needed\n",
" if config.img_obs:\n",
" from xminigrid.experimental.img_obs import RGBImgObservationWrapper\n",
"\n",
" env = RGBImgObservationWrapper(env)\n",
"\n",
" # setup training state\n",
" rng = jax.random.PRNGKey(config.seed)\n",
" rng, _rng = jax.random.split(rng)\n",
Expand All @@ -444,6 +465,7 @@
" rnn_hidden_dim=config.rnn_hidden_dim,\n",
" rnn_num_layers=config.rnn_num_layers,\n",
" head_hidden_dim=config.head_hidden_dim,\n",
" img_obs=config.img_obs,\n",
" )\n",
" # [batch_size, seq_len, ...]\n",
" init_obs = {\n",
Expand Down
13 changes: 10 additions & 3 deletions training/train_single_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
from dataclasses import asdict, dataclass
from functools import partial
from typing import Optional

import jax
import jax.numpy as jnp
Expand All @@ -29,6 +30,8 @@ class TrainConfig:
group: str = "default"
name: str = "single-task-ppo"
env_id: str = "MiniGrid-Empty-6x6"
benchmark_id: Optional[str] = None
ruleset_id: Optional[int] = None
img_obs: bool = False
# agent
action_emb_dim: int = 16
Expand Down Expand Up @@ -69,12 +72,16 @@ def linear_schedule(count):
return config.lr * frac

# setup environment
if "XLand-MiniGrid" in config.env_id:
raise ValueError("Only single-task environments are supported.")

env, env_params = xminigrid.make(config.env_id)
env = GymAutoResetWrapper(env)

# for single-task XLand environments
if config.benchmark_id is not None:
assert "XLand-MiniGrid" in config.env_id, "Benchmarks should be used only with XLand environments."
assert config.ruleset_id is not None, "Ruleset ID should be specified for benchmarks usage."
benchmark = xminigrid.load_benchmark(config.benchmark_id)
env_params = env_params.replace(ruleset=benchmark.get_ruleset(config.ruleset_id))

# enabling image observations if needed
if config.img_obs:
from xminigrid.experimental.img_obs import RGBImgObservationWrapper
Expand Down
Loading