Skip to content

Commit

Permalink
updated standalone ppo
Browse files Browse the repository at this point in the history
  • Loading branch information
Howuhh committed Mar 24, 2024
1 parent 8127fe9 commit a61adac
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 40 deletions.
52 changes: 33 additions & 19 deletions examples/train_meta_standalone.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -143,40 +143,48 @@
" RNNModel, variable_axes={\"params\": None}, split_rngs={\"params\": False}, axis_name=\"batch\"\n",
")\n",
"\n",
"class MaxPool2d(nn.Module):\n",
" kernel_size: tuple[int, int]\n",
"\n",
" @nn.compact\n",
" def __call__(self, x):\n",
" return nn.max_pool(inputs=x, window_shape=self.kernel_size, strides=self.kernel_size, padding=\"VALID\")\n",
"\n",
"class ActorCriticInput(TypedDict):\n",
" observation: jax.Array\n",
" prev_action: jax.Array\n",
" prev_reward: jax.Array\n",
"\n",
"\n",
"class ActorCriticRNN(nn.Module):\n",
" num_actions: int\n",
" action_emb_dim: int = 16\n",
" rnn_hidden_dim: int = 64\n",
" rnn_num_layers: int = 1\n",
" head_hidden_dim: int = 64\n",
" img_obs: bool = False\n",
"\n",
" @nn.compact\n",
" def __call__(self, inputs: ActorCriticInput, hidden: jax.Array) -> tuple[distrax.Categorical, jax.Array, jax.Array]:\n",
" B, S = inputs[\"observation\"].shape[:2]\n",
" # encoder from https://github.com/lcswillems/rl-starter-files/blob/master/model.py\n",
" img_encoder = nn.Sequential(\n",
" [\n",
" nn.Conv(16, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" MaxPool2d((2, 2)),\n",
" nn.Conv(32, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" nn.Conv(64, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" ]\n",
" )\n",
" if self.img_obs:\n",
" img_encoder = nn.Sequential(\n",
" [\n",
" nn.Conv(16, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" nn.Conv(32, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" nn.Conv(32, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" nn.Conv(32, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" ]\n",
" )\n",
" else:\n",
" img_encoder = nn.Sequential(\n",
" [\n",
" nn.Conv(16, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" nn.Conv(32, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" nn.Conv(64, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" ]\n",
" )\n",
" action_encoder = nn.Embed(self.num_actions, self.action_emb_dim)\n",
"\n",
" rnn_core = BatchedRNNModel(self.rnn_hidden_dim, self.rnn_num_layers)\n",
Expand Down Expand Up @@ -294,8 +302,6 @@
" value_loss = jnp.square(value - targets)\n",
" value_loss_clipped = jnp.square(value_pred_clipped - targets)\n",
" value_loss = 0.5 * jnp.maximum(value_loss, value_loss_clipped).mean()\n",
" # TODO: ablate this!\n",
" # value_loss = jnp.square(value - targets).mean()\n",
"\n",
" # CALCULATE ACTOR LOSS\n",
" ratio = jnp.exp(log_prob - transitions.log_prob)\n",
Expand Down Expand Up @@ -391,6 +397,7 @@
"class TrainConfig:\n",
" env_id: str = \"XLand-MiniGrid-R1-8x8\"\n",
" benchmark_id: str = \"trivial-1m\"\n",
" img_obs: bool = False\n",
" # agent\n",
" action_emb_dim: int = 16\n",
" rnn_hidden_dim: int = 64\n",
Expand Down Expand Up @@ -444,6 +451,12 @@
" env, env_params = xminigrid.make(config.env_id)\n",
" env = GymAutoResetWrapper(env)\n",
"\n",
" # enabling image observations if needed\n",
" if config.img_obs:\n",
" from xminigrid.experimental.img_obs import RGBImgObservationWrapper\n",
"\n",
" env = RGBImgObservationWrapper(env)\n",
" \n",
" # loading benchmark\n",
" benchmark = xminigrid.load_benchmark(config.benchmark_id)\n",
"\n",
Expand All @@ -457,6 +470,7 @@
" rnn_hidden_dim=config.rnn_hidden_dim,\n",
" rnn_num_layers=config.rnn_num_layers,\n",
" head_hidden_dim=config.head_hidden_dim,\n",
" img_obs=config.img_obs,\n",
" )\n",
" # [batch_size, seq_len, ...]\n",
" init_obs = {\n",
Expand Down
64 changes: 43 additions & 21 deletions examples/train_single_standalone.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
"source": [
"import time\n",
"import math\n",
"from typing import TypedDict\n",
"from typing import TypedDict, Optional\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
Expand Down Expand Up @@ -142,40 +142,48 @@
" RNNModel, variable_axes={\"params\": None}, split_rngs={\"params\": False}, axis_name=\"batch\"\n",
")\n",
"\n",
"class MaxPool2d(nn.Module):\n",
" kernel_size: tuple[int, int]\n",
"\n",
" @nn.compact\n",
" def __call__(self, x):\n",
" return nn.max_pool(inputs=x, window_shape=self.kernel_size, strides=self.kernel_size, padding=\"VALID\")\n",
"\n",
"class ActorCriticInput(TypedDict):\n",
" observation: jax.Array\n",
" prev_action: jax.Array\n",
" prev_reward: jax.Array\n",
"\n",
"\n",
"class ActorCriticRNN(nn.Module):\n",
" num_actions: int\n",
" action_emb_dim: int = 16\n",
" rnn_hidden_dim: int = 64\n",
" rnn_num_layers: int = 1\n",
" head_hidden_dim: int = 64\n",
" img_obs: bool = False\n",
"\n",
" @nn.compact\n",
" def __call__(self, inputs: ActorCriticInput, hidden: jax.Array) -> tuple[distrax.Categorical, jax.Array, jax.Array]:\n",
" B, S = inputs[\"observation\"].shape[:2]\n",
" # encoder from https://github.com/lcswillems/rl-starter-files/blob/master/model.py\n",
" img_encoder = nn.Sequential(\n",
" [\n",
" nn.Conv(16, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" MaxPool2d((2, 2)),\n",
" nn.Conv(32, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" nn.Conv(64, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" ]\n",
" )\n",
" if self.img_obs:\n",
" img_encoder = nn.Sequential(\n",
" [\n",
" nn.Conv(16, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" nn.Conv(32, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" nn.Conv(32, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" nn.Conv(32, (3, 3), strides=2, padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" ]\n",
" )\n",
" else:\n",
" img_encoder = nn.Sequential(\n",
" [\n",
" nn.Conv(16, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" nn.Conv(32, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" nn.Conv(64, (2, 2), padding=\"VALID\", kernel_init=orthogonal(math.sqrt(2))),\n",
" nn.relu,\n",
" ]\n",
" )\n",
" action_encoder = nn.Embed(self.num_actions, self.action_emb_dim)\n",
"\n",
" rnn_core = BatchedRNNModel(self.rnn_hidden_dim, self.rnn_num_layers)\n",
Expand Down Expand Up @@ -389,6 +397,9 @@
"@dataclass\n",
"class TrainConfig:\n",
" env_id: str = \"MiniGrid-Empty-6x6\"\n",
" benchmark_id: Optional[str] = None\n",
" ruleset_id: Optional[int] = None\n",
" img_obs: bool = False\n",
" # agent\n",
" action_emb_dim: int = 16\n",
" rnn_hidden_dim: int = 1024\n",
Expand Down Expand Up @@ -428,12 +439,22 @@
" return config.lr * frac\n",
"\n",
" # setup environment\n",
" if \"XLand-MiniGrid\" in config.env_id:\n",
" raise ValueError(\"Only single-task environments are supported.\")\n",
"\n",
" env, env_params = xminigrid.make(config.env_id)\n",
" env = GymAutoResetWrapper(env)\n",
"\n",
" # for single-task XLand environments\n",
" if config.benchmark_id is not None:\n",
" assert \"XLand-MiniGrid\" in config.env_id, \"Benchmarks should be used only with XLand environments.\"\n",
" assert config.ruleset_id is not None, \"Ruleset ID should be specified for benchmarks usage.\"\n",
" benchmark = xminigrid.load_benchmark(config.benchmark_id)\n",
" env_params = env_params.replace(ruleset=benchmark.get_ruleset(config.ruleset_id))\n",
"\n",
" # enabling image observations if needed\n",
" if config.img_obs:\n",
" from xminigrid.experimental.img_obs import RGBImgObservationWrapper\n",
"\n",
" env = RGBImgObservationWrapper(env)\n",
"\n",
" # setup training state\n",
" rng = jax.random.PRNGKey(config.seed)\n",
" rng, _rng = jax.random.split(rng)\n",
Expand All @@ -444,6 +465,7 @@
" rnn_hidden_dim=config.rnn_hidden_dim,\n",
" rnn_num_layers=config.rnn_num_layers,\n",
" head_hidden_dim=config.head_hidden_dim,\n",
" img_obs=config.img_obs,\n",
" )\n",
" # [batch_size, seq_len, ...]\n",
" init_obs = {\n",
Expand Down

0 comments on commit a61adac

Please sign in to comment.