diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 000000000..7fdcfaab6 --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,51 @@ +cff-version: 1.2.0 +title: Minigrid +message: >- + If you use this software, please cite it using the + metadata from this file. +authors: +- family-names: Chevalier-Boisvert + given-names: Maxime +- family-names: Dai + given-names: Bolun +- family-names: Towers + given-names: Mark +- family-names: de Lazcano + given-names: Rodrigo +- family-names: Willems + given-names: Lucas +- family-names: Lahlou + given-names: Salem +- family-names: Pal + given-names: Suman +- family-names: Castro + given-names: Pablo Samuel +- family-names: Terry + given-names: Jordan +url: "https://github.com/Farama-Foundation/Minigrid" + +preferred-citation: + type: article + authors: + - family-names: Chevalier-Boisvert + given-names: Maxime + - family-names: Dai + given-names: Bolun + - family-names: Towers + given-names: Mark + - family-names: de Lazcano + given-names: Rodrigo + - family-names: Willems + given-names: Lucas + - family-names: Lahlou + given-names: Salem + - family-names: Pal + given-names: Suman + - family-names: Castro + given-names: Pablo Samuel + - family-names: Terry + given-names: Jordan + journal: CoRR + title: Minigrid + volume: abs/2306.13831 + year: 2023 diff --git a/docs/content/training.md b/docs/content/training.md new file mode 100644 index 000000000..5817b0849 --- /dev/null +++ b/docs/content/training.md @@ -0,0 +1,67 @@ +--- +layout: "contents" +title: Training Minigrid Environments +firstpage: +--- + +## Training Minigrid Environments + +The environments in the Minigrid library can be trained easily using [StableBaselines3](https://stable-baselines3.readthedocs.io/en/master/). In this tutorial we show how a PPO agent can be trained on the `MiniGrid-Empty-16x16-v0` environment. + +## Create Custom Feature Extractor + +Although `StableBaselines3` is fully compatible with `Gymnasium`-based environments, including Minigrid, the default CNN architecture does not directly support the Minigrid observation space. Thus, to train an agent on Minigrid environments, we therefore need to create a custom feature extractor. This can be done by creating a feature extractor class that inherits from `stable_baselines3.common.torch_layers.BaseFeaturesExtractor` + +```python +class MinigridFeaturesExtractor(BaseFeaturesExtractor): + def __init__(self, observation_space: gym.Space, features_dim: int = 512, normalized_image: bool = False) -> None: + super().__init__(observation_space, features_dim) + n_input_channels = observation_space.shape[0] + self.cnn = nn.Sequential( + nn.Conv2d(n_input_channels, 16, (2, 2)), + nn.ReLU(), + nn.Conv2d(16, 32, (2, 2)), + nn.ReLU(), + nn.Conv2d(32, 64, (2, 2)), + nn.ReLU(), + nn.Flatten(), + ) + + # Compute shape by doing one forward pass + with torch.no_grad(): + n_flatten = self.cnn(torch.as_tensor(observation_space.sample()[None]).float()).shape[1] + + self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU()) + + def forward(self, observations: torch.Tensor) -> torch.Tensor: + return self.linear(self.cnn(observations)) +``` + +This class is created based on the custom feature extractor [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#custom-feature-extractor:~:text=Custom%20Feature%20Extractor-,%EF%83%81,-If%20you%20want), the CNN architecture is copied from Lucas Willems' [rl-starter-files](https://github.com/lcswillems/rl-starter-files/blob/317da04a9a6fb26506bbd7f6c7c7e10fc0de86e0/model.py#L18). + +## Train a PPO Agent + +The using the custom feature extractor, we can train a PPO agent on the `MiniGrid-Empty-16x16-v0` environment. The following code snippet shows how this can be done. + +```python +import minigrid +from minigrid.wrappers import ImgObsWrapper +from stable_baselines3 import PPO + +policy_kwargs = dict( + features_extractor_class=MinigridFeaturesExtractor, + features_extractor_kwargs=dict(features_dim=128), +) + +env = gym.make("MiniGrid-Empty-16x16-v0", render_mode="rgb_array") +env = ImgObsWrapper(env) + +model = PPO("CnnPolicy", env, policy_kwargs=policy_kwargs, verbose=1) +model.learn(2e5) +``` + +By default the observation of Minigrid environments are dictionaries. Since the `CnnPolicy` from StableBaseline3 by default takes in image observations, we need to wrap the environment using the `ImgObsWrapper` from the Minigrid library. This wrapper converts the dictionary observation to an image observation. + +## Further Reading + +One can also pass dictionary observations to StableBaseline3 policies, for a walkthrough the process of doing so see [here](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#multiple-inputs-and-dictionary-observations). An implementation utilizing this functionality can be found [here](https://github.com/BolunDai0216/MinigridMiniworldTransfer/blob/main/minigrid_gotoobj_train.py). \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index 82116865d..8ab58158e 100644 --- a/docs/index.md +++ b/docs/index.md @@ -57,6 +57,7 @@ To cite this project please use: content/basic_usage content/publications content/create_env_tutorial +content/training ``` ```{toctree}