Skip to content

Commit

Permalink
Tweaks to the documentation (#43)
Browse files Browse the repository at this point in the history
* Various small tweaks

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix broken link in README.md

Signed-off-by: Fabrice Normandin <[email protected]>

* Tweak site name and description

Signed-off-by: Fabrice Normandin <[email protected]>

* Move the docs for the Jax example

Signed-off-by: Fabrice Normandin <[email protected]>

* Revert changes to pyproject.toml

Signed-off-by: Fabrice Normandin <[email protected]>

---------

Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice authored Sep 12, 2024
1 parent e6925d9 commit 64c6f36
Show file tree
Hide file tree
Showing 11 changed files with 91 additions and 46 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ repos:
rev: 0.7.17
hooks:
- id: mdformat
exclude: "SUMMARY.md|testing.md"
exclude: "SUMMARY.md|testing.md|jax.md"
args: ["--number"]
additional_dependencies:
- mdformat-gfm
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Please note: This is a Work-in-Progress. The goal is to make a first release by
This is a template repository for a research project in machine learning. It is meant to be a starting point for new ML researchers that run jobs on SLURM clusters.
The main target audience is [Mila](https://mila.quebec/en) researchers and students, but this should still be useful to anyone that uses PyTorch-Lightning with Hydra.

For more context, see [this introduction to the project.](https://mila-iqia.github.io/ResearchTemplate/overview/intro).
For more context, see [this introduction to the project.](https://mila-iqia.github.io/ResearchTemplate/intro).

## Overview

Expand Down
18 changes: 0 additions & 18 deletions docs/examples/jax.md

This file was deleted.

51 changes: 51 additions & 0 deletions docs/features/jax.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Using Jax with PyTorch-Lightning

You can use Jax for your dataloading, your network, or the learning algorithm, all while still benefiting from the nice stuff that comes from using PyTorch-Lightning.

**How does this work?**
Well, we use [torch-jax-interop](https://www.github.com/lebrice/torch_jax_interop), another package developed here at Mila, which allows easy interop between torch and jax code. See the readme on that repo for more details.

You can use Jax in your network or learning algorithm, for example in your forward / backward passes, to update parameters, etc. but not the training loop itself, since that is handled by the [lightning.Trainer][lightning.pytorch.trainer.trainer.Trainer].
There are lots of good reasons why you might want to let Lightning handle the training loop.
which are very well described [here](https://lightning.ai/docs/pytorch/stable/).

??? note "What about end-to-end training in Jax?"
This template doesn't include a way to do end-to-end, fully-jitted training in Jax, however, it _might_ be possible to do so in this way:

- add a new configuration in the `trainer` config group, with a `_target_` pointing to a
trainer-like object with a `fit`, `evaluate` and `test` method mimicking those of PyTorch-Lightning.
- add a new configuration in the `algorithm` config group pointing to a learning algorithm class that isn't a LightningModule.

If you want an example of how to do this, please make an issue (or like an existing issue) on GitHub.

## `JaxExample`: a LightningModule that uses Jax

The [JaxExample][project.algorithms.jax_example.JaxExample] algorithm uses a network which is a [flax.linen.Module](https://flax.readthedocs.io/en/latest/).
The network is wrapped with `torch_jax_interop.JaxFunction`, so that it can accept torch tensors as inputs, produces torch tensors as outputs, and the parameters are saved as as `torch.nn.Parameter`s (which use the same underlying memory as the jax arrays).
In this example, the loss function and optimizers are in PyTorch, while the network forward and backward passes are written in Jax.

The loss that is returned in the training step is used by Lightning in the usual way. The backward
pass uses Jax to calculate the gradients, and the weights are updated by a PyTorch optimizer.

!!! note
You could also very well do both the forward **and** backward passes in Jax! To do this, [use the 'manual optimization' mode of PyTorch-Lightning](https://lightning.ai/docs/pytorch/stable/model/manual_optimization.html) and perform the parameter updates yourself. For the rest of Lightning to work, just make sure to store the parameters as torch.nn.Parameters. An example of how to do this will be added shortly.

### Jax Network

{{ inline('project.algorithms.jax_example.CNN') }}

### Jax Algorithm

{{ inline('project.algorithms.jax_example.JaxExample') }}

### Configs

#### JaxExample algorithm config

{{ inline('project/configs/algorithm/jax_example.yaml') }}

## Running the example

```console
$ python project/main.py algorithm=jax_example network=jax_cnn datamodule=cifar10
```
9 changes: 9 additions & 0 deletions docs/features/testing.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ the "test explorer" tab to your editor. Then, you'll be able to see and debug th

## Unit tests

```console
pytest -x -v
```

## Regression Tests

Expand All @@ -81,6 +84,12 @@ pytest --regen-all

## integration-tests

To run slower integration tests, use the following:

```console
pytest -x -v --slow
```

## Continuous Integration

<!--
Expand Down
7 changes: 3 additions & 4 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,16 @@ For more context, see [this introduction to the project.](intro.md).
You can use both PyTorch and Jax for your algorithms!
([Lightning](https://lightning.ai/docs/pytorch/stable/) handles the rest.)

[:octicons-arrow-right-24: Check out the Jax example](examples/jax.md)
[:octicons-arrow-right-24: Check out the Jax example](features/jax.md)

- :fontawesome-solid-plane-departure:{ .lg .middle } __Ready-to-use examples__

---

Includes examples for Supervised learning(1) and NLP 🤗(2), with unsupervised learning and RL coming soon.
Includes examples for Supervised learning(1) and NLP 🤗, with unsupervised learning and RL coming soon.
{ .annotate }

1. The source code for the example is available [here](https://github.com/mila-iqia/ResearchTemplate/blob/master/project/algorithms/example.py)
2. 👷 Coming soon-ish

[:octicons-arrow-right-24: Check out the examples here](examples/examples.md)

Expand All @@ -69,7 +68,7 @@ This project makes use of the following libraries:

- [Hydra](https://hydra.cc/) is used to configure the project. It allows you to define configuration files and override them from the command line.
- [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/) is used to as the training framework. It provides a high-level interface to organize ML research code.
- 🔥 Please note: You can also use [Jax](https://jax.readthedocs.io/en/latest/) with this repo, as is shown in the [Jax example](examples/jax.md) 🔥
- 🔥 Please note: You can also use [Jax](https://jax.readthedocs.io/en/latest/) with this repo, as described in the [Jax example](features/jax.md) 🔥
- [Weights & Biases](https://wandb.ai) is used to log metrics and visualize results.
- [pytest](https://docs.pytest.org/en/stable/) is used for testing.

Expand Down
2 changes: 1 addition & 1 deletion docs/intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Other good reads:

You are welcome (and encouraged) to use other similar templates which, at the time of writing this, have significantly better documentation. However, there are several advantages to using this particular template:

- ❗Support for both Jax and Torch with PyTorch-Lightning (See the [Jax example](examples/jax.md))❗
- ❗Support for both Jax and Torch with PyTorch-Lightning (See the [Jax example](features/jax.md))❗
- Your Hydra configs will have an [Auto-Generated YAML schemas](features/auto_schema.md) 🔥
- A comprehensive suite of automated tests for new algorithms, datasets and networks
- 🤖 [Thoroughly tested on the Mila directly with GitHub CI](features/testing.md#automated-testing-on-slurm-clusters-with-github-ci)
Expand Down
32 changes: 18 additions & 14 deletions docs/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@


def define_env(env: MacrosPlugin):
@env.macro
def inline(module_or_file: str, indent: int = 0):
block_type: str | None = None
# print(f"Inlining reference: {module_or_file}")
Expand All @@ -36,24 +35,29 @@ def inline(module_or_file: str, indent: int = 0):
else:
block_type = block_type or "python3"
obj: Any = get_object_from_reference(module_or_file)
content = "".join(inspect.getsourcelines(obj)[0])
logger.info(f"inlining code for {obj}")
content = inspect.getsource(obj)
# BUG: Sometimes using {{ inline('some_module.SomeClass.some_method') }} will show the
# incorrect source code: it will show the method *above* the one we're looking for.
# content = "".join(inspect.getsourcelines(obj)[0])

content = f"```{block_type}\n" + textwrap.indent(content + "\n```", " " * indent)
return content

env.macro(inline, name="inline")


def get_object_from_reference(reference: str):
"""taken from https://github.com/mkdocs/mkdocs/issues/692"""
split = reference.split(".")
right = []
module = None
while split:
parts = reference.split(".")
for i in range(1, len(parts)):
module_name = ".".join(parts[:i])
obj_path = parts[i:]
try:
module = importlib.import_module(".".join(split))
break
except ModuleNotFoundError:
right.append(split.pop())
if module:
for entry in reversed(right):
module = getattr(module, entry)
return module
obj = importlib.import_module(module_name)
for part in obj_path:
obj = getattr(obj, part)
return obj
except (ModuleNotFoundError, AttributeError):
continue
raise RuntimeError(f"Unable to import the {reference=}")
9 changes: 6 additions & 3 deletions mkdocs.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
site_name: Research Project Template (wip)
site_description: A project template and directory structure for Python data science projects. (Work-in-Progress)
site_name: Research Project Template
site_description: Template for a ML Research project.
site_url: https://mila-iqia.github.io/ResearchTemplate/
repo_url: https://www.github.com/mila-iqia/ResearchTemplate
# edit_uri: edit/master/docs
Expand Down Expand Up @@ -40,6 +40,9 @@ markdown_extensions:
anchor_linenums: true
line_spans: __span
pygments_lang_class: true
- pymdownx.inlinehilite
- pymdownx.snippets
- pymdownx.superfences
- pymdownx.magiclink
- attr_list
- md_in_html
Expand All @@ -57,7 +60,7 @@ markdown_extensions:
plugins:
- search
- awesome-pages
- macros: #https://mkdocs-macros-plugin.readthedocs.io/en/latest/#declaration-of-the-macros-plugin
- macros: #https://mkdocs-macros-plugin.readthedocs.io/en/latest/#declaration-of-the-macros-plugin
module_name: docs/macros
- autorefs
- gen-files:
Expand Down
3 changes: 0 additions & 3 deletions project/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@
from project.utils.hydra_utils import resolve_dictconfig
from project.utils.utils import print_config

if os.environ.get("CUDA_VISIBLE_DEVICES", "").startswith("MIG-"):
# NOTE: Perhaps unsetting it would also work, but this works atm.
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
logger = get_logger(__name__)

PROJECT_NAME = Path(__file__).parent.name
Expand Down
2 changes: 1 addition & 1 deletion project/main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import hydra_zen
import pytest
import torch
from omegaconf import DictConfig

from project.algorithms.example import ExampleAlgorithm
Expand All @@ -31,7 +32,6 @@ def test_jax_can_use_the_GPU():

def test_torch_can_use_the_GPU():
"""Test that torch can use the GPU if it we have one."""
import torch

assert torch.cuda.is_available() == bool(shutil.which("nvidia-smi"))

Expand Down

0 comments on commit 64c6f36

Please sign in to comment.