Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
florencejt committed May 15, 2024
1 parent 84af094 commit c208ab0
Showing 1 changed file with 35 additions and 26 deletions.
61 changes: 35 additions & 26 deletions tests/test_utils/test_training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import os
import tempfile
import lightning.pytorch as pl
from lightning.pytorch.callbacks import (EarlyStopping, ModelCheckpoint)
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch import Trainer
from torchmetrics import Accuracy, R2Score
import numpy as np
Expand Down Expand Up @@ -50,13 +50,11 @@ class SomeFusionModelClass:
fold = 1
extra_log_string_dict = {"param1": "value1", "param2": 42}

checkpoint_name = set_checkpoint_name(
fusion_model, fold, extra_log_string_dict
)
checkpoint_name = set_checkpoint_name(fusion_model, fold, extra_log_string_dict)

assert (
checkpoint_name
== "SomeFusionModelClass_fold_1_param1_value1_param2_42_{epoch:02d}"
checkpoint_name
== "SomeFusionModelClass_fold_1_param1_value1_param2_42_{epoch:02d}"
)


Expand All @@ -68,9 +66,7 @@ class SomeFusionModelClass:
fold = None
extra_log_string_dict = {"param1": "value1", "param2": 42}

checkpoint_name = set_checkpoint_name(
fusion_model, fold, extra_log_string_dict
)
checkpoint_name = set_checkpoint_name(fusion_model, fold, extra_log_string_dict)

assert checkpoint_name == "SomeFusionModelClass_param1_value1_param2_42_{epoch:02d}"

Expand All @@ -83,9 +79,7 @@ class SomeFusionModelClass:
fold = 2
extra_log_string_dict = None

checkpoint_name = set_checkpoint_name(
fusion_model, fold, extra_log_string_dict
)
checkpoint_name = set_checkpoint_name(fusion_model, fold, extra_log_string_dict)

assert checkpoint_name == "SomeFusionModelClass_fold_2_{epoch:02d}"

Expand Down Expand Up @@ -149,7 +143,8 @@ class SomeFusionModelClass:
k = None

checkpoint_filenames = get_checkpoint_filenames_for_subspace_models(
subspace_method, k)
subspace_method, k
)

expected_filenames = [
"subspace_SomeFusionModelClass_SubspaceModel1_key_value",
Expand Down Expand Up @@ -228,10 +223,10 @@ def test_get_checkpoint_filename_for_trained_fusion_model_not_found(params, mode

# Attempt to get a checkpoint filename when no matching file exists
with pytest.raises(
ValueError, match=r"Could not find checkpoint file with name .*"
ValueError, match=r"Could not find checkpoint file with name .*"
):
get_checkpoint_filename_for_trained_fusion_model(
params['checkpoint_dir'], model, checkpoint_file_suffix
params["checkpoint_dir"], model, checkpoint_file_suffix
)


Expand All @@ -253,7 +248,7 @@ def test_get_checkpoint_filename_for_trained_fusion_model_multiple_files(params,

# Attempt to get a checkpoint filename when multiple matching files exist
with pytest.raises(
ValueError, match=r"Found multiple checkpoint files with name .*"
ValueError, match=r"Found multiple checkpoint files with name .*"
):
get_checkpoint_filename_for_trained_fusion_model(
params["checkpoint_dir"], model, checkpoint_file_suffix
Expand All @@ -279,7 +274,10 @@ def mock_logger():

def test_init_trainer_default(mock_logger):
# Test initializing trainer with default parameters
trainer = init_trainer(mock_logger, output_paths={}, )
trainer = init_trainer(
mock_logger,
output_paths={},
)
assert trainer is not None
assert isinstance(trainer, Trainer)
assert trainer.max_epochs == 1000
Expand All @@ -289,14 +287,18 @@ def test_init_trainer_default(mock_logger):
assert trainer.checkpoint_callback is not None


@pytest.mark.filterwarnings("ignore:.*GPU available but not used*.", )
@pytest.mark.filterwarnings(
"ignore:.*GPU available but not used*.",
)
def test_init_trainer_custom_early_stopping(mock_logger):
# Test initializing trainer with a custom early stopping callback
# custom_early_stopping = Mock()
custom_early_stopping = EarlyStopping(monitor="val_loss",
patience=3,
verbose=True,
mode="max", )
custom_early_stopping = EarlyStopping(
monitor="val_loss",
patience=3,
verbose=True,
mode="max",
)
trainer = init_trainer(
mock_logger, output_paths={}, own_early_stopping_callback=custom_early_stopping
)
Expand All @@ -311,7 +313,10 @@ def test_init_trainer_custom_early_stopping(mock_logger):
assert isinstance(trainer.callbacks[0], EarlyStopping)
assert trainer.callbacks[0] == custom_early_stopping
for key in custom_early_stopping.__dict__:
assert custom_early_stopping.__dict__[key] == trainer.early_stopping_callback.__dict__[key]
assert (
custom_early_stopping.__dict__[key]
== trainer.early_stopping_callback.__dict__[key]
)

assert trainer.checkpoint_callback is not None

Expand Down Expand Up @@ -339,7 +344,11 @@ def test_init_trainer_with_accelerator_and_devices(mock_logger):
# Test initializing trainer with custom accelerator and devices

params = {"accelerator": "cpu", "devices": 3}
trainer = init_trainer(mock_logger, output_paths={}, training_modifications={"accelerator": "cpu", "devices": 3})
trainer = init_trainer(
mock_logger,
output_paths={},
training_modifications={"accelerator": "cpu", "devices": 3},
)

assert trainer is not None
assert isinstance(trainer, Trainer)
Expand Down Expand Up @@ -441,7 +450,7 @@ def __init__(self, model):

# Get the final validation metrics
with pytest.raises(
ValueError,
match=r"not in trainer.callback_metrics.keys()",
ValueError,
match=r"not in trainer.callback_metrics.keys()",
):
get_final_val_metrics(trainer)

0 comments on commit c208ab0

Please sign in to comment.