diff --git a/tests/test_utils/test_training_utils.py b/tests/test_utils/test_training_utils.py index 003485b..ab7415e 100644 --- a/tests/test_utils/test_training_utils.py +++ b/tests/test_utils/test_training_utils.py @@ -13,7 +13,7 @@ import os import tempfile import lightning.pytorch as pl -from lightning.pytorch.callbacks import (EarlyStopping, ModelCheckpoint) +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint from lightning.pytorch import Trainer from torchmetrics import Accuracy, R2Score import numpy as np @@ -50,13 +50,11 @@ class SomeFusionModelClass: fold = 1 extra_log_string_dict = {"param1": "value1", "param2": 42} - checkpoint_name = set_checkpoint_name( - fusion_model, fold, extra_log_string_dict - ) + checkpoint_name = set_checkpoint_name(fusion_model, fold, extra_log_string_dict) assert ( - checkpoint_name - == "SomeFusionModelClass_fold_1_param1_value1_param2_42_{epoch:02d}" + checkpoint_name + == "SomeFusionModelClass_fold_1_param1_value1_param2_42_{epoch:02d}" ) @@ -68,9 +66,7 @@ class SomeFusionModelClass: fold = None extra_log_string_dict = {"param1": "value1", "param2": 42} - checkpoint_name = set_checkpoint_name( - fusion_model, fold, extra_log_string_dict - ) + checkpoint_name = set_checkpoint_name(fusion_model, fold, extra_log_string_dict) assert checkpoint_name == "SomeFusionModelClass_param1_value1_param2_42_{epoch:02d}" @@ -83,9 +79,7 @@ class SomeFusionModelClass: fold = 2 extra_log_string_dict = None - checkpoint_name = set_checkpoint_name( - fusion_model, fold, extra_log_string_dict - ) + checkpoint_name = set_checkpoint_name(fusion_model, fold, extra_log_string_dict) assert checkpoint_name == "SomeFusionModelClass_fold_2_{epoch:02d}" @@ -149,7 +143,8 @@ class SomeFusionModelClass: k = None checkpoint_filenames = get_checkpoint_filenames_for_subspace_models( - subspace_method, k) + subspace_method, k + ) expected_filenames = [ "subspace_SomeFusionModelClass_SubspaceModel1_key_value", @@ -228,10 +223,10 @@ def test_get_checkpoint_filename_for_trained_fusion_model_not_found(params, mode # Attempt to get a checkpoint filename when no matching file exists with pytest.raises( - ValueError, match=r"Could not find checkpoint file with name .*" + ValueError, match=r"Could not find checkpoint file with name .*" ): get_checkpoint_filename_for_trained_fusion_model( - params['checkpoint_dir'], model, checkpoint_file_suffix + params["checkpoint_dir"], model, checkpoint_file_suffix ) @@ -253,7 +248,7 @@ def test_get_checkpoint_filename_for_trained_fusion_model_multiple_files(params, # Attempt to get a checkpoint filename when multiple matching files exist with pytest.raises( - ValueError, match=r"Found multiple checkpoint files with name .*" + ValueError, match=r"Found multiple checkpoint files with name .*" ): get_checkpoint_filename_for_trained_fusion_model( params["checkpoint_dir"], model, checkpoint_file_suffix @@ -279,7 +274,10 @@ def mock_logger(): def test_init_trainer_default(mock_logger): # Test initializing trainer with default parameters - trainer = init_trainer(mock_logger, output_paths={}, ) + trainer = init_trainer( + mock_logger, + output_paths={}, + ) assert trainer is not None assert isinstance(trainer, Trainer) assert trainer.max_epochs == 1000 @@ -289,14 +287,18 @@ def test_init_trainer_default(mock_logger): assert trainer.checkpoint_callback is not None -@pytest.mark.filterwarnings("ignore:.*GPU available but not used*.", ) +@pytest.mark.filterwarnings( + "ignore:.*GPU available but not used*.", +) def test_init_trainer_custom_early_stopping(mock_logger): # Test initializing trainer with a custom early stopping callback # custom_early_stopping = Mock() - custom_early_stopping = EarlyStopping(monitor="val_loss", - patience=3, - verbose=True, - mode="max", ) + custom_early_stopping = EarlyStopping( + monitor="val_loss", + patience=3, + verbose=True, + mode="max", + ) trainer = init_trainer( mock_logger, output_paths={}, own_early_stopping_callback=custom_early_stopping ) @@ -311,7 +313,10 @@ def test_init_trainer_custom_early_stopping(mock_logger): assert isinstance(trainer.callbacks[0], EarlyStopping) assert trainer.callbacks[0] == custom_early_stopping for key in custom_early_stopping.__dict__: - assert custom_early_stopping.__dict__[key] == trainer.early_stopping_callback.__dict__[key] + assert ( + custom_early_stopping.__dict__[key] + == trainer.early_stopping_callback.__dict__[key] + ) assert trainer.checkpoint_callback is not None @@ -339,7 +344,11 @@ def test_init_trainer_with_accelerator_and_devices(mock_logger): # Test initializing trainer with custom accelerator and devices params = {"accelerator": "cpu", "devices": 3} - trainer = init_trainer(mock_logger, output_paths={}, training_modifications={"accelerator": "cpu", "devices": 3}) + trainer = init_trainer( + mock_logger, + output_paths={}, + training_modifications={"accelerator": "cpu", "devices": 3}, + ) assert trainer is not None assert isinstance(trainer, Trainer) @@ -441,7 +450,7 @@ def __init__(self, model): # Get the final validation metrics with pytest.raises( - ValueError, - match=r"not in trainer.callback_metrics.keys()", + ValueError, + match=r"not in trainer.callback_metrics.keys()", ): get_final_val_metrics(trainer)