Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loss functions module #640

Merged
merged 59 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from 57 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
a9a23b9
Bump sqlparse from 0.4.4 to 0.5.0 (#558)
dependabot[bot] Apr 22, 2024
36eb46f
Bump tqdm from 4.66.1 to 4.66.3 (#569)
dependabot[bot] May 4, 2024
fa7f0f1
Bump werkzeug from 3.0.1 to 3.0.3 (#570)
dependabot[bot] May 7, 2024
a05fcd5
Bump jinja2 from 3.1.3 to 3.1.4 (#571)
dependabot[bot] May 7, 2024
b2fc3e6
Bump mlflow from 2.10.1 to 2.12.1 (#575)
dependabot[bot] May 17, 2024
495d5b9
Bump gunicorn from 21.2.0 to 22.0.0 (#576)
dependabot[bot] May 17, 2024
bdd102a
Bump requests from 2.31.0 to 2.32.0 (#578)
dependabot[bot] May 21, 2024
beccd4c
[CI] Run tests through GitHub Actions (#573)
NicolasGensollen May 22, 2024
2861e9d
[CI] Skip tests when PR is in draft mode (#592)
NicolasGensollen May 23, 2024
f5de251
[CI] Test train workflow on GPU machine (#590)
NicolasGensollen May 23, 2024
69b3538
[CI] Port remaining GPU tests to GitHub Actions (#593)
NicolasGensollen May 23, 2024
c9d9252
[CI] Remove GPU pipeline from Jenkinsfile (#594)
NicolasGensollen May 24, 2024
753f04e
[CI] Port remaining non GPU tests to GitHub Actions (#581)
NicolasGensollen May 24, 2024
c424d77
[CI] Remove jenkins related things (#595)
NicolasGensollen May 24, 2024
4281c73
add simulate-gpu option
thibaultdvx May 28, 2024
52d7561
Add flags to run CI tests locally (#596)
thibaultdvx May 30, 2024
39d22fd
[CI] Remove duplicated verbose flag in test pipelines (#598)
NicolasGensollen May 30, 2024
571662c
[DOC] Update the Python version used for creating the conda environme…
NicolasGensollen May 30, 2024
567467e
Merge remote-tracking branch 'upstream/dev' into dev
thibaultdvx May 30, 2024
d54d59c
Flag for local tests (#608)
thibaultdvx May 31, 2024
78f2928
Merge remote-tracking branch 'upstream/dev' into dev
thibaultdvx May 31, 2024
f20e7fb
Update quality_check.py (#609)
HuguesRoy Jun 4, 2024
f6f382a
Fix issue in compare_folders (#610)
thibaultdvx Jun 4, 2024
cd3a538
Merge remote-tracking branch 'upstream/dev' into dev
thibaultdvx Jun 4, 2024
f7eb225
Merge branch 'dev' into refactoring
thibaultdvx Jun 4, 2024
523563d
revert change on poetry
thibaultdvx Jun 4, 2024
4971fa7
correction of wrong conflict choice in rebasing
thibaultdvx Jun 4, 2024
c60d53c
Merge remote-tracking branch 'upstream/refactoring' into refactoring
thibaultdvx Jun 4, 2024
52f9492
[INFRA] Update the Makefile `check.lock` target (#603)
NicolasGensollen Jun 4, 2024
b0d3490
Merge remote-tracking branch 'upstream/dev' into dev
thibaultdvx Jun 5, 2024
02c4e30
Merge branch 'refactoring' of https://github.com/aramis-lab/clinicadl…
thibaultdvx Jun 6, 2024
95dc7f4
Merge remote-tracking branch 'upstream/refactoring' into refactoring
thibaultdvx Jun 7, 2024
996cdd5
[CI] Run unit tests and linter on refactoring branch (#618)
NicolasGensollen Jun 7, 2024
6166462
Merge remote-tracking branch 'upstream/dev' into dev
thibaultdvx Jun 7, 2024
e4ce5dc
Merge branch 'dev' into refactoring
thibaultdvx Jun 7, 2024
3b7763e
Merge remote-tracking branch 'upstream/refactoring' into refactoring
thibaultdvx Jun 7, 2024
752bc2b
remove outdated tests
thibaultdvx Jun 7, 2024
405f4d8
new model module
thibaultdvx Jun 12, 2024
15516b1
Merge remote-tracking branch 'upstream/refactoring' into refactoring
thibaultdvx Jun 12, 2024
e027552
add otrch interface in loss
thibaultdvx Jun 12, 2024
17ea4c8
modify nn
thibaultdvx Jun 13, 2024
c646edd
add nn module
thibaultdvx Jul 22, 2024
d6b5593
unittests
thibaultdvx Jul 22, 2024
dd38e06
Merge remote-tracking branch 'upstream/refactoring' into refactoring
thibaultdvx Jul 22, 2024
0732f15
Merge branch 'refactoring' into clinicadl_models
thibaultdvx Jul 22, 2024
4b2cb1f
remove losses
thibaultdvx Jul 22, 2024
391eafd
remove losses
thibaultdvx Jul 22, 2024
456431d
remove losses
thibaultdvx Jul 22, 2024
03ec7b9
solve problem in unittests (two files with the same name)
thibaultdvx Jul 22, 2024
7734528
add loss module
thibaultdvx Jul 22, 2024
ed466b4
add unittest
thibaultdvx Jul 22, 2024
322cf7e
Merge remote-tracking branch 'upstream/refactoring' into refactoring
thibaultdvx Jul 22, 2024
06afe02
Merge branch 'refactoring' into clinicadl_models
thibaultdvx Jul 22, 2024
4545ac7
modify docstring
thibaultdvx Jul 22, 2024
8e7f921
modify docstring
thibaultdvx Jul 22, 2024
032f6ba
update config with package default
thibaultdvx Jul 23, 2024
3989db5
add last params in config
thibaultdvx Jul 23, 2024
b8d0204
change dict to config and add losses
thibaultdvx Aug 2, 2024
5490ab5
add unit test
thibaultdvx Aug 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions clinicadl/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .config import ClassificationLoss, ImplementedLoss, LossConfig
from .factory import get_loss_function
85 changes: 85 additions & 0 deletions clinicadl/losses/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from enum import Enum
from typing import List, Optional, Union

from pydantic import (
BaseModel,
ConfigDict,
NonNegativeFloat,
NonNegativeInt,
PositiveFloat,
field_validator,
)

from clinicadl.utils.enum import BaseEnum
from clinicadl.utils.factories import DefaultFromLibrary


class ClassificationLoss(str, BaseEnum):
"""Losses that can be used only for classification."""

CROSS_ENTROPY = "CrossEntropyLoss"
MULTI_MARGIN = "MultiMarginLoss"


class ImplementedLoss(str, BaseEnum):
"""Implemented losses in ClinicaDL."""

CROSS_ENTROPY = "CrossEntropyLoss"
MULTI_MARGIN = "MultiMarginLoss"
L1 = "L1Loss"
MSE = "MSELoss"
HUBER = "HuberLoss"
SMOOTH_L1 = "SmoothL1Loss"

@classmethod
def _missing_(cls, value):
raise ValueError(
f"{value} is not implemented. Implemented losses are: "
+ ", ".join([repr(m.value) for m in cls])
)


class Reduction(str, Enum):
"""Supported reduction method in ClinicaDL."""

NONE = "none"
MEAN = "mean"
SUM = "sum"


class Order(int, Enum):
"""Supported order of L-norm for MultiMarginLoss."""

ONE = 1
TWO = 2


class LossConfig(BaseModel):
"""Config class to configure the loss function."""

loss: ImplementedLoss = ImplementedLoss.MSE
reduction: Union[Reduction, DefaultFromLibrary] = DefaultFromLibrary.YES
delta: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES
beta: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES
p: Union[Order, DefaultFromLibrary] = DefaultFromLibrary.YES
margin: Union[float, DefaultFromLibrary] = DefaultFromLibrary.YES
weight: Union[
Optional[List[NonNegativeFloat]], DefaultFromLibrary
] = DefaultFromLibrary.YES
ignore_index: Union[NonNegativeInt, DefaultFromLibrary] = DefaultFromLibrary.YES
label_smoothing: Union[
NonNegativeFloat, DefaultFromLibrary
] = DefaultFromLibrary.YES
# pydantic config
model_config = ConfigDict(
validate_assignment=True, use_enum_values=True, validate_default=True
)

@field_validator("label_smoothing")
@classmethod
def validator_label_smoothing(cls, v):
if isinstance(v, float):
assert (
0 <= v <= 1
), f"label_smoothing must be between 0 and 1 but it has been set to {v}."
return v
50 changes: 50 additions & 0 deletions clinicadl/losses/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import inspect
from copy import deepcopy
from typing import Any, Dict, Tuple

import torch

from clinicadl.utils.factories import DefaultFromLibrary, get_args_and_defaults

from .config import LossConfig


def get_loss_function(config: LossConfig) -> Tuple[torch.nn.Module, Dict[str, Any]]:
"""
Factory function to get a loss function from PyTorch.

Parameters
----------
loss : LossConfig
The config class with the parameters of the loss function.

Returns
-------
nn.Module
The loss function.
Dict[str, Any]
The config dict with only the parameters relevant to the selected
loss function.
"""
loss_class = getattr(torch.nn, config.loss)
expected_args, config_dict = get_args_and_defaults(loss_class.__init__)
for arg, value in config.model_dump().items():
if arg in expected_args and value != DefaultFromLibrary.YES:
config_dict[arg] = value

config_dict_ = deepcopy(config_dict)
if "weight" in config_dict and config_dict["weight"] is not None:
config_dict_["weight"] = torch.Tensor(config_dict_["weight"])
loss = loss_class(**config_dict_)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As the loss is already initialized here, do we need to return a dictionary? Because I thought we wanted to get rid of all the dictionaries.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @camillebrianceau, thanks for the review. You're right, I changed that to put the arguments of the loss function in a config class, rather than a dict.

Even if the loss is initialized, we need the config class to store the config in the MAPS.

The difference between the input and the output config class is that the values set to "DefaultFromLibrary" have been changed with their effective values. It is not essential but better for reproducibility I think (for example in the - unlikely - event that PyTorch changes some default values).

Tell me what do you think!

config_dict["loss"] = config.loss

return loss, config_dict


# TODO : what about them?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

??

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you talking about the other losses?

I've just added more losses from PyTorch in my last commit.

# "KLDivLoss",
# "BCEWithLogitsLoss",
# "VAEGaussianLoss",
# "VAEBernoulliLoss",
# "VAEContinuousBernoulliLoss",
34 changes: 34 additions & 0 deletions clinicadl/utils/factories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import inspect
from enum import Enum
from typing import Any, Callable, Dict, List, Tuple


class DefaultFromLibrary(str, Enum):
YES = "DefaultFromLibrary"


def get_args_and_defaults(func: Callable) -> Tuple[List[str], Dict[str, Any]]:
"""
Gets the arguments of a function, as well as the default
values possibly attached to them.

Parameters
----------
func : Callable
The function.

Returns
-------
List[str]
The names of the arguments.
Dict[str, Any]
The default values in a dict.
"""
signature = inspect.signature(func)
args = list(signature.parameters.keys())
defaults = {
k: v.default
for k, v in signature.parameters.items()
if v.default is not inspect.Parameter.empty
}
return args, defaults
Empty file.
17 changes: 17 additions & 0 deletions tests/unittests/losses/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest
from pydantic import ValidationError


def test_LossConfig():
from clinicadl.losses import LossConfig

LossConfig(reduction="none", p=2, weight=[0.1, 0.1, 0.8])
LossConfig(
loss="SmoothL1Loss", margin=10.0, delta=2.0, beta=3.0, label_smoothing=0.5
)
with pytest.raises(ValueError):
LossConfig(loss="abc")
with pytest.raises(ValueError):
LossConfig(weight=[0.1, -0.1, 0.8])
with pytest.raises(ValidationError):
LossConfig(label_smoothing=1.1)
27 changes: 27 additions & 0 deletions tests/unittests/losses/test_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from torch import Tensor
from torch.nn import MultiMarginLoss

from clinicadl.losses import ImplementedLoss, LossConfig, get_loss_function


def test_get_loss_function():
for loss in [e.value for e in ImplementedLoss]:
config = LossConfig(loss=loss)
get_loss_function(config)

config = LossConfig(loss="MultiMarginLoss", reduction="sum", weight=[1, 2, 3], p=2)
loss, config_dict = get_loss_function(config)
assert isinstance(loss, MultiMarginLoss)
assert loss.reduction == "sum"
assert loss.p == 2
assert loss.margin == 1.0
assert (loss.weight == Tensor([1, 2, 3])).all()
assert config_dict == {
"loss": "MultiMarginLoss",
"reduction": "sum",
"p": 2,
"margin": 1.0,
"weight": [1, 2, 3],
"size_average": None,
"reduce": None,
}
10 changes: 10 additions & 0 deletions tests/unittests/utils/test_factories_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from clinicadl.utils.factories import get_args_and_defaults


def test_get_default_args():
def f(a, b="b", c=0, d=None):
return None

args, defaults = get_args_and_defaults(f)
assert args == ["a", "b", "c", "d"]
assert defaults == {"b": "b", "c": 0, "d": None}
Loading