diff --git a/clinicadl/optim/__init__.py b/clinicadl/optim/__init__.py index 6715835a5..185b1c418 100644 --- a/clinicadl/optim/__init__.py +++ b/clinicadl/optim/__init__.py @@ -1 +1,4 @@ from .config import OptimizationConfig +from .early_stopping import EarlyStopping +from .lr_scheduler import create_lr_scheduler_config, get_lr_scheduler +from .optimizer import create_optimizer_config, get_optimizer diff --git a/clinicadl/optim/lr_scheduler/__init__.py b/clinicadl/optim/lr_scheduler/__init__.py index 9c5fabd2c..c26899e69 100644 --- a/clinicadl/optim/lr_scheduler/__init__.py +++ b/clinicadl/optim/lr_scheduler/__init__.py @@ -1,2 +1,3 @@ -from .config import ImplementedLRScheduler, LRSchedulerConfig +from .config import create_lr_scheduler_config +from .enum import ImplementedLRScheduler from .factory import get_lr_scheduler diff --git a/clinicadl/optim/lr_scheduler/config.py b/clinicadl/optim/lr_scheduler/config.py index 93fb3d9e1..073f92db7 100644 --- a/clinicadl/optim/lr_scheduler/config.py +++ b/clinicadl/optim/lr_scheduler/config.py @@ -1,7 +1,4 @@ -from __future__ import annotations - -from enum import Enum -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Type, Union from pydantic import ( BaseModel, @@ -10,71 +7,43 @@ NonNegativeInt, PositiveFloat, PositiveInt, + computed_field, field_validator, - model_validator, ) from clinicadl.utils.factories import DefaultFromLibrary +from .enum import ImplementedLRScheduler, Mode, ThresholdMode -class ImplementedLRScheduler(str, Enum): - """Implemented LR schedulers in ClinicaDL.""" - - CONSTANT = "ConstantLR" - LINEAR = "LinearLR" - STEP = "StepLR" - MULTI_STEP = "MultiStepLR" - PLATEAU = "ReduceLROnPlateau" - - @classmethod - def _missing_(cls, value): - raise ValueError( - f"{value} is not implemented. Implemented LR schedulers are: " - + ", ".join([repr(m.value) for m in cls]) - ) - - -class Mode(str, Enum): - """Supported mode for ReduceLROnPlateau.""" - - MIN = "min" - MAX = "max" - - -class ThresholdMode(str, Enum): - """Supported threshold mode for ReduceLROnPlateau.""" - - ABS = "abs" - REL = "rel" +__all__ = [ + "LRSchedulerConfig", + "ConstantLRConfig", + "LinearLRConfig", + "StepLRConfig", + "MultiStepLRConfig", + "ReduceLROnPlateauConfig", + "create_lr_scheduler_config", +] class LRSchedulerConfig(BaseModel): - """Config class to configure the optimizer.""" + """Base config class for the LR scheduler.""" - scheduler: Optional[ImplementedLRScheduler] = None - step_size: Optional[PositiveInt] = None gamma: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES - milestones: Optional[List[PositiveInt]] = None factor: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES - start_factor: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES - end_factor: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES total_iters: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES last_epoch: Union[int, DefaultFromLibrary] = DefaultFromLibrary.YES - - mode: Union[Mode, DefaultFromLibrary] = DefaultFromLibrary.YES - patience: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES - threshold: Union[NonNegativeFloat, DefaultFromLibrary] = DefaultFromLibrary.YES - threshold_mode: Union[ThresholdMode, DefaultFromLibrary] = DefaultFromLibrary.YES - cooldown: Union[NonNegativeInt, DefaultFromLibrary] = DefaultFromLibrary.YES - min_lr: Union[ - NonNegativeFloat, Dict[str, PositiveFloat], DefaultFromLibrary - ] = DefaultFromLibrary.YES - eps: Union[NonNegativeFloat, DefaultFromLibrary] = DefaultFromLibrary.YES # pydantic config model_config = ConfigDict( validate_assignment=True, use_enum_values=True, validate_default=True ) + @computed_field + @property + def scheduler(self) -> Optional[ImplementedLRScheduler]: + """The name of the scheduler.""" + return None + @field_validator("last_epoch") @classmethod def validator_last_epoch(cls, v): @@ -84,31 +53,108 @@ def validator_last_epoch(cls, v): ), f"last_epoch must be -1 or a non-negative int but it has been set to {v}." return v - @field_validator("milestones") + +class ConstantLRConfig(LRSchedulerConfig): + """Config class for ConstantLR scheduler.""" + + @computed_field + @property + def scheduler(self) -> ImplementedLRScheduler: + """The name of the scheduler.""" + return ImplementedLRScheduler.CONSTANT + + +class LinearLRConfig(LRSchedulerConfig): + """Config class for LinearLR scheduler.""" + + start_factor: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES + end_factor: Union[PositiveFloat, DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def scheduler(self) -> ImplementedLRScheduler: + """The name of the scheduler.""" + return ImplementedLRScheduler.LINEAR + + +class StepLRConfig(LRSchedulerConfig): + """Config class for StepLR scheduler.""" + + step_size: PositiveInt + + @computed_field + @property + def scheduler(self) -> ImplementedLRScheduler: + """The name of the scheduler.""" + return ImplementedLRScheduler.STEP + + +class MultiStepLRConfig(LRSchedulerConfig): + """Config class for MultiStepLR scheduler.""" + + milestones: List[PositiveInt] + + @computed_field + @property + def scheduler(self) -> ImplementedLRScheduler: + """The name of the scheduler.""" + return ImplementedLRScheduler.MULTI_STEP + + @field_validator("milestones", mode="after") @classmethod def validator_milestones(cls, v): import numpy as np - if v is not None: - assert len(np.unique(v)) == len( - v - ), "Epoch(s) in milestones should be unique." - return sorted(v) - return v + assert len(np.unique(v)) == len(v), "Epoch(s) in milestones should be unique." + return sorted(v) + + +class ReduceLROnPlateauConfig(LRSchedulerConfig): + """Config class for ReduceLROnPlateau scheduler.""" + + mode: Union[Mode, DefaultFromLibrary] = DefaultFromLibrary.YES + patience: Union[PositiveInt, DefaultFromLibrary] = DefaultFromLibrary.YES + threshold: Union[NonNegativeFloat, DefaultFromLibrary] = DefaultFromLibrary.YES + threshold_mode: Union[ThresholdMode, DefaultFromLibrary] = DefaultFromLibrary.YES + cooldown: Union[NonNegativeInt, DefaultFromLibrary] = DefaultFromLibrary.YES + min_lr: Union[ + NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary + ] = DefaultFromLibrary.YES + eps: Union[NonNegativeFloat, DefaultFromLibrary] = DefaultFromLibrary.YES - @model_validator(mode="after") - def check_mandatory_args(self) -> LRSchedulerConfig: - if ( - self.scheduler == ImplementedLRScheduler.MULTI_STEP - and self.milestones is None - ): - raise ValueError( - """If you chose MultiStepLR as LR scheduler, you should pass milestones - (see PyTorch documentation for more details).""" - ) - elif self.scheduler == ImplementedLRScheduler.STEP and self.step_size is None: - raise ValueError( - """If you chose StepLR as LR scheduler, you should pass a step_size - (see PyTorch documentation for more details).""" - ) - return self + @property + def scheduler(self) -> ImplementedLRScheduler: + """The name of the scheduler.""" + return ImplementedLRScheduler.PLATEAU + + +def create_lr_scheduler_config( + scheduler: Optional[Union[str, ImplementedLRScheduler]], +) -> Type[LRSchedulerConfig]: + """ + A factory function to create a config class suited for the LR scheduler. + + Parameters + ---------- + scheduler : Optional[Union[str, ImplementedLRScheduler]] + The name of the LR scheduler. + Can be None if no LR scheduler will be used. + + Returns + ------- + Type[LRSchedulerConfig] + The config class. + + Raises + ------ + ValueError + If `scheduler` is not supported. + """ + if scheduler is None: + return LRSchedulerConfig + + scheduler = ImplementedLRScheduler(scheduler) + config_name = "".join([scheduler, "Config"]) + config = globals()[config_name] + + return config diff --git a/clinicadl/optim/lr_scheduler/enum.py b/clinicadl/optim/lr_scheduler/enum.py new file mode 100644 index 000000000..a70bb1801 --- /dev/null +++ b/clinicadl/optim/lr_scheduler/enum.py @@ -0,0 +1,32 @@ +from enum import Enum + + +class ImplementedLRScheduler(str, Enum): + """Implemented LR schedulers in ClinicaDL.""" + + CONSTANT = "ConstantLR" + LINEAR = "LinearLR" + STEP = "StepLR" + MULTI_STEP = "MultiStepLR" + PLATEAU = "ReduceLROnPlateau" + + @classmethod + def _missing_(cls, value): + raise ValueError( + f"{value} is not implemented. Implemented LR schedulers are: " + + ", ".join([repr(m.value) for m in cls]) + ) + + +class Mode(str, Enum): + """Supported mode for ReduceLROnPlateau.""" + + MIN = "min" + MAX = "max" + + +class ThresholdMode(str, Enum): + """Supported threshold mode for ReduceLROnPlateau.""" + + ABS = "abs" + REL = "rel" diff --git a/clinicadl/optim/lr_scheduler/factory.py b/clinicadl/optim/lr_scheduler/factory.py index a26948deb..2eeaccd55 100644 --- a/clinicadl/optim/lr_scheduler/factory.py +++ b/clinicadl/optim/lr_scheduler/factory.py @@ -56,6 +56,6 @@ def get_lr_scheduler( config_dict_["min_lr"].append(default_min_lr) scheduler = scheduler_class(optimizer, **config_dict_) - updated_config = LRSchedulerConfig(scheduler=config.scheduler, **config_dict) + updated_config = config.model_copy(update=config_dict) return scheduler, updated_config diff --git a/clinicadl/optim/optimizer/__init__.py b/clinicadl/optim/optimizer/__init__.py index 2c9cce3ba..504c60999 100644 --- a/clinicadl/optim/optimizer/__init__.py +++ b/clinicadl/optim/optimizer/__init__.py @@ -1,2 +1,3 @@ -from .config import ImplementedOptimizer, OptimizerConfig +from .config import create_optimizer_config +from .enum import ImplementedOptimizer from .factory import get_optimizer diff --git a/clinicadl/optim/optimizer/config.py b/clinicadl/optim/optimizer/config.py index 46aa5958c..b0a55f034 100644 --- a/clinicadl/optim/optimizer/config.py +++ b/clinicadl/optim/optimizer/config.py @@ -1,38 +1,32 @@ -from enum import Enum -from typing import Dict, List, Optional, Tuple, Union +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Tuple, Type, Union from pydantic import ( BaseModel, ConfigDict, NonNegativeFloat, PositiveFloat, + computed_field, field_validator, ) from clinicadl.utils.factories import DefaultFromLibrary +from .enum import ImplementedOptimizer -class ImplementedOptimizer(str, Enum): - """Implemented optimizers in ClinicaDL.""" +__all__ = [ + "OptimizerConfig", + "AdadeltaConfig", + "AdagradConfig", + "AdamConfig", + "RMSpropConfig", + "SGDConfig", + "create_optimizer_config", +] - ADADELTA = "Adadelta" - ADAGRAD = "Adagrad" - ADAM = "Adam" - RMS_PROP = "RMSprop" - SGD = "SGD" - @classmethod - def _missing_(cls, value): - raise ValueError( - f"{value} is not implemented. Implemented optimizers are: " - + ", ".join([repr(m.value) for m in cls]) - ) - - -class OptimizerConfig(BaseModel): - """Config class to configure the optimizer.""" - - optimizer: ImplementedOptimizer = ImplementedOptimizer.ADAM +class OptimizerConfig(BaseModel, ABC): + """Base config class for the optimizer.""" lr: Union[ PositiveFloat, Dict[str, PositiveFloat], DefaultFromLibrary @@ -40,36 +34,9 @@ class OptimizerConfig(BaseModel): weight_decay: Union[ NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary ] = DefaultFromLibrary.YES - betas: Union[ - Tuple[NonNegativeFloat, NonNegativeFloat], - Dict[str, Tuple[NonNegativeFloat, NonNegativeFloat]], - DefaultFromLibrary, - ] = DefaultFromLibrary.YES - alpha: Union[ - NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary - ] = DefaultFromLibrary.YES - momentum: Union[ - NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary - ] = DefaultFromLibrary.YES - rho: Union[ - NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary - ] = DefaultFromLibrary.YES - lr_decay: Union[ - NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary - ] = DefaultFromLibrary.YES eps: Union[ NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary ] = DefaultFromLibrary.YES - dampening: Union[ - NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary - ] = DefaultFromLibrary.YES - initial_accumulator_value: Union[ - NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary - ] = DefaultFromLibrary.YES - - centered: Union[bool, Dict[str, bool], DefaultFromLibrary] = DefaultFromLibrary.YES - nesterov: Union[bool, Dict[str, bool], DefaultFromLibrary] = DefaultFromLibrary.YES - amsgrad: Union[bool, Dict[str, bool], DefaultFromLibrary] = DefaultFromLibrary.YES foreach: Union[ Optional[bool], Dict[str, Optional[bool]], DefaultFromLibrary ] = DefaultFromLibrary.YES @@ -81,14 +48,19 @@ class OptimizerConfig(BaseModel): bool, Dict[str, bool], DefaultFromLibrary ] = DefaultFromLibrary.YES fused: Union[ - Optional[bool], Dict[str, bool], DefaultFromLibrary + Optional[bool], Dict[str, Optional[bool]], DefaultFromLibrary ] = DefaultFromLibrary.YES # pydantic config model_config = ConfigDict( validate_assignment=True, use_enum_values=True, validate_default=True ) - @field_validator("betas", "rho", "alpha", "dampening") + @computed_field + @property + @abstractmethod + def optimizer(self) -> ImplementedOptimizer: + """The name of the optimizer.""" + @classmethod def validator_proba(cls, v, ctx): name = ctx.field_name @@ -128,3 +100,131 @@ def get_all_groups(self) -> List[str]: groups.update(set(value.keys())) return list(groups) + + +class AdadeltaConfig(OptimizerConfig): + """Config class for Adadelta optimizer.""" + + rho: Union[ + NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary + ] = DefaultFromLibrary.YES + + @computed_field + @property + def optimizer(self) -> ImplementedOptimizer: + """The name of the optimizer.""" + return ImplementedOptimizer.ADADELTA + + @field_validator("rho") + def validator_rho(cls, v, ctx): + return cls.validator_proba(v, ctx) + + +class AdagradConfig(OptimizerConfig): + """Config class for Adagrad optimizer.""" + + lr_decay: Union[ + NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary + ] = DefaultFromLibrary.YES + initial_accumulator_value: Union[ + NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary + ] = DefaultFromLibrary.YES + + @computed_field + @property + def optimizer(self) -> ImplementedOptimizer: + """The name of the optimizer.""" + return ImplementedOptimizer.ADAGRAD + + +class AdamConfig(OptimizerConfig): + """Config class for Adam optimizer.""" + + betas: Union[ + Tuple[NonNegativeFloat, NonNegativeFloat], + Dict[str, Tuple[NonNegativeFloat, NonNegativeFloat]], + DefaultFromLibrary, + ] = DefaultFromLibrary.YES + amsgrad: Union[bool, Dict[str, bool], DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def optimizer(self) -> ImplementedOptimizer: + """The name of the optimizer.""" + return ImplementedOptimizer.ADAM + + @field_validator("betas") + def validator_betas(cls, v, ctx): + return cls.validator_proba(v, ctx) + + +class RMSpropConfig(OptimizerConfig): + """Config class for RMSprop optimizer.""" + + alpha: Union[ + NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary + ] = DefaultFromLibrary.YES + momentum: Union[ + NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary + ] = DefaultFromLibrary.YES + centered: Union[bool, Dict[str, bool], DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def optimizer(self) -> ImplementedOptimizer: + """The name of the optimizer.""" + return ImplementedOptimizer.RMS_PROP + + @field_validator("alpha") + def validator_alpha(cls, v, ctx): + return cls.validator_proba(v, ctx) + + +class SGDConfig(OptimizerConfig): + """Config class for SGD optimizer.""" + + momentum: Union[ + NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary + ] = DefaultFromLibrary.YES + dampening: Union[ + NonNegativeFloat, Dict[str, NonNegativeFloat], DefaultFromLibrary + ] = DefaultFromLibrary.YES + nesterov: Union[bool, Dict[str, bool], DefaultFromLibrary] = DefaultFromLibrary.YES + + @computed_field + @property + def optimizer(self) -> ImplementedOptimizer: + """The name of the optimizer.""" + return ImplementedOptimizer.SGD + + @field_validator("dampening") + def validator_dampening(cls, v, ctx): + return cls.validator_proba(v, ctx) + + +def create_optimizer_config( + optimizer: Union[str, ImplementedOptimizer], +) -> Type[OptimizerConfig]: + """ + A factory function to create a config class suited for the optimizer. + + Parameters + ---------- + optimizer : Union[str, ImplementedOptimizer] + The name of the optimizer. + + Returns + ------- + Type[OptimizerConfig] + The config class. + + Raises + ------ + ValueError + If `optimizer` is not supported. + """ + optimizer = ImplementedOptimizer(optimizer) + config_name = "".join([optimizer, "Config"]) + config = globals()[config_name] + + return config diff --git a/clinicadl/optim/optimizer/enum.py b/clinicadl/optim/optimizer/enum.py new file mode 100644 index 000000000..a397dbe85 --- /dev/null +++ b/clinicadl/optim/optimizer/enum.py @@ -0,0 +1,18 @@ +from enum import Enum + + +class ImplementedOptimizer(str, Enum): + """Implemented optimizers in ClinicaDL.""" + + ADADELTA = "Adadelta" + ADAGRAD = "Adagrad" + ADAM = "Adam" + RMS_PROP = "RMSprop" + SGD = "SGD" + + @classmethod + def _missing_(cls, value): + raise ValueError( + f"{value} is not implemented. Implemented optimizers are: " + + ", ".join([repr(m.value) for m in cls]) + ) diff --git a/clinicadl/optim/optimizer/factory.py b/clinicadl/optim/optimizer/factory.py index 3afd6a848..3123781ed 100644 --- a/clinicadl/optim/optimizer/factory.py +++ b/clinicadl/optim/optimizer/factory.py @@ -1,5 +1,6 @@ -from typing import Any, Dict, Tuple +from typing import Any, Dict, Iterable, Iterator, List, Tuple +import torch import torch.nn as nn import torch.optim as optim @@ -31,6 +32,11 @@ def get_optimizer( The updated config class: the arguments set to default will be updated with their effective values (the default values from the library). Useful for reproducibility. + + Raises + ------ + AttributeError + If a parameter group mentioned in the config class cannot be found in the network. """ optimizer_class = getattr(optim, config.optimizer) expected_args, default_args = get_args_and_defaults(optimizer_class.__init__) @@ -58,7 +64,7 @@ def get_optimizer( list_args_groups.append({"params": other_params}) optimizer = optimizer_class(list_args_groups, **args_global) - updated_config = OptimizerConfig(optimizer=config.optimizer, **default_args) + updated_config = config.model_copy(update=default_args) return optimizer, updated_config @@ -119,3 +125,90 @@ def _regroup_args( args_global[arg] = value return args_groups, args_global + + +def _get_params_in_group( + network: nn.Module, group: str +) -> Tuple[Iterator[torch.Tensor], List[str]]: + """ + Gets the parameters of a specific group of a neural network. + + Parameters + ---------- + network : nn.Module + The neural network. + group : str + The name of the group, e.g. a layer or a block. + If it is a sub-block, the hierarchy should be + specified with "." (see examples). + Will work even if the group is reduced to a base layer + (e.g. group = "dense.weight" or "dense.bias"). + + Returns + ------- + Iterator[torch.Tensor] + A generator that contains the parameters of the group. + List[str] + The name of all the parameters in the group. + + Raises + ------ + AttributeError + If `group` cannot be found in the network. + + Examples + -------- + >>> net = nn.Sequential( + OrderedDict( + [ + ("conv1", nn.Conv2d(1, 1, kernel_size=3)), + ("final", nn.Sequential(OrderedDict([("dense1", nn.Linear(10, 10))]))), + ] + ) + ) + >>> generator, params_names = _get_params_in_group(network, "final.dense1") + >>> params_names + ["final.dense1.weight", "final.dense1.bias"] + """ + group_hierarchy = group.split(".") + for name in group_hierarchy: + try: + network = getattr(network, name) + except AttributeError as exc: + raise AttributeError( + f"There is no such group as {group} in the network." + ) from exc + + try: + params = network.parameters() + params_names = [ + ".".join([group, name]) for name, _ in network.named_parameters() + ] + except AttributeError: # we already reached params + params = (param for param in [network]) + params_names = [group] + + return params, params_names + + +def _get_params_not_in_group( + network: nn.Module, group: Iterable[str] +) -> Iterator[torch.Tensor]: + """ + Finds the parameters of a neural networks that + are not in a group. + + Parameters + ---------- + network : nn.Module + The neural network. + group : List[str] + The group of parameters. + + Returns + ------- + Iterator[torch.Tensor] + A generator of all the parameters that are not in the input + group. + """ + return (param[1] for param in network.named_parameters() if param[0] not in group) diff --git a/tests/unittests/optim/lr_scheduler/test_config.py b/tests/unittests/optim/lr_scheduler/test_config.py index 270ccc27c..dbf96ccc8 100644 --- a/tests/unittests/optim/lr_scheduler/test_config.py +++ b/tests/unittests/optim/lr_scheduler/test_config.py @@ -1,37 +1,114 @@ import pytest from pydantic import ValidationError -from clinicadl.optim.lr_scheduler import LRSchedulerConfig - - -def test_LRSchedulerConfig(): - config = LRSchedulerConfig( - scheduler="ReduceLROnPlateau", - mode="max", - patience=1, - threshold_mode="rel", - milestones=[4, 3, 2], - min_lr={"param_0": 1e-1, "ELSE": 1e-2}, - ) - assert config.scheduler == "ReduceLROnPlateau" - assert config.mode == "max" - assert config.patience == 1 - assert config.threshold_mode == "rel" - assert config.milestones == [2, 3, 4] - assert config.min_lr == {"param_0": 1e-1, "ELSE": 1e-2} - assert config.threshold == "DefaultFromLibrary" +from clinicadl.optim.lr_scheduler.config import ( + ConstantLRConfig, + LinearLRConfig, + MultiStepLRConfig, + ReduceLROnPlateauConfig, + StepLRConfig, + create_lr_scheduler_config, +) +BAD_INPUTS = { + "milestones": [3, 2, 4], + "gamma": 0, + "last_epoch": -2, + "step_size": 0, + "factor": 0, + "total_iters": 0, + "start_factor": 0, + "end_factor": 0, + "mode": "abc", + "patience": 0, + "threshold": -1, + "threshold_mode": "abc", + "cooldown": -1, + "eps": -0.1, + "min_lr": -0.1, +} + +GOOD_INPUTS = { + "milestones": [1, 4, 5], + "gamma": 0.1, + "last_epoch": -1, + "step_size": 1, + "factor": 0.1, + "total_iters": 1, + "start_factor": 0.1, + "end_factor": 0.2, + "mode": "min", + "patience": 1, + "threshold": 0, + "threshold_mode": "abs", + "cooldown": 0, + "eps": 0, + "min_lr": 0, +} + + +@pytest.mark.parametrize( + "config", + [ + ConstantLRConfig, + LinearLRConfig, + MultiStepLRConfig, + ReduceLROnPlateauConfig, + StepLRConfig, + ], +) +def test_validation_fail(config): + fields = config.model_fields + inputs = {key: value for key, value in BAD_INPUTS.items() if key in fields} with pytest.raises(ValidationError): - LRSchedulerConfig(last_epoch=-2) - with pytest.raises(ValueError): - LRSchedulerConfig(scheduler="abc") - with pytest.raises(ValueError): - LRSchedulerConfig(mode="abc") - with pytest.raises(ValueError): - LRSchedulerConfig(threshold_mode="abc") - with pytest.raises(ValidationError): - LRSchedulerConfig(milestones=[10, 10]) - with pytest.raises(ValidationError): - LRSchedulerConfig(scheduler="MultiStepLR") + config(**inputs) + + # test dict inputs for min_lr + if "min_lr" in inputs: + inputs["min_lr"] = {"group_1": inputs["min_lr"]} + with pytest.raises(ValidationError): + config(**inputs) + + +def test_validation_fail_special(): with pytest.raises(ValidationError): - LRSchedulerConfig(scheduler="StepLR") + MultiStepLRConfig(milestones=[0, 1]) + + +@pytest.mark.parametrize( + "config", + [ + ConstantLRConfig, + LinearLRConfig, + MultiStepLRConfig, + ReduceLROnPlateauConfig, + StepLRConfig, + ], +) +def test_validation_pass(config): + fields = config.model_fields + inputs = {key: value for key, value in GOOD_INPUTS.items() if key in fields} + c = config(**inputs) + for arg, value in inputs.items(): + assert getattr(c, arg) == value + + # test dict inputs + if "min_lr" in inputs: + inputs["min_lr"] = {"group_1": inputs["min_lr"]} + c = config(**inputs) + assert getattr(c, "min_lr") == inputs["min_lr"] + + +@pytest.mark.parametrize( + "name,expected_class", + [ + ("ConstantLR", ConstantLRConfig), + ("LinearLR", LinearLRConfig), + ("MultiStepLR", MultiStepLRConfig), + ("ReduceLROnPlateau", ReduceLROnPlateauConfig), + ("StepLR", StepLRConfig), + ], +) +def test_create_optimizer_config(name, expected_class): + config = create_lr_scheduler_config(name) + assert config == expected_class diff --git a/tests/unittests/optim/lr_scheduler/test_factory.py b/tests/unittests/optim/lr_scheduler/test_factory.py index 76df845cd..cffb3d138 100644 --- a/tests/unittests/optim/lr_scheduler/test_factory.py +++ b/tests/unittests/optim/lr_scheduler/test_factory.py @@ -6,7 +6,7 @@ from clinicadl.optim.lr_scheduler import ( ImplementedLRScheduler, - LRSchedulerConfig, + create_lr_scheduler_config, get_lr_scheduler, ) @@ -37,17 +37,12 @@ def test_get_lr_scheduler(): lr=10.0, ) - for scheduler in [e.value for e in ImplementedLRScheduler]: - if scheduler == "StepLR": - config = LRSchedulerConfig(scheduler=scheduler, step_size=1) - elif scheduler == "MultiStepLR": - config = LRSchedulerConfig(scheduler=scheduler, milestones=[1, 2, 3]) - else: - config = LRSchedulerConfig(scheduler=scheduler) - get_lr_scheduler(optimizer, config) + args = {"step_size": 1, "milestones": [1, 2]} + for scheduler in ImplementedLRScheduler: + config = create_lr_scheduler_config(scheduler=scheduler)(**args) + _ = get_lr_scheduler(optimizer, config) - config = LRSchedulerConfig( - scheduler="ReduceLROnPlateau", + config = create_lr_scheduler_config(scheduler="ReduceLROnPlateau")( mode="max", factor=0.123, threshold=1e-1, @@ -83,7 +78,8 @@ def test_get_lr_scheduler(): scheduler, updated_config = get_lr_scheduler(optimizer, config) assert scheduler.min_lrs == [1.0, 1.0, 1.0] - config = LRSchedulerConfig() + # no lr scheduler + config = create_lr_scheduler_config(None)() scheduler, updated_config = get_lr_scheduler(optimizer, config) assert isinstance(scheduler, LambdaLR) assert updated_config.scheduler is None diff --git a/tests/unittests/optim/optimizer/test_config.py b/tests/unittests/optim/optimizer/test_config.py index e162fe58d..bf1dbcd8f 100644 --- a/tests/unittests/optim/optimizer/test_config.py +++ b/tests/unittests/optim/optimizer/test_config.py @@ -1,34 +1,118 @@ import pytest from pydantic import ValidationError -from clinicadl.optim.optimizer import OptimizerConfig - - -def test_OptimizerConfig(): - config = OptimizerConfig( - optimizer="SGD", - lr=1e-3, - weight_decay={"param_0": 1e-3, "param_1": 1e-2}, - momentum={"param_1": 1e-1}, - lr_decay=1e-4, - ) - assert config.optimizer == "SGD" - assert config.lr == 1e-3 - assert config.weight_decay == {"param_0": 1e-3, "param_1": 1e-2} - assert config.momentum == {"param_1": 1e-1} - assert config.lr_decay == 1e-4 - assert config.alpha == "DefaultFromLibrary" - assert sorted(config.get_all_groups()) == ["param_0", "param_1"] +from clinicadl.optim.optimizer.config import ( + AdadeltaConfig, + AdagradConfig, + AdamConfig, + RMSpropConfig, + SGDConfig, + create_optimizer_config, +) +BAD_INPUTS = { + "lr": 0, + "rho": 1.1, + "eps": -0.1, + "weight_decay": -0.1, + "lr_decay": -0.1, + "initial_accumulator_value": -0.1, + "betas": (0.9, 1.0), + "alpha": 1.1, + "momentum": -0.1, + "dampening": 0.1, +} + +GOOD_INPUTS_1 = { + "lr": 0.1, + "rho": 0, + "eps": 0, + "weight_decay": 0, + "foreach": None, + "capturable": False, + "maximize": True, + "differentiable": False, + "fused": None, + "lr_decay": 0, + "initial_accumulator_value": 0, + "betas": (0.0, 0.0), + "amsgrad": True, + "alpha": 0.0, + "momentum": 0, + "centered": True, + "dampening": 0, + "nesterov": True, +} + +GOOD_INPUTS_2 = { + "foreach": True, + "fused": False, +} + + +@pytest.mark.parametrize( + "config", + [ + AdadeltaConfig, + AdagradConfig, + AdamConfig, + RMSpropConfig, + SGDConfig, + ], +) +def test_validation_fail(config): + fields = config.model_fields + inputs = {key: value for key, value in BAD_INPUTS.items() if key in fields} with pytest.raises(ValidationError): - OptimizerConfig(betas={"params_0": (0.9, 1.01), "params_1": (0.9, 0.99)}) - with pytest.raises(ValidationError): - OptimizerConfig(betas=0.9) - with pytest.raises(ValidationError): - OptimizerConfig(rho=1.01) - with pytest.raises(ValidationError): - OptimizerConfig(alpha=1.01) + config(**inputs) + + # test dict inputs + inputs = {key: {"group_1": value} for key, value in inputs.items()} with pytest.raises(ValidationError): - OptimizerConfig(dampening={"params_0": 0.1, "params_1": 2}) - with pytest.raises(ValueError): - OptimizerConfig(optimizer="abc") + config(**inputs) + + +@pytest.mark.parametrize( + "config", + [ + AdadeltaConfig, + AdagradConfig, + AdamConfig, + RMSpropConfig, + SGDConfig, + ], +) +@pytest.mark.parametrize( + "good_inputs", + [ + GOOD_INPUTS_1, + GOOD_INPUTS_2, + ], +) +def test_validation_pass(config, good_inputs): + fields = config.model_fields + inputs = {key: value for key, value in good_inputs.items() if key in fields} + c = config(**inputs) + for arg, value in inputs.items(): + assert getattr(c, arg) == value + + # test dict inputs + inputs = {key: {"group_1": value} for key, value in inputs.items()} + c = config(**inputs) + for arg, value in inputs.items(): + assert getattr(c, arg) == value + + +@pytest.mark.parametrize( + "name,expected_class", + [ + ("Adadelta", AdadeltaConfig), + ("Adagrad", AdagradConfig), + ("Adam", AdamConfig), + ("RMSprop", RMSpropConfig), + ("SGD", SGDConfig), + ], +) +def test_create_optimizer_config(name, expected_class): + config = create_optimizer_config(name) + assert config == expected_class diff --git a/tests/unittests/optim/optimizer/test_factory.py b/tests/unittests/optim/optimizer/test_factory.py index 7dbf149e9..47b44a00a 100644 --- a/tests/unittests/optim/optimizer/test_factory.py +++ b/tests/unittests/optim/optimizer/test_factory.py @@ -1,12 +1,25 @@ from collections import OrderedDict import pytest +import torch import torch.nn as nn +from torch.optim import Adagrad + +from clinicadl.optim.optimizer import ( + ImplementedOptimizer, + create_optimizer_config, + get_optimizer, +) +from clinicadl.optim.optimizer.factory import ( + _get_params_in_group, + _get_params_not_in_group, + _regroup_args, +) @pytest.fixture def network(): - network = nn.Sequential( + net = nn.Sequential( OrderedDict( [ ("conv1", nn.Conv2d(1, 1, kernel_size=3)), @@ -14,31 +27,22 @@ def network(): ] ) ) - network.add_module( + net.add_module( "final", nn.Sequential( OrderedDict([("dense2", nn.Linear(10, 5)), ("dense3", nn.Linear(5, 3))]) ), ) - return network + return net def test_get_optimizer(network): - from torch.optim import Adagrad - - from clinicadl.optim.optimizer import ( - ImplementedOptimizer, - OptimizerConfig, - get_optimizer, - ) - - for optimizer in [e.value for e in ImplementedOptimizer]: - config = OptimizerConfig(optimizer=optimizer) + for optimizer in ImplementedOptimizer: + config = create_optimizer_config(optimizer=optimizer)() optimizer, _ = get_optimizer(network, config) assert len(optimizer.param_groups) == 1 - config = OptimizerConfig( - optimizer="Adagrad", + config = create_optimizer_config("Adagrad")( lr=1e-5, weight_decay={"final.dense3.weight": 1.0, "dense1": 0.1}, lr_decay={"dense1": 10, "ELSE": 100}, @@ -82,28 +86,23 @@ def test_get_optimizer(network): assert not updated_config.maximize assert not updated_config.differentiable - # special case : only ELSE - config = OptimizerConfig( - optimizer="Adagrad", + # special cases 1 + config = create_optimizer_config("Adagrad")( lr_decay={"ELSE": 100}, ) optimizer, _ = get_optimizer(network, config) assert len(optimizer.param_groups) == 1 assert optimizer.param_groups[0]["lr_decay"] == 100 - # special case : the params mentioned form all the network - config = OptimizerConfig( - optimizer="Adagrad", + # special cases 2 + config = create_optimizer_config("Adagrad")( lr_decay={"conv1": 100, "dense1": 10, "final": 1}, ) optimizer, _ = get_optimizer(network, config) assert len(optimizer.param_groups) == 3 # special case : no ELSE mentioned - config = OptimizerConfig( - optimizer="Adagrad", - lr_decay={"conv1": 100}, - ) + config = create_optimizer_config("Adagrad")(lr_decay={"conv1": 100}) optimizer, _ = get_optimizer(network, config) assert len(optimizer.param_groups) == 2 assert optimizer.param_groups[0]["lr_decay"] == 100 @@ -111,8 +110,6 @@ def test_get_optimizer(network): def test_regroup_args(): - from clinicadl.optim.optimizer.factory import _regroup_args - args = { "weight_decay": {"params_0": 0.0, "params_1": 1.0}, "alpha": {"params_1": 0.5, "ELSE": 0.1}, @@ -134,3 +131,48 @@ def test_regroup_args(): {"weight_decay": {"params_0": 0.0, "params_1": 1.0}} ) assert len(args_global) == 0 + + +def test_get_params_in_block(network): + generator, list_layers = _get_params_in_group(network, "dense1") + assert next(iter(generator)).shape == torch.Size((10, 10)) + assert next(iter(generator)).shape == torch.Size((10,)) + assert sorted(list_layers) == sorted(["dense1.weight", "dense1.bias"]) + + generator, list_layers = _get_params_in_group(network, "dense1.weight") + assert next(iter(generator)).shape == torch.Size((10, 10)) + assert sum(1 for _ in generator) == 0 + assert sorted(list_layers) == sorted(["dense1.weight"]) + + generator, list_layers = _get_params_in_group(network, "final.dense3") + assert next(iter(generator)).shape == torch.Size((3, 5)) + assert next(iter(generator)).shape == torch.Size((3,)) + assert sorted(list_layers) == sorted(["final.dense3.weight", "final.dense3.bias"]) + + generator, list_layers = _get_params_in_group(network, "final") + assert sum(1 for _ in generator) == 4 + assert sorted(list_layers) == sorted( + [ + "final.dense2.weight", + "final.dense2.bias", + "final.dense3.weight", + "final.dense3.bias", + ] + ) + + +def test_find_params_not_in_group(network): + params = _get_params_not_in_group( + network, + [ + "final.dense2.weight", + "final.dense2.bias", + "conv1.bias", + "final.dense3.weight", + "dense1.weight", + "dense1.bias", + ], + ) + assert next(iter(params)).shape == torch.Size((1, 1, 3, 3)) + assert next(iter(params)).shape == torch.Size((3,)) + assert sum(1 for _ in params) == 0 # no more params