Skip to content

Commit

Permalink
doc + explicit mandatory input_shape for UNETR++
Browse files Browse the repository at this point in the history
  • Loading branch information
Frank Guibert committed Jul 4, 2024
1 parent 4ee1fec commit 2743812
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 6 deletions.
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,11 @@ pip install mfai

## Instanciate a model

Our [unit tests](tests/test_models.py#L39) provides an example of how to use the models in a PyTorch training loop. Our models are instanciated with 2 mandatory positional arguments: **in_channels** and **out_channels** respectively the number of input and output channels of the model. The other parameter is an instance of the model's settings class.
Our [unit tests](tests/test_models.py#L39) provides an example of how to use the models in a PyTorch training loop. Our models are instanciated with 2 mandatory positional arguments: **in_channels** and **out_channels** respectively the number of input and output channels/features of the model. A third **input_shape** parameter is either mandatory (**UNETR++** or **HalfUNet wtih abs pos embedding**) or optional for the other models. It describes the shape of the input tensor along its spatial dimensions.

The last parameter is an instance of the model's settings class and is a keyword argument with a default value set to the default settings.



Here is an example of how to instanciate the UNet model with a 3 channels input (like an RGB image) and 1 channel output with its default settings:

Expand All @@ -94,6 +98,13 @@ from mfai.torch.models import HalfUNet
halfunet = HalfUNet(in_channels=2, out_channels=2, settings=HalfUNet.settings_kls(num_filters=128, use_ghost=True))
```

Finally, to instanciate a model with the mandatory **input_shape** parameter, here is an example with the UNETR++ model working on 2d spatial data (256x256) with 3 channels input and 1 channel output:

```python
from mfai.torch.models import UNETRPP
unetrpp = UNETRPP(in_channels=3, out_channels=1, input_shape=(256, 256))
```

**_FEATURE:_** Each model has its settings class available under the **settings_kls** attribute.

You can use the **load_from_settings_file** function to instanciate a model with its settings from a json file:
Expand Down
11 changes: 9 additions & 2 deletions mfai/torch/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path
from typing import Optional, Tuple
from torch import nn
from .deeplabv3 import DeepLabV3, DeepLabV3Plus
from .half_unet import HalfUNet
Expand All @@ -20,7 +21,11 @@


def load_from_settings_file(
model_name: str, in_channels: int, out_channels: int, settings_path: Path
model_name: str,
in_channels: int,
out_channels: int,
settings_path: Path,
input_shape: Optional[Tuple[int, ...]] = None,
) -> nn.Module:
"""
Instanciate a model from a settings file with Schema validation.
Expand All @@ -41,4 +46,6 @@ def load_from_settings_file(
model_settings = model_kls.settings_kls.schema().loads(f.read())

# instanciate the model
return model_kls(in_channels, out_channels, settings=model_settings)
return model_kls(
in_channels, out_channels, input_shape=input_shape, settings=model_settings
)
4 changes: 1 addition & 3 deletions mfai/torch/models/unetrpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,6 @@ def __init__(
raise ValueError("dropout_rate should be between 0 and 1.")

if hidden_size % num_heads != 0:
print("Hidden size is ", hidden_size)
print("Num heads is ", num_heads)
raise ValueError("hidden_size should be divisible by num_heads.")

self.norm = nn.LayerNorm(hidden_size)
Expand Down Expand Up @@ -524,7 +522,7 @@ def __init__(
self,
in_channels: int,
out_channels: int,
input_shape: Union[None, Tuple[int, int]] = None,
input_shape: Tuple[int, ...],
settings: UNETRPPSettings = UNETRPPSettings(),
) -> None:
"""
Expand Down

0 comments on commit 2743812

Please sign in to comment.