-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Loss functions module #640
Merged
thibaultdvx
merged 59 commits into
aramis-lab:refactoring
from
thibaultdvx:clinicadl_models
Aug 26, 2024
Merged
Changes from 57 commits
Commits
Show all changes
59 commits
Select commit
Hold shift + click to select a range
a9a23b9
Bump sqlparse from 0.4.4 to 0.5.0 (#558)
dependabot[bot] 36eb46f
Bump tqdm from 4.66.1 to 4.66.3 (#569)
dependabot[bot] fa7f0f1
Bump werkzeug from 3.0.1 to 3.0.3 (#570)
dependabot[bot] a05fcd5
Bump jinja2 from 3.1.3 to 3.1.4 (#571)
dependabot[bot] b2fc3e6
Bump mlflow from 2.10.1 to 2.12.1 (#575)
dependabot[bot] 495d5b9
Bump gunicorn from 21.2.0 to 22.0.0 (#576)
dependabot[bot] bdd102a
Bump requests from 2.31.0 to 2.32.0 (#578)
dependabot[bot] beccd4c
[CI] Run tests through GitHub Actions (#573)
NicolasGensollen 2861e9d
[CI] Skip tests when PR is in draft mode (#592)
NicolasGensollen f5de251
[CI] Test train workflow on GPU machine (#590)
NicolasGensollen 69b3538
[CI] Port remaining GPU tests to GitHub Actions (#593)
NicolasGensollen c9d9252
[CI] Remove GPU pipeline from Jenkinsfile (#594)
NicolasGensollen 753f04e
[CI] Port remaining non GPU tests to GitHub Actions (#581)
NicolasGensollen c424d77
[CI] Remove jenkins related things (#595)
NicolasGensollen 4281c73
add simulate-gpu option
thibaultdvx 52d7561
Add flags to run CI tests locally (#596)
thibaultdvx 39d22fd
[CI] Remove duplicated verbose flag in test pipelines (#598)
NicolasGensollen 571662c
[DOC] Update the Python version used for creating the conda environme…
NicolasGensollen 567467e
Merge remote-tracking branch 'upstream/dev' into dev
thibaultdvx d54d59c
Flag for local tests (#608)
thibaultdvx 78f2928
Merge remote-tracking branch 'upstream/dev' into dev
thibaultdvx f20e7fb
Update quality_check.py (#609)
HuguesRoy f6f382a
Fix issue in compare_folders (#610)
thibaultdvx cd3a538
Merge remote-tracking branch 'upstream/dev' into dev
thibaultdvx f7eb225
Merge branch 'dev' into refactoring
thibaultdvx 523563d
revert change on poetry
thibaultdvx 4971fa7
correction of wrong conflict choice in rebasing
thibaultdvx c60d53c
Merge remote-tracking branch 'upstream/refactoring' into refactoring
thibaultdvx 52f9492
[INFRA] Update the Makefile `check.lock` target (#603)
NicolasGensollen b0d3490
Merge remote-tracking branch 'upstream/dev' into dev
thibaultdvx 02c4e30
Merge branch 'refactoring' of https://github.com/aramis-lab/clinicadl…
thibaultdvx 95dc7f4
Merge remote-tracking branch 'upstream/refactoring' into refactoring
thibaultdvx 996cdd5
[CI] Run unit tests and linter on refactoring branch (#618)
NicolasGensollen 6166462
Merge remote-tracking branch 'upstream/dev' into dev
thibaultdvx e4ce5dc
Merge branch 'dev' into refactoring
thibaultdvx 3b7763e
Merge remote-tracking branch 'upstream/refactoring' into refactoring
thibaultdvx 752bc2b
remove outdated tests
thibaultdvx 405f4d8
new model module
thibaultdvx 15516b1
Merge remote-tracking branch 'upstream/refactoring' into refactoring
thibaultdvx e027552
add otrch interface in loss
thibaultdvx 17ea4c8
modify nn
thibaultdvx c646edd
add nn module
thibaultdvx d6b5593
unittests
thibaultdvx dd38e06
Merge remote-tracking branch 'upstream/refactoring' into refactoring
thibaultdvx 0732f15
Merge branch 'refactoring' into clinicadl_models
thibaultdvx 4b2cb1f
remove losses
thibaultdvx 391eafd
remove losses
thibaultdvx 456431d
remove losses
thibaultdvx 03ec7b9
solve problem in unittests (two files with the same name)
thibaultdvx 7734528
add loss module
thibaultdvx ed466b4
add unittest
thibaultdvx 322cf7e
Merge remote-tracking branch 'upstream/refactoring' into refactoring
thibaultdvx 06afe02
Merge branch 'refactoring' into clinicadl_models
thibaultdvx 4545ac7
modify docstring
thibaultdvx 8e7f921
modify docstring
thibaultdvx 032f6ba
update config with package default
thibaultdvx 3989db5
add last params in config
thibaultdvx b8d0204
change dict to config and add losses
thibaultdvx 5490ab5
add unit test
thibaultdvx File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .config import ClassificationLoss, ImplementedLoss, LossConfig | ||
from .factory import get_loss_function |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from enum import Enum | ||
from typing import List, Optional, Union | ||
|
||
from pydantic import ( | ||
BaseModel, | ||
ConfigDict, | ||
NonNegativeFloat, | ||
NonNegativeInt, | ||
PositiveFloat, | ||
field_validator, | ||
) | ||
|
||
from clinicadl.utils.enum import BaseEnum | ||
from clinicadl.utils.factories import DefaultFromLibrary | ||
|
||
|
||
class ClassificationLoss(str, BaseEnum): | ||
"""Losses that can be used only for classification.""" | ||
|
||
CROSS_ENTROPY = "CrossEntropyLoss" | ||
MULTI_MARGIN = "MultiMarginLoss" | ||
|
||
|
||
class ImplementedLoss(str, BaseEnum): | ||
"""Implemented losses in ClinicaDL.""" | ||
|
||
CROSS_ENTROPY = "CrossEntropyLoss" | ||
MULTI_MARGIN = "MultiMarginLoss" | ||
L1 = "L1Loss" | ||
MSE = "MSELoss" | ||
HUBER = "HuberLoss" | ||
SMOOTH_L1 = "SmoothL1Loss" | ||
|
||
@classmethod | ||
def _missing_(cls, value): | ||
raise ValueError( | ||
f"{value} is not implemented. Implemented losses are: " | ||
+ ", ".join([repr(m.value) for m in cls]) | ||
) | ||
|
||
|
||
class Reduction(str, Enum): | ||
"""Supported reduction method in ClinicaDL.""" | ||
|
||
NONE = "none" | ||
MEAN = "mean" | ||
SUM = "sum" | ||
|
||
|
||
class Order(int, Enum): | ||
"""Supported order of L-norm for MultiMarginLoss.""" | ||
|
||
ONE = 1 | ||
TWO = 2 | ||
|
||
|
||
class LossConfig(BaseModel): | ||
"""Config class to configure the loss function.""" | ||
|
||
loss: ImplementedLoss = ImplementedLoss.MSE | ||
reduction: Union[Reduction, DefaultFromLibrary] = DefaultFromLibrary.YES | ||
delta: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES | ||
beta: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES | ||
p: Union[Order, DefaultFromLibrary] = DefaultFromLibrary.YES | ||
margin: Union[float, DefaultFromLibrary] = DefaultFromLibrary.YES | ||
weight: Union[ | ||
Optional[List[NonNegativeFloat]], DefaultFromLibrary | ||
] = DefaultFromLibrary.YES | ||
ignore_index: Union[NonNegativeInt, DefaultFromLibrary] = DefaultFromLibrary.YES | ||
label_smoothing: Union[ | ||
NonNegativeFloat, DefaultFromLibrary | ||
] = DefaultFromLibrary.YES | ||
# pydantic config | ||
model_config = ConfigDict( | ||
validate_assignment=True, use_enum_values=True, validate_default=True | ||
) | ||
|
||
@field_validator("label_smoothing") | ||
@classmethod | ||
def validator_label_smoothing(cls, v): | ||
if isinstance(v, float): | ||
assert ( | ||
0 <= v <= 1 | ||
), f"label_smoothing must be between 0 and 1 but it has been set to {v}." | ||
return v |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import inspect | ||
from copy import deepcopy | ||
from typing import Any, Dict, Tuple | ||
|
||
import torch | ||
|
||
from clinicadl.utils.factories import DefaultFromLibrary, get_args_and_defaults | ||
|
||
from .config import LossConfig | ||
|
||
|
||
def get_loss_function(config: LossConfig) -> Tuple[torch.nn.Module, Dict[str, Any]]: | ||
""" | ||
Factory function to get a loss function from PyTorch. | ||
|
||
Parameters | ||
---------- | ||
loss : LossConfig | ||
The config class with the parameters of the loss function. | ||
|
||
Returns | ||
------- | ||
nn.Module | ||
The loss function. | ||
Dict[str, Any] | ||
The config dict with only the parameters relevant to the selected | ||
loss function. | ||
""" | ||
loss_class = getattr(torch.nn, config.loss) | ||
expected_args, config_dict = get_args_and_defaults(loss_class.__init__) | ||
for arg, value in config.model_dump().items(): | ||
if arg in expected_args and value != DefaultFromLibrary.YES: | ||
config_dict[arg] = value | ||
|
||
config_dict_ = deepcopy(config_dict) | ||
if "weight" in config_dict and config_dict["weight"] is not None: | ||
config_dict_["weight"] = torch.Tensor(config_dict_["weight"]) | ||
loss = loss_class(**config_dict_) | ||
|
||
config_dict["loss"] = config.loss | ||
|
||
return loss, config_dict | ||
|
||
|
||
# TODO : what about them? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ?? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you talking about the other losses? I've just added more losses from PyTorch in my last commit. |
||
# "KLDivLoss", | ||
# "BCEWithLogitsLoss", | ||
# "VAEGaussianLoss", | ||
# "VAEBernoulliLoss", | ||
# "VAEContinuousBernoulliLoss", |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import inspect | ||
from enum import Enum | ||
from typing import Any, Callable, Dict, List, Tuple | ||
|
||
|
||
class DefaultFromLibrary(str, Enum): | ||
YES = "DefaultFromLibrary" | ||
|
||
|
||
def get_args_and_defaults(func: Callable) -> Tuple[List[str], Dict[str, Any]]: | ||
""" | ||
Gets the arguments of a function, as well as the default | ||
values possibly attached to them. | ||
|
||
Parameters | ||
---------- | ||
func : Callable | ||
The function. | ||
|
||
Returns | ||
------- | ||
List[str] | ||
The names of the arguments. | ||
Dict[str, Any] | ||
The default values in a dict. | ||
""" | ||
signature = inspect.signature(func) | ||
args = list(signature.parameters.keys()) | ||
defaults = { | ||
k: v.default | ||
for k, v in signature.parameters.items() | ||
if v.default is not inspect.Parameter.empty | ||
} | ||
return args, defaults |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import pytest | ||
from pydantic import ValidationError | ||
|
||
|
||
def test_LossConfig(): | ||
from clinicadl.losses import LossConfig | ||
|
||
LossConfig(reduction="none", p=2, weight=[0.1, 0.1, 0.8]) | ||
LossConfig( | ||
loss="SmoothL1Loss", margin=10.0, delta=2.0, beta=3.0, label_smoothing=0.5 | ||
) | ||
with pytest.raises(ValueError): | ||
LossConfig(loss="abc") | ||
with pytest.raises(ValueError): | ||
LossConfig(weight=[0.1, -0.1, 0.8]) | ||
with pytest.raises(ValidationError): | ||
LossConfig(label_smoothing=1.1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from torch import Tensor | ||
from torch.nn import MultiMarginLoss | ||
|
||
from clinicadl.losses import ImplementedLoss, LossConfig, get_loss_function | ||
|
||
|
||
def test_get_loss_function(): | ||
for loss in [e.value for e in ImplementedLoss]: | ||
config = LossConfig(loss=loss) | ||
get_loss_function(config) | ||
|
||
config = LossConfig(loss="MultiMarginLoss", reduction="sum", weight=[1, 2, 3], p=2) | ||
loss, config_dict = get_loss_function(config) | ||
assert isinstance(loss, MultiMarginLoss) | ||
assert loss.reduction == "sum" | ||
assert loss.p == 2 | ||
assert loss.margin == 1.0 | ||
assert (loss.weight == Tensor([1, 2, 3])).all() | ||
assert config_dict == { | ||
"loss": "MultiMarginLoss", | ||
"reduction": "sum", | ||
"p": 2, | ||
"margin": 1.0, | ||
"weight": [1, 2, 3], | ||
"size_average": None, | ||
"reduce": None, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from clinicadl.utils.factories import get_args_and_defaults | ||
|
||
|
||
def test_get_default_args(): | ||
def f(a, b="b", c=0, d=None): | ||
return None | ||
|
||
args, defaults = get_args_and_defaults(f) | ||
assert args == ["a", "b", "c", "d"] | ||
assert defaults == {"b": "b", "c": 0, "d": None} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As the loss is already initialized here, do we need to return a dictionary? Because I thought we wanted to get rid of all the dictionaries.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @camillebrianceau, thanks for the review. You're right, I changed that to put the arguments of the loss function in a config class, rather than a dict.
Even if the loss is initialized, we need the config class to store the config in the MAPS.
The difference between the input and the output config class is that the values set to "DefaultFromLibrary" have been changed with their effective values. It is not essential but better for reproducibility I think (for example in the - unlikely - event that PyTorch changes some default values).
Tell me what do you think!