diff --git a/examples/train_meta_standalone.ipynb b/examples/train_meta_standalone.ipynb index ad3f253..ed911ac 100644 --- a/examples/train_meta_standalone.ipynb +++ b/examples/train_meta_standalone.ipynb @@ -143,40 +143,48 @@ " RNNModel, variable_axes={\"params\": None}, split_rngs={\"params\": False}, axis_name=\"batch\"\n", ")\n", "\n", - "class MaxPool2d(nn.Module):\n", - " kernel_size: tuple[int, int]\n", - "\n", - " @nn.compact\n", - " def __call__(self, x):\n", - " return nn.max_pool(inputs=x, window_shape=self.kernel_size, strides=self.kernel_size, padding=\"VALID\")\n", "\n", "class ActorCriticInput(TypedDict):\n", " observation: jax.Array\n", " prev_action: jax.Array\n", " prev_reward: jax.Array\n", "\n", + "\n", "class ActorCriticRNN(nn.Module):\n", " num_actions: int\n", " action_emb_dim: int = 16\n", " rnn_hidden_dim: int = 64\n", " rnn_num_layers: int = 1\n", " head_hidden_dim: int = 64\n", + " img_obs: bool = False\n", "\n", " @nn.compact\n", " def __call__(self, inputs: ActorCriticInput, hidden: jax.Array) -> tuple[distrax.Categorical, jax.Array, jax.Array]:\n", " B, S = inputs[\"observation\"].shape[:2]\n", " # encoder from https://github.com/lcswillems/rl-starter-files/blob/master/model.py\n", - " img_encoder = nn.Sequential(\n", - " [\n", - " nn.Conv(16, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", - " nn.relu,\n", - " MaxPool2d((2, 2)),\n", - " nn.Conv(32, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", - " nn.relu,\n", - " nn.Conv(64, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", - " nn.relu,\n", - " ]\n", - " )\n", + " if self.img_obs:\n", + " img_encoder = nn.Sequential(\n", + " [\n", + " nn.Conv(16, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", + " nn.relu,\n", + " nn.Conv(32, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", + " nn.relu,\n", + " nn.Conv(32, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", + " nn.relu,\n", + " nn.Conv(32, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", + " ]\n", + " )\n", + " else:\n", + " img_encoder = nn.Sequential(\n", + " [\n", + " nn.Conv(16, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", + " nn.relu,\n", + " nn.Conv(32, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", + " nn.relu,\n", + " nn.Conv(64, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", + " nn.relu,\n", + " ]\n", + " )\n", " action_encoder = nn.Embed(self.num_actions, self.action_emb_dim)\n", "\n", " rnn_core = BatchedRNNModel(self.rnn_hidden_dim, self.rnn_num_layers)\n", @@ -294,8 +302,6 @@ " value_loss = jnp.square(value - targets)\n", " value_loss_clipped = jnp.square(value_pred_clipped - targets)\n", " value_loss = 0.5 * jnp.maximum(value_loss, value_loss_clipped).mean()\n", - " # TODO: ablate this!\n", - " # value_loss = jnp.square(value - targets).mean()\n", "\n", " # CALCULATE ACTOR LOSS\n", " ratio = jnp.exp(log_prob - transitions.log_prob)\n", @@ -391,6 +397,7 @@ "class TrainConfig:\n", " env_id: str = \"XLand-MiniGrid-R1-8x8\"\n", " benchmark_id: str = \"trivial-1m\"\n", + " img_obs: bool = False\n", " # agent\n", " action_emb_dim: int = 16\n", " rnn_hidden_dim: int = 64\n", @@ -444,6 +451,12 @@ " env, env_params = xminigrid.make(config.env_id)\n", " env = GymAutoResetWrapper(env)\n", "\n", + " # enabling image observations if needed\n", + " if config.img_obs:\n", + " from xminigrid.experimental.img_obs import RGBImgObservationWrapper\n", + "\n", + " env = RGBImgObservationWrapper(env)\n", + " \n", " # loading benchmark\n", " benchmark = xminigrid.load_benchmark(config.benchmark_id)\n", "\n", @@ -457,6 +470,7 @@ " rnn_hidden_dim=config.rnn_hidden_dim,\n", " rnn_num_layers=config.rnn_num_layers,\n", " head_hidden_dim=config.head_hidden_dim,\n", + " img_obs=config.img_obs,\n", " )\n", " # [batch_size, seq_len, ...]\n", " init_obs = {\n", diff --git a/examples/train_single_standalone.ipynb b/examples/train_single_standalone.ipynb index fc33025..535f79f 100644 --- a/examples/train_single_standalone.ipynb +++ b/examples/train_single_standalone.ipynb @@ -51,7 +51,7 @@ "source": [ "import time\n", "import math\n", - "from typing import TypedDict\n", + "from typing import TypedDict, Optional\n", "\n", "import jax\n", "import jax.numpy as jnp\n", @@ -142,40 +142,48 @@ " RNNModel, variable_axes={\"params\": None}, split_rngs={\"params\": False}, axis_name=\"batch\"\n", ")\n", "\n", - "class MaxPool2d(nn.Module):\n", - " kernel_size: tuple[int, int]\n", - "\n", - " @nn.compact\n", - " def __call__(self, x):\n", - " return nn.max_pool(inputs=x, window_shape=self.kernel_size, strides=self.kernel_size, padding=\"VALID\")\n", "\n", "class ActorCriticInput(TypedDict):\n", " observation: jax.Array\n", " prev_action: jax.Array\n", " prev_reward: jax.Array\n", "\n", + "\n", "class ActorCriticRNN(nn.Module):\n", " num_actions: int\n", " action_emb_dim: int = 16\n", " rnn_hidden_dim: int = 64\n", " rnn_num_layers: int = 1\n", " head_hidden_dim: int = 64\n", + " img_obs: bool = False\n", "\n", " @nn.compact\n", " def __call__(self, inputs: ActorCriticInput, hidden: jax.Array) -> tuple[distrax.Categorical, jax.Array, jax.Array]:\n", " B, S = inputs[\"observation\"].shape[:2]\n", " # encoder from https://github.com/lcswillems/rl-starter-files/blob/master/model.py\n", - " img_encoder = nn.Sequential(\n", - " [\n", - " nn.Conv(16, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", - " nn.relu,\n", - " MaxPool2d((2, 2)),\n", - " nn.Conv(32, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", - " nn.relu,\n", - " nn.Conv(64, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", - " nn.relu,\n", - " ]\n", - " )\n", + " if self.img_obs:\n", + " img_encoder = nn.Sequential(\n", + " [\n", + " nn.Conv(16, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", + " nn.relu,\n", + " nn.Conv(32, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", + " nn.relu,\n", + " nn.Conv(32, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", + " nn.relu,\n", + " nn.Conv(32, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", + " ]\n", + " )\n", + " else:\n", + " img_encoder = nn.Sequential(\n", + " [\n", + " nn.Conv(16, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", + " nn.relu,\n", + " nn.Conv(32, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", + " nn.relu,\n", + " nn.Conv(64, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n", + " nn.relu,\n", + " ]\n", + " )\n", " action_encoder = nn.Embed(self.num_actions, self.action_emb_dim)\n", "\n", " rnn_core = BatchedRNNModel(self.rnn_hidden_dim, self.rnn_num_layers)\n", @@ -389,6 +397,9 @@ "@dataclass\n", "class TrainConfig:\n", " env_id: str = \"MiniGrid-Empty-6x6\"\n", + " benchmark_id: Optional[str] = None\n", + " ruleset_id: Optional[int] = None\n", + " img_obs: bool = False\n", " # agent\n", " action_emb_dim: int = 16\n", " rnn_hidden_dim: int = 1024\n", @@ -428,12 +439,22 @@ " return config.lr * frac\n", "\n", " # setup environment\n", - " if \"XLand-MiniGrid\" in config.env_id:\n", - " raise ValueError(\"Only single-task environments are supported.\")\n", - "\n", " env, env_params = xminigrid.make(config.env_id)\n", " env = GymAutoResetWrapper(env)\n", "\n", + " # for single-task XLand environments\n", + " if config.benchmark_id is not None:\n", + " assert \"XLand-MiniGrid\" in config.env_id, \"Benchmarks should be used only with XLand environments.\"\n", + " assert config.ruleset_id is not None, \"Ruleset ID should be specified for benchmarks usage.\"\n", + " benchmark = xminigrid.load_benchmark(config.benchmark_id)\n", + " env_params = env_params.replace(ruleset=benchmark.get_ruleset(config.ruleset_id))\n", + "\n", + " # enabling image observations if needed\n", + " if config.img_obs:\n", + " from xminigrid.experimental.img_obs import RGBImgObservationWrapper\n", + "\n", + " env = RGBImgObservationWrapper(env)\n", + "\n", " # setup training state\n", " rng = jax.random.PRNGKey(config.seed)\n", " rng, _rng = jax.random.split(rng)\n", @@ -444,6 +465,7 @@ " rnn_hidden_dim=config.rnn_hidden_dim,\n", " rnn_num_layers=config.rnn_num_layers,\n", " head_hidden_dim=config.head_hidden_dim,\n", + " img_obs=config.img_obs,\n", " )\n", " # [batch_size, seq_len, ...]\n", " init_obs = {\n", diff --git a/training/train_single_task.py b/training/train_single_task.py index e4064a9..70d1a36 100644 --- a/training/train_single_task.py +++ b/training/train_single_task.py @@ -4,6 +4,7 @@ import time from dataclasses import asdict, dataclass from functools import partial +from typing import Optional import jax import jax.numpy as jnp @@ -29,6 +30,8 @@ class TrainConfig: group: str = "default" name: str = "single-task-ppo" env_id: str = "MiniGrid-Empty-6x6" + benchmark_id: Optional[str] = None + ruleset_id: Optional[int] = None img_obs: bool = False # agent action_emb_dim: int = 16 @@ -69,12 +72,16 @@ def linear_schedule(count): return config.lr * frac # setup environment - if "XLand-MiniGrid" in config.env_id: - raise ValueError("Only single-task environments are supported.") - env, env_params = xminigrid.make(config.env_id) env = GymAutoResetWrapper(env) + # for single-task XLand environments + if config.benchmark_id is not None: + assert "XLand-MiniGrid" in config.env_id, "Benchmarks should be used only with XLand environments." + assert config.ruleset_id is not None, "Ruleset ID should be specified for benchmarks usage." + benchmark = xminigrid.load_benchmark(config.benchmark_id) + env_params = env_params.replace(ruleset=benchmark.get_ruleset(config.ruleset_id)) + # enabling image observations if needed if config.img_obs: from xminigrid.experimental.img_obs import RGBImgObservationWrapper