Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Review of loss module #655

Merged
merged 13 commits into from
Sep 23, 2024
3 changes: 2 additions & 1 deletion clinicadl/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .config import ClassificationLoss, ImplementedLoss, LossConfig
from .config import create_loss_config
from .enum import ClassificationLoss, ImplementedLoss
from .factory import get_loss_function
278 changes: 201 additions & 77 deletions clinicadl/losses/config.py
Original file line number Diff line number Diff line change
@@ -1,87 +1,87 @@
from enum import Enum
from typing import List, Optional, Union
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Type, Union

from pydantic import (
BaseModel,
ConfigDict,
NonNegativeFloat,
PositiveFloat,
computed_field,
field_validator,
model_validator,
)

from clinicadl.utils.enum import BaseEnum
from clinicadl.utils.factories import DefaultFromLibrary

from .enum import ImplementedLoss, Order, Reduction

class ClassificationLoss(str, BaseEnum):
"""Losses that can be used only for classification."""
__all__ = [
"LossConfig",
"NLLLossConfig",
"CrossEntropyLossConfig",
"BCELossConfig",
"BCEWithLogitsLossConfig",
"MultiMarginLossConfig",
"KLDivLossConfig",
"HuberLossConfig",
"SmoothL1LossConfig",
"L1LossConfig",
"MSELossConfig",
"create_loss_config",
]

CROSS_ENTROPY = "CrossEntropyLoss" # for multi-class classification, inputs are unormalized logits and targets are int (same dimension without the class channel)
MULTI_MARGIN = "MultiMarginLoss" # no particular restriction on the input, targets are int (same dimension without th class channel)
BCE = "BCELoss" # for binary classification, targets and inputs should be probabilities and have same shape
BCE_LOGITS = "BCEWithLogitsLoss" # for binary classification, targets should be probabilities and inputs logits, and have the same shape. More stable numerically

class LossConfig(BaseModel, ABC):
"""Base config class for the loss function."""

class ImplementedLoss(str, Enum):
"""Implemented losses in ClinicaDL."""

CROSS_ENTROPY = "CrossEntropyLoss"
MULTI_MARGIN = "MultiMarginLoss"
BCE = "BCELoss"
BCE_LOGITS = "BCEWithLogitsLoss"
L1 = "L1Loss"
MSE = "MSELoss"
HUBER = "HuberLoss"
SMOOTH_L1 = "SmoothL1Loss"
KLDIV = "KLDivLoss" # if log_target=False, target must be positive

@classmethod
def _missing_(cls, value):
raise ValueError(
f"{value} is not implemented. Implemented losses are: "
+ ", ".join([repr(m.value) for m in cls])
)
reduction: Union[Reduction, DefaultFromLibrary] = DefaultFromLibrary.YES
weight: Union[
Optional[List[NonNegativeFloat]], DefaultFromLibrary
] = DefaultFromLibrary.YES
# pydantic config
model_config = ConfigDict(
validate_assignment=True, use_enum_values=True, validate_default=True
)

@computed_field
@property
@abstractmethod
def loss(self) -> ImplementedLoss:
"""ImplementedLoss.e name of the loss."""

class Reduction(str, Enum):
"""Supported reduction method in ClinicaDL."""

MEAN = "mean"
SUM = "sum"
class NLLLossConfig(LossConfig):
"""Config class for Negative Log Likelihood loss."""

ignore_index: Union[int, DefaultFromLibrary] = DefaultFromLibrary.YES

class Order(int, Enum):
"""Supported order of L-norm for MultiMarginLoss."""
@computed_field
@property
def loss(self) -> ImplementedLoss:
"""The name of the loss."""
return ImplementedLoss.NLL

ONE = 1
TWO = 2
@field_validator("ignore_index")
@classmethod
def validator_ignore_index(cls, v):
if isinstance(v, int):
assert (
v == -100 or 0 <= v
), "ignore_index must be a positive int (or -100 when disabled)."
return v


class LossConfig(BaseModel):
"""Config class to configure the loss function."""
class CrossEntropyLossConfig(NLLLossConfig):
"""Config class for Cross Entropy loss."""

loss: ImplementedLoss = ImplementedLoss.MSE
reduction: Union[Reduction, DefaultFromLibrary] = DefaultFromLibrary.YES
delta: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES
beta: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES
p: Union[Order, DefaultFromLibrary] = DefaultFromLibrary.YES
margin: Union[float, DefaultFromLibrary] = DefaultFromLibrary.YES
weight: Union[
Optional[List[NonNegativeFloat]], DefaultFromLibrary
] = DefaultFromLibrary.YES # a weight for each class
ignore_index: Union[int, DefaultFromLibrary] = DefaultFromLibrary.YES
label_smoothing: Union[
NonNegativeFloat, DefaultFromLibrary
] = DefaultFromLibrary.YES
log_target: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES
pos_weight: Union[
Optional[List[NonNegativeFloat]], DefaultFromLibrary
] = DefaultFromLibrary.YES # a positive weight for each class
# pydantic config
model_config = ConfigDict(
validate_assignment=True, use_enum_values=True, validate_default=True
)

@computed_field
@property
def loss(self) -> ImplementedLoss:
"""The name of the loss."""
return ImplementedLoss.CROSS_ENTROPY

@field_validator("label_smoothing")
@classmethod
Expand All @@ -92,26 +92,150 @@ def validator_label_smoothing(cls, v):
), f"label_smoothing must be between 0 and 1 but it has been set to {v}."
return v

@field_validator("ignore_index")

class BCELossConfig(LossConfig):
"""Config class for Binary Cross Entropy loss."""

weight: Optional[List[NonNegativeFloat]] = None

@computed_field
@property
def loss(self) -> ImplementedLoss:
"""The name of the loss."""
return ImplementedLoss.BCE

@field_validator("weight")
@classmethod
def validator_ignore_index(cls, v):
if isinstance(v, int):
assert (
v == -100 or 0 <= v
), "ignore_index must be a positive int (or -100 when disabled)."
def validator_weight(cls, v):
if v is not None:
raise ValueError(
"Cannot use weight with BCEWithLogitsLoss. If you want more flexibility, please use API mode."
)
return v

@model_validator(mode="after")
def model_validator(self):
if (
self.loss == ImplementedLoss.BCE_LOGITS
and self.weight is not None
and self.weight != DefaultFromLibrary.YES
):
raise ValueError("Cannot use weight with BCEWithLogitsLoss.")
elif (
self.loss == ImplementedLoss.BCE
and self.weight is not None
and self.weight != DefaultFromLibrary.YES
):
raise ValueError("Cannot use weight with BCELoss.")

class BCEWithLogitsLossConfig(BCELossConfig):
"""Config class for Binary Cross Entropy With Logits loss."""

pos_weight: Union[Optional[List[Any]], DefaultFromLibrary] = DefaultFromLibrary.YES

@computed_field
@property
def loss(self) -> ImplementedLoss:
"""The name of the loss."""
return ImplementedLoss.BCE_LOGITS

@field_validator("pos_weight")
@classmethod
def validator_pos_weight(cls, v):
if isinstance(v, list):
check = cls._recursive_float_check(v)
if not check:
raise ValueError(
f"elements in pos_weight must be non-negative float, got: {v}"
)
return v

@classmethod
def _recursive_float_check(cls, item):
if isinstance(item, list):
return all(cls._recursive_float_check(i) for i in item)
else:
return (isinstance(item, float) or isinstance(item, int)) and item >= 0


class MultiMarginLossConfig(LossConfig):
"""Config class for Multi Margin loss."""

p: Union[Order, DefaultFromLibrary] = DefaultFromLibrary.YES
margin: Union[float, DefaultFromLibrary] = DefaultFromLibrary.YES

@computed_field
@property
def loss(self) -> ImplementedLoss:
"""The name of the loss."""
return ImplementedLoss.MULTI_MARGIN


class KLDivLossConfig(LossConfig):
"""Config class for Kullback-Leibler Divergence loss."""

log_target: Union[bool, DefaultFromLibrary] = DefaultFromLibrary.YES

@computed_field
@property
def loss(self) -> ImplementedLoss:
"""The name of the loss."""
return ImplementedLoss.KLDIV


class HuberLossConfig(LossConfig):
"""Config class for Huber loss."""

delta: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES

@computed_field
@property
def loss(self) -> ImplementedLoss:
"""The name of the loss."""
return ImplementedLoss.HUBER


class SmoothL1LossConfig(LossConfig):
"""Config class for Smooth L1 loss."""

beta: Union[NonNegativeFloat, DefaultFromLibrary] = DefaultFromLibrary.YES

@computed_field
@property
def loss(self) -> ImplementedLoss:
"""The name of the loss."""
return ImplementedLoss.SMOOTH_L1


class L1LossConfig(LossConfig):
"""Config class for L1 loss."""

@computed_field
@property
def loss(self) -> ImplementedLoss:
"""The name of the loss."""
return ImplementedLoss.L1


class MSELossConfig(LossConfig):
"""Config class for Mean Squared Error loss."""

@computed_field
@property
def loss(self) -> ImplementedLoss:
"""The name of the loss."""
return ImplementedLoss.MSE


def create_loss_config(
loss: Union[str, ImplementedLoss],
) -> Type[LossConfig]:
"""
A factory function to create a config class suited for the loss.

Parameters
----------
loss : Union[str, ImplementedLoss]
The name of the loss.

Returns
-------
Type[LossConfig]
The config class.

Raises
------
ValueError
If `loss` is not supported.
"""
loss = ImplementedLoss(loss)
config_name = "".join([loss, "Config"])
config = globals()[config_name]

return config
50 changes: 50 additions & 0 deletions clinicadl/losses/enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from enum import Enum

from clinicadl.utils.enum import BaseEnum


class ClassificationLoss(str, BaseEnum):
"""Losses that can be used only for classification."""

CROSS_ENTROPY = "CrossEntropyLoss" # for multi-class classification, inputs are unormalized logits and targets are int (same dimension without the class channel)
NLL = "NLLLoss" # for multi-class classification, inputs are log-probabilities and targets are int (same dimension without the class channel)
MULTI_MARGIN = "MultiMarginLoss" # no particular restriction on the input, targets are int (same dimension without th class channel)
BCE = "BCELoss" # for binary classification, targets and inputs should be probabilities and have same shape
BCE_LOGITS = "BCEWithLogitsLoss" # for binary classification, targets should be probabilities and inputs logits, and have the same shape. More stable numerically


class ImplementedLoss(str, Enum):
"""Implemented losses in ClinicaDL."""

CROSS_ENTROPY = "CrossEntropyLoss"
NLL = "NLLLoss"
MULTI_MARGIN = "MultiMarginLoss"
BCE = "BCELoss"
BCE_LOGITS = "BCEWithLogitsLoss"

L1 = "L1Loss"
MSE = "MSELoss"
HUBER = "HuberLoss"
SMOOTH_L1 = "SmoothL1Loss"
KLDIV = "KLDivLoss" # if log_target=False, target must be positive

@classmethod
def _missing_(cls, value):
raise ValueError(
f"{value} is not implemented. Implemented losses are: "
+ ", ".join([repr(m.value) for m in cls])
)


class Reduction(str, Enum):
"""Supported reduction method in ClinicaDL."""

MEAN = "mean"
SUM = "sum"


class Order(int, Enum):
"""Supported order of L-norm for MultiMarginLoss."""

ONE = 1
TWO = 2
2 changes: 1 addition & 1 deletion clinicadl/losses/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,6 @@ def get_loss_function(config: LossConfig) -> Tuple[torch.nn.Module, LossConfig]:
config_dict_["pos_weight"] = torch.Tensor(config_dict_["pos_weight"])
loss = loss_class(**config_dict_)

updated_config = LossConfig(loss=config.loss, **config_dict)
updated_config = config.model_copy(update=config_dict)

return loss, updated_config
Loading
Loading