Skip to content

Commit

Permalink
Add some docs and improve docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
ibro45 committed Dec 4, 2024
1 parent 7cc0b9d commit 1e8b3d0
Show file tree
Hide file tree
Showing 23 changed files with 463 additions and 226 deletions.
15 changes: 14 additions & 1 deletion docs/advanced/callbacks.md
Original file line number Diff line number Diff line change
@@ -1 +1,14 @@
🚧 Under construction 🚧
# Callbacks

Callbacks in Lighter allow you to customize and extend the training process. You can define custom actions to be executed at various stages of the training loop.

## Freezer Callback
The `LighterFreezer` callback allows you to freeze certain layers of the model during training. This can be useful for transfer learning or fine-tuning.

## Writer Callbacks
Lighter provides writer callbacks to save predictions in different formats. The `LighterFileWriter` and `LighterTableWriter` are examples of such callbacks.

- **LighterFileWriter**: Writes predictions to files, supporting formats like images, videos, and ITK images.
- **LighterTableWriter**: Saves predictions in a table format, such as CSV.

For more details on how to implement and use callbacks, refer to the [PyTorch Lightning Callback documentation](https://pytorch-lightning.readthedocs.io/en/stable/extensions/callbacks.html).
12 changes: 11 additions & 1 deletion docs/advanced/inferer.md
Original file line number Diff line number Diff line change
@@ -1 +1,11 @@
🚧 Under construction 🚧
# Inferer

The inferer in Lighter is used for making predictions on data. It is typically used in validation, testing, and prediction workflows.

## Using Inferers
Inferers must be classes with a `__call__` method that accepts two arguments: the input to infer over and the model itself. They are used to handle complex inference scenarios, such as patch-based or sliding window inference.

## MONAI Inferers
Lighter integrates with MONAI inferers, which cover most common inference scenarios. You can use MONAI's sliding window or patch-based inferers directly in your Lighter configuration.

For more information on MONAI inferers, visit the [MONAI documentation](https://docs.monai.io/en/stable/inferers.html).
26 changes: 25 additions & 1 deletion docs/advanced/postprocessing.md
Original file line number Diff line number Diff line change
@@ -1 +1,25 @@
🚧 Under construction 🚧
# Postprocessing

Postprocessing in Lighter allows you to apply custom transformations to data at various stages of the workflow. This can include modifying inputs, targets, predictions, or entire batches.

## Defining Postprocessing Functions
Postprocessing functions can be defined in the configuration file under the `postprocessing` key. They can be applied to:
- **Batch**: Modify the entire batch before it is passed to the model.
- **Criterion**: Modify inputs, targets, or predictions before loss calculation.
- **Metrics**: Modify inputs, targets, or predictions before metric calculation.
- **Logging**: Modify inputs, targets, or predictions before logging.

## Example
```yaml
postprocessing:
batch:
train: '$lambda x: {"input": x[0], "target": x[1]}'
criterion:
input: '$lambda x: x / 255.0'
metrics:
pred: '$lambda x: x.argmax(dim=1)'
logging:
target: '$lambda x: x.cpu().numpy()'
```
For more information on how to use postprocessing in Lighter, refer to the [Lighter documentation](./config.md).
10 changes: 8 additions & 2 deletions docs/basics/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

Lighter is a configuration-centric framework where the config. is used for setting up the machine learning workflow from model architecture selection, loss function, optimizer, dataset preparation and running the training/evaluation/inference process.

Our configuration system is heavily based on MONAI bundle parser but with a standardized structure. For every configuration, we expect several items to be mandatorily defined.
Our configuration system is heavily based on the MONAI bundle parser but with a standardized structure. For every configuration, we expect several items to be mandatorily defined.

The configuration is divided into two main components:
- **Trainer**: Handles the training process, including epochs, devices, etc.
- **LighterSystem**: Encapsulates the model, optimizer, datasets, and other components.

Let us take a simple example config to dig deeper into the configuration system of Lighter. You can go through the config and click on the + for more information about specific concepts.

Expand Down Expand Up @@ -49,7 +53,9 @@ system:
1. `_target_` is a special reserved keyword that initializes a python object from the provided text. In this case, a `Trainer` object from the `pytorch_lightning` library is initialized
2. `max_epochs` is an argument of the `Trainer` class which is passed through this format. Any argument for the class can be passed similarly.
3. `$@` is a combination of `$` which evaluates a python expression and `@` which references a python object. In this case we first reference the model with `@model` which is the `torchvision.models.resnet18` defined earlier and then access its parameters using `[email protected]()`
4. YAML allows passing a list in the format below where each `_target_` specifices a transform that is added to the list of transforms in `Compose`. The `torchvision.datasets.CIFAR10` accepts these with a `transform` argument and applies them to each item.
4. YAML allows passing a list in the format below where each `_target_` specifies a transform that is added to the list of transforms in `Compose`. The `torchvision.datasets.CIFAR10` accepts these with a `transform` argument and applies them to each item.

5. Datasets are defined for different modes: train, val, test, and predict. Each dataset can have its own transforms and configurations.

## Configuration Concepts
As seen in the [Quickstart](./quickstart.md), Lighter has two main components:
Expand Down
4 changes: 2 additions & 2 deletions docs/basics/projects.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ With Lighter, you can be as hands-on as you wish when using it in your project.
- [x] Train on your data + Add a custom model architecture + Add a complex loss function
- [x] Customization per your imagination!

Lets start by looking at each of these one by one. At the end of this, you will hopefully have a better idea of how best you can leverage lighter
Let's start by looking at each of these one by one. At the end of this, you will hopefully have a better idea of how best you can leverage Lighter.

### Training on your own dataset

Expand Down Expand Up @@ -90,7 +90,7 @@ class MyXRayDataset(Dataset):
```

!!! note
Lighter works with the default torchvision format of (image, target) and also with `dict` with `input` and `target` keys. The input/target key or tuple can contain complex input/target organization, e.g. multiple images for input and multiple labels for target
Lighter works with the default torchvision format of (image, target) and also with `dict` with `input` and `target` keys. The input/target key or tuple can contain complex input/target organization, e.g., multiple images for input and multiple labels for target.


Now that you have built your dataset, all you need to do is add it to the lighter config! But wait, how will Lighter know where your code is? All lighter configs contain a `project` key that takes the full path to where your python code is located. Once you set this up, call `project.my_xray_dataset.` and Lighter will pick up the dataset.
Expand Down
11 changes: 6 additions & 5 deletions docs/basics/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ pip install project-lighter --pre
## Building a config
Key to the Lighter ecosystem is a YAML file that serves as the central point of control for your experiments. It allows you to define, manage, and modify all aspects of your experiment without diving deep into the code.

A Lighter config contains two main components:
A Lighter config contains two main components:

- Trainer
- LighterSystem
- **Trainer**: Manages the training loop and related settings.
- **LighterSystem**: Defines the model, datasets, optimizer, and other components.

### Trainer
Trainer contains all the information about running a training/evaluation/inference process and is a crucial component of training automation in Pytorch Lightning. Please refer to the [Pytorch Lightning's Trainer documentation](https://lightning.ai/docs/pytorch/stable/common/trainer.html) for more information.
Expand Down Expand Up @@ -76,10 +76,11 @@ system:
# takes a dictionary as input.
batch:
train: '$lambda x: {"input": x[0], "target": x[1]}'

```
For more information about each of the LighterSystem components and how to override them, see [here](./config.md)
5. Postprocessing functions can be defined for different stages like batch, criterion, metrics, and logging. These functions allow you to modify data at various points in the workflow.
For more information about each of the LighterSystem components and how to override them, see [here](./config.md).
## Running this experiment with Lighter
We just combine the Trainer and LighterSystem into a single YAML and run the command in the terminal as shown,
Expand Down
18 changes: 15 additions & 3 deletions docs/basics/workflows.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,27 @@
# Running Workflows

Once the configuration is establised, Lighter can run different deep learning workflows. The following workflows are supported:
Once the configuration is established, Lighter can run different deep learning workflows. The following workflows are supported:

1. fit
2. validate
3. test
4. predict

These workflows are inherited from the Pytorch lightning trainer and can be found in the [PL docs](https://lightning.ai/docs/pytorch/stable/common/trainer.html#methods)
These workflows are inherited from the PyTorch Lightning trainer and can be found in the [PL docs](https://lightning.ai/docs/pytorch/stable/common/trainer.html#methods).

We also show below how you can run these workflows and what are some "required" definitions in the config while running these workflows.
Below, we show how you can run these workflows and what are some "required" definitions in the config while running these workflows.

## Fit workflow
The fit workflow is used for training the model. It requires the `trainer` and `system` configurations to be defined in the YAML file.

## Validate workflow
The validate workflow is used to evaluate the model on a validation dataset. Ensure that the `val` dataset is defined in the `system` configuration.

## Test workflow
The test workflow is used to evaluate the model on a test dataset. Ensure that the `test` dataset is defined in the `system` configuration.

## Predict workflow
The predict workflow is used to make predictions on new data. Ensure that the `predict` dataset is defined in the `system` configuration.

## Fit workflow

Expand Down
4 changes: 4 additions & 0 deletions lighter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
Lighter is a framework for streamlining deep learning experiments with configuration files.
"""

__version__ = "0.0.3a9"

from .utils.logging import _setup_logging
Expand Down
44 changes: 27 additions & 17 deletions lighter/callbacks/freezer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
This module provides the LighterFreezer callback, which allows freezing model parameters during training.
"""

from typing import Any, List, Optional, Union

from loguru import logger
Expand All @@ -10,18 +14,16 @@

class LighterFreezer(Callback):
"""
Callback to freeze the parameters/layers of a model. Can be run indefinitely or until a specified step or epoch.
`names` and`name_starts_with` can be used to specify which parameters to freeze.
If both are specified, the parameters that match any of the two will be frozen.
Callback to freeze model parameters during training. Parameters can be frozen by exact name or prefix.
Freezing can be applied indefinitely or until a specified step/epoch.
Args:
names (str, List[str], optional): Names of the parameters to be frozen. Defaults to None.
name_starts_with (str, List[str], optional): Prefixes of the parameter names to be frozen. Defaults to None.
except_names (str, List[str], optional): Names of the parameters to be excluded from freezing. Defaults to None.
except_name_starts_with (str, List[str], optional): Prefixes of the parameter names to be excluded from freezing.
Defaults to None.
until_step (int, optional): Maximum step to freeze parameters until. Defaults to None.
until_epoch (int, optional): Maximum epoch to freeze parameters until. Defaults to None.
names (Optional[Union[str, List[str]]]): Full names of parameters to freeze.
name_starts_with (Optional[Union[str, List[str]]]): Prefixes of parameter names to freeze.
except_names (Optional[Union[str, List[str]]]): Names of parameters to exclude from freezing.
except_name_starts_with (Optional[Union[str, List[str]]]): Prefixes of parameter names to exclude from freezing.
until_step (int): Maximum step to freeze parameters until.
until_epoch (int): Maximum epoch to freeze parameters until.
Raises:
ValueError: If neither `names` nor `name_starts_with` are specified.
Expand Down Expand Up @@ -56,6 +58,15 @@ def __init__(
self._frozen_state = False

def on_train_batch_start(self, trainer: Trainer, pl_module: LighterSystem, batch: Any, batch_idx: int) -> None:
"""
Called at the start of each training batch to potentially freeze parameters.
Args:
trainer (Trainer): The trainer instance.
pl_module (LighterSystem): The LighterSystem instance.
batch (Any): The current batch.
batch_idx (int): The index of the batch.
"""
self._on_batch_start(trainer, pl_module)

def on_validation_batch_start(
Expand All @@ -75,11 +86,11 @@ def on_predict_batch_start(

def _on_batch_start(self, trainer: Trainer, pl_module: LighterSystem) -> None:
"""
Freezes the parameters of the model at the start of each training batch.
Freezes or unfreezes model parameters based on the current step or epoch.
Args:
trainer (Trainer): Trainer instance.
pl_module (LighterSystem): LighterSystem instance.
trainer (Trainer): The trainer instance.
pl_module (LighterSystem): The LighterSystem instance.
"""
current_step = trainer.global_step
current_epoch = trainer.current_epoch
Expand All @@ -101,12 +112,11 @@ def _on_batch_start(self, trainer: Trainer, pl_module: LighterSystem) -> None:

def _set_model_requires_grad(self, model: Union[Module, LighterSystem], requires_grad: bool) -> None:
"""
Sets the requires_grad attribute of the model's parameters.
Sets the requires_grad attribute for model parameters, effectively freezing or unfreezing them.
Args:
model (Module): PyTorch model whose parameters need to be frozen.
requires_grad (bool): Whether to freeze the parameters or not.
model (Union[Module, LighterSystem]): The model whose parameters to modify.
requires_grad (bool): Whether to allow gradients (unfreeze) or not (freeze).
"""
# If the model is a `LighterSystem`, get the underlying PyTorch model.
if isinstance(model, LighterSystem):
Expand Down
23 changes: 15 additions & 8 deletions lighter/callbacks/utils.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,37 @@
"""
This module provides utility functions for callbacks, including mode conversion and image preprocessing.
"""

import torch
import torchvision
from torch import Tensor


def get_lighter_mode(lightning_stage: str) -> str:
"""Converts the name of a PyTorch Lightnig stage to the name of its corresponding Lighter mode.
"""
Converts a PyTorch Lightning stage name to the corresponding Lighter mode name.
Args:
lightning_stage (str): Stage in which the Trainer is. Can be accessed using `trainer.state.stage`.
lightning_stage (str): The Lightning stage in which the Trainer is operating.
Returns:
Lighter mode name.
str: The corresponding Lighter mode name.
"""
lightning_to_lighter = {"train": "train", "validate": "val", "test": "test"}
return lightning_to_lighter[lightning_stage]


def preprocess_image(image: Tensor) -> Tensor:
"""Preprocess the image before logging it. If it is a batch of multiple images,
it will create a grid image of them. In case of 3D, a single image is displayed
with slices stacked vertically, while a batch of 3D images as a grid where each
column is a different 3D image.
"""
Preprocess image for logging. For multiple 2D images, creates a grid.
For 3D images, stacks slices vertically. For multiple 3D images, creates a grid
with each column showing a different 3D image stacked vertically.
Args:
image (Tensor): A 2D or 3D image tensor.
Returns:
The image ready for logging.
Tensor: The preprocessed image ready for logging.
"""
# If 3D (BCDHW), concat the images vertically and horizontally.
if image.ndim == 5:
Expand Down
31 changes: 21 additions & 10 deletions lighter/callbacks/writer/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
This module provides the base class for defining custom writers in Lighter, allowing predictions to be saved in various formats.
"""

from typing import Any, Callable, Dict, Union

import gc
Expand All @@ -21,8 +25,8 @@ class LighterBaseWriter(ABC, Callback):
2) `self.write()` method to specify the saving strategy for a prediction.
Args:
path (Union[str, Path]): Path for saving. It can be a directory or a specific file.
writer (Union[str, Callable]): Name of the writer function registered in `self.writers`, or a custom writer function.
path (Union[str, Path]): Path for saving predictions.
writer (Union[str, Callable]): Writer function or name of a registered writer.
"""

def __init__(self, path: Union[str, Path], writer: Union[str, Callable]) -> None:
Expand Down Expand Up @@ -63,11 +67,12 @@ def write(self, tensor: Tensor, id: int) -> None:

def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None:
"""
Callback function to set up necessary prerequisites: prediction count and prediction file or directory.
When executing in a distributed environment, it ensures that:
1. Each distributed node initializes a prediction count based on its rank.
2. All distributed nodes write predictions to the same path.
3. The path is accessible to all nodes, i.e., all nodes share the same storage.
Sets up the writer, ensuring the path is ready for saving predictions.
Args:
trainer (Trainer): The trainer instance.
pl_module (LighterSystem): The LighterSystem instance.
stage (str): The current stage of training.
"""
if stage != "predict":
return
Expand Down Expand Up @@ -99,9 +104,15 @@ def on_predict_batch_end(
self, trainer: Trainer, pl_module: LighterSystem, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0
) -> None:
"""
Callback method executed at the end of each prediction batch/step.
If the IDs are not provided, it generates global unique IDs based on the prediction count.
Finally, it writes the predictions using the specified writer.
Callback method executed at the end of each prediction batch to write predictions with unique IDs.
Args:
trainer (Trainer): The trainer instance.
pl_module (LighterSystem): The LighterSystem instance.
outputs (Any): The outputs from the prediction step.
batch (Any): The current batch.
batch_idx (int): The index of the batch.
dataloader_idx (int): The index of the dataloader.
"""
# If the IDs are not provided, generate global unique IDs based on the prediction count. DDP supported.
if outputs["id"] is None:
Expand Down
Loading

0 comments on commit 1e8b3d0

Please sign in to comment.