Skip to content

Commit

Permalink
refine readme, cnn arch
Browse files Browse the repository at this point in the history
  • Loading branch information
Howuhh committed Mar 24, 2024
1 parent dc26721 commit 251301f
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ On the high level, current API combines [dm_env](https://github.com/google-deepm
import jax
import xminigrid
from xminigrid.wrappers import GymAutoResetWrapper
from xminigrid.experimental.img_obs import RGBImgObservationWrapper

key = jax.random.PRNGKey(0)
reset_key, ruleset_key = jax.random.split(key)
Expand All @@ -109,6 +110,9 @@ env_params = env_params.replace(ruleset=ruleset)
# auto-reset wrapper
env = GymAutoResetWrapper(env)

# render obs as rgb images if needed (warn: this will affect speed greatly)
env = RGBImgObservationWrapper(env)

# fully jit-compatible step and reset methods
timestep = jax.jit(env.reset)(env_params, reset_key)
timestep = jax.jit(env.step)(env_params, timestep, action=0)
Expand Down
7 changes: 5 additions & 2 deletions src/xminigrid/manual_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,10 @@ def close(self) -> None:
if self.video_format == ".mp4":
iio.imwrite(save_path, self.frames, format_hint=".mp4", fps=self.video_fps)
elif self.video_format == ".gif":
iio.imwrite(save_path, self.frames, format_hint=".gif", duration=(1000 * 1 / self.video_fps), loop=10)
iio.imwrite(
save_path, self.frames[:-1], format_hint=".gif", duration=(1000 * 1 / self.video_fps), loop=10
)
# iio.imwrite(save_path, self.frames, format_hint=".gif", duration=(1000 * 1 / self.video_fps), loop=10)
else:
raise RuntimeError("Unknown video format! Should be one of ('.mp4', '.gif')")

Expand All @@ -189,7 +192,7 @@ def close(self) -> None:
parser.add_argument("--save-video", action="store_true")
parser.add_argument("--video-path", type=str, default=".")
parser.add_argument("--video-format", type=str, default=".mp4", choices=(".mp4", ".gif"))
parser.add_argument("--video-fps", type=int, default=8)
parser.add_argument("--video-fps", type=int, default=5)

args = parser.parse_args()
env, env_params = xminigrid.make(args.env_id)
Expand Down
9 changes: 4 additions & 5 deletions training/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,15 @@ def __call__(self, inputs: ActorCriticInput, hidden: jax.Array) -> tuple[distrax
B, S = inputs["observation"].shape[:2]
# encoder from https://github.com/lcswillems/rl-starter-files/blob/master/model.py
if self.img_obs:
# slight modification of NatureDQN CNN
img_encoder = nn.Sequential(
[
nn.Conv(32, (8, 8), strides=4, padding="VALID", kernel_init=orthogonal(math.sqrt(2))),
nn.Conv(16, (3, 3), strides=2, padding="VALID", kernel_init=orthogonal(math.sqrt(2))),
nn.relu,
nn.Conv(64, (4, 4), strides=3, padding="VALID", kernel_init=orthogonal(math.sqrt(2))),
nn.Conv(32, (3, 3), strides=2, padding="VALID", kernel_init=orthogonal(math.sqrt(2))),
nn.relu,
nn.Conv(64, (3, 3), strides=2, padding="VALID", kernel_init=orthogonal(math.sqrt(2))),
nn.Conv(32, (3, 3), strides=2, padding="VALID", kernel_init=orthogonal(math.sqrt(2))),
nn.relu,
nn.Conv(64, (2, 2), strides=1, padding="VALID", kernel_init=orthogonal(math.sqrt(2))),
nn.Conv(32, (3, 3), strides=2, padding="VALID", kernel_init=orthogonal(math.sqrt(2))),
]
)
else:
Expand Down

0 comments on commit 251301f

Please sign in to comment.