Skip to content

Commit

Permalink
Add docstrings to tests. Itk saving check for MetaTensor.
Browse files Browse the repository at this point in the history
  • Loading branch information
ibro45 committed Dec 11, 2024
1 parent 389aac3 commit f5db261
Show file tree
Hide file tree
Showing 12 changed files with 924 additions and 194 deletions.
8 changes: 5 additions & 3 deletions lighter/callbacks/writer/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch
import torchvision
from monai.data import metatensor_to_itk_image
from monai.data import MetaTensor, metatensor_to_itk_image
from monai.transforms import DivisiblePad
from torch import Tensor

Expand Down Expand Up @@ -110,15 +110,17 @@ def write_video(path, tensor):
torchvision.io.write_video(str(path), tensor, fps=24)


def write_itk_image(path: str, tensor: Tensor, suffix) -> None:
def write_itk_image(path: str, tensor: MetaTensor, suffix) -> None:
"""
Writes a tensor as an ITK image file.
Args:
path: The path to save the ITK image.
tensor: The tensor representing the image.
tensor: The tensor representing the image. Must be in MONAI MetaTensor format.
suffix: The file suffix indicating the format.
"""
path = path.with_suffix(suffix)
if not isinstance(tensor, MetaTensor):
raise TypeError("Tensor must be in MONAI MetaTensor format.")
itk_image = metatensor_to_itk_image(tensor, channel_dim=0, dtype=tensor.dtype)
OPTIONAL_IMPORTS["itk"].imwrite(itk_image, str(path), True)
90 changes: 74 additions & 16 deletions tests/unit/test_callbacks_freezer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@


class DummyModel(Module):
"""
A simple neural network model for testing purposes.
Contains three linear layers that can be selectively frozen during training.
"""

def __init__(self):
super().__init__()
self.layer1 = torch.nn.Linear(10, 10)
Expand All @@ -23,6 +29,12 @@ def forward(self, x):


class DummyDataset(Dataset):
"""
A dummy dataset that generates random tensors for testing.
Returns random input tensors of size 10 and target tensors of 0.
"""

def __len__(self):
return 10

Expand All @@ -32,6 +44,12 @@ def __getitem__(self, idx):

@pytest.fixture
def dummy_system():
"""
Fixture that creates a LighterSystem instance with a dummy model for testing.
Returns:
LighterSystem: A configured system with DummyModel, SGD optimizer, and DummyDataset.
"""
model = DummyModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
dataset = DummyDataset()
Expand All @@ -40,6 +58,14 @@ def dummy_system():


def test_freezer_initialization():
"""
Test the initialization of LighterFreezer with various parameter combinations.
Verifies:
- Raises ValueError when neither names nor name_starts_with is specified
- Raises ValueError when both until_step and until_epoch are specified
- Correctly stores the names parameter
"""
with pytest.raises(ValueError, match="At least one of `names` or `name_starts_with` must be specified."):
LighterFreezer()

Expand All @@ -50,6 +76,13 @@ def test_freezer_initialization():


def test_freezer_functionality(dummy_system):
"""
Test the basic functionality of LighterFreezer during training.
Verifies:
- Specified layers are correctly frozen (requires_grad=False)
- Non-specified layers remain unfrozen (requires_grad=True)
"""
freezer = LighterFreezer(names=["layer1.weight", "layer1.bias"])
trainer = Trainer(callbacks=[freezer], max_epochs=1)
trainer.fit(dummy_system)
Expand All @@ -59,6 +92,11 @@ def test_freezer_functionality(dummy_system):


def test_freezer_exceed_until_step(dummy_system):
"""
Test that layers are unfrozen after exceeding the specified step limit.
Verifies that layers become trainable (requires_grad=True) after the until_step threshold.
"""
freezer = LighterFreezer(names=["layer1.weight", "layer1.bias"], until_step=0)
trainer = Trainer(callbacks=[freezer], max_epochs=1)
trainer.fit(dummy_system)
Expand All @@ -67,6 +105,11 @@ def test_freezer_exceed_until_step(dummy_system):


def test_freezer_exceed_until_epoch(dummy_system):
"""
Test that layers are unfrozen after exceeding the specified epoch limit.
Verifies that layers become trainable (requires_grad=True) after the until_epoch threshold.
"""
freezer = LighterFreezer(names=["layer1.weight", "layer1.bias"], until_epoch=0)
trainer = Trainer(callbacks=[freezer], max_epochs=1)
trainer.fit(dummy_system)
Expand All @@ -75,6 +118,13 @@ def test_freezer_exceed_until_epoch(dummy_system):


def test_freezer_set_model_requires_grad(dummy_system):
"""
Test the internal _set_model_requires_grad method of LighterFreezer.
Verifies:
- Method correctly freezes specified parameters
- Method correctly unfreezes specified parameters
"""
freezer = LighterFreezer(names=["layer1.weight", "layer1.bias"])
freezer._set_model_requires_grad(dummy_system.model, requires_grad=False)
assert not dummy_system.model.layer1.weight.requires_grad
Expand All @@ -84,23 +134,15 @@ def test_freezer_set_model_requires_grad(dummy_system):
assert dummy_system.model.layer1.bias.requires_grad


def test_freezer_until_step(dummy_system):
freezer = LighterFreezer(names=["layer1.weight", "layer1.bias"], until_step=0)
trainer = Trainer(callbacks=[freezer], max_epochs=1)
trainer.fit(dummy_system)
assert dummy_system.model.layer1.weight.requires_grad
assert dummy_system.model.layer1.bias.requires_grad


def test_freezer_until_epoch(dummy_system):
freezer = LighterFreezer(names=["layer1.weight", "layer1.bias"], until_epoch=0)
trainer = Trainer(callbacks=[freezer], max_epochs=1)
trainer.fit(dummy_system)
assert dummy_system.model.layer1.weight.requires_grad
assert dummy_system.model.layer1.bias.requires_grad


def test_freezer_with_exceptions(dummy_system):
"""
Test LighterFreezer with exception patterns for layer freezing.
Verifies:
- Layers matching name_starts_with are frozen
- Layers in except_names remain unfrozen
- Other layers behave as expected
"""
freezer = LighterFreezer(name_starts_with=["layer"], except_names=["layer2.weight", "layer2.bias"])
trainer = Trainer(callbacks=[freezer], max_epochs=1)
trainer.fit(dummy_system)
Expand All @@ -113,6 +155,14 @@ def test_freezer_with_exceptions(dummy_system):


def test_freezer_except_name_starts_with(dummy_system):
"""
Test LighterFreezer with except_name_starts_with parameter.
Verifies:
- Layers matching name_starts_with are frozen
- Layers matching except_name_starts_with remain unfrozen
- Other layers behave as expected
"""
freezer = LighterFreezer(name_starts_with=["layer"], except_name_starts_with=["layer2"])
trainer = Trainer(callbacks=[freezer], max_epochs=1)
trainer.fit(dummy_system)
Expand All @@ -125,6 +175,14 @@ def test_freezer_except_name_starts_with(dummy_system):


def test_freezer_set_model_requires_grad_with_exceptions(dummy_system):
"""
Test the _set_model_requires_grad method with various exception patterns.
Verifies:
- Correct handling of specific parameter exceptions
- Proper behavior with name_starts_with and except_names combinations
- Consistent freezing/unfreezing across multiple configurations
"""
freezer = LighterFreezer(names=["layer1.weight", "layer1.bias"], except_names=["layer1.bias"])
freezer._set_model_requires_grad(dummy_system.model, requires_grad=False)
assert not dummy_system.model.layer1.weight.requires_grad
Expand Down
36 changes: 36 additions & 0 deletions tests/unit/test_callbacks_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@


def test_get_lighter_mode():
"""
Test the get_lighter_mode function's stage name mapping and error handling.
Tests:
- Mapping of 'train' stage
- Mapping of 'validate' stage to 'val'
- Mapping of 'test' stage
- Raising KeyError for invalid stage names
"""
assert get_lighter_mode("train") == "train"
assert get_lighter_mode("validate") == "val"
assert get_lighter_mode("test") == "test"
Expand All @@ -16,6 +25,16 @@ def test_get_lighter_mode():


def test_preprocess_image_single_3d():
"""
Test preprocess_image function with a single 3D image input.
Tests the reshaping of a single 3D image with dimensions:
- Input: (1, 1, depth, height, width)
- Expected output: (1, depth*height, width)
The function verifies that a 3D medical image is correctly
reshaped while preserving spatial relationships.
"""
depth = 20
height = 64
width = 64
Expand All @@ -25,6 +44,23 @@ def test_preprocess_image_single_3d():


def test_preprocess_image_batch_3d():
"""
Test preprocess_image function with a batch of 3D images.
Tests the reshaping of multiple 3D images with dimensions:
- Input: (batch_size, 1, depth, height, width)
- Expected output: (1, depth*height, batch_size*width)
The function verifies that a batch of 3D medical images is correctly
reshaped into a single 2D representation while maintaining the
spatial relationships and batch information.
Args used in test:
batch_size: 8
depth: 20
height: 64
width: 64
"""
batch_size = 8
depth = 20
height = 64
Expand Down
73 changes: 73 additions & 0 deletions tests/unit/test_callbacks_writer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,58 @@

@pytest.fixture
def target_path():
"""
Fixture that provides a test path for the writer.
Returns:
Path: A Path object pointing to "test" directory
"""
return Path("test")


class MockWriter(LighterBaseWriter):
"""
Mock implementation of LighterBaseWriter for testing purposes.
This class provides a minimal implementation of the abstract base class
with a simple tensor writer function.
"""

@property
def writers(self):
"""
Define available writers for the mock class.
Returns:
dict: Dictionary containing writer name and corresponding function
"""
return {"tensor": lambda x: None}

def write(self, tensor, id):
"""
Mock implementation of the write method.
Args:
tensor: The tensor to write
id: Identifier for the tensor
"""
pass


def test_writer_initialization(target_path):
"""
Test the initialization of writers.
Tests that:
- MockWriter initializes correctly with valid writer
- Base class raises TypeError when instantiated directly
Args:
target_path (Path): Fixture providing test directory path
Raises:
TypeError: When attempting to instantiate abstract base class
"""
# Test initialization with a valid writer
writer = MockWriter(path=target_path, writer="tensor")
assert callable(writer.writer)
Expand All @@ -31,6 +70,17 @@ def test_writer_initialization(target_path):


def test_on_predict_batch_end(target_path):
"""
Test the on_predict_batch_end callback functionality.
Verifies that:
- Prediction IDs are properly assigned
- Prediction counter increments correctly
- Trainer's prediction list is maintained
Args:
target_path (Path): Fixture providing test directory path
"""
logging.basicConfig(level=logging.INFO)
trainer = MagicMock()
trainer.world_size = 1
Expand All @@ -54,6 +104,18 @@ def test_on_predict_batch_end(target_path):


def test_writer_setup_predict(target_path, caplog):
"""
Test writer setup for prediction stage.
Verifies that:
- Writer initializes correctly for prediction
- Prediction counter is properly reset
- Global synchronization works as expected
Args:
target_path (Path): Fixture providing test directory path
caplog: Pytest fixture for capturing log output
"""
trainer = MagicMock()
trainer.world_size = 1
trainer.is_global_zero = True
Expand All @@ -69,6 +131,17 @@ def test_writer_setup_predict(target_path, caplog):


def test_writer_setup_non_predict(target_path):
"""
Test writer setup for non-prediction stages.
Verifies that:
- Writer initializes correctly for non-prediction stages (e.g., train)
- Prediction counter remains None
- Path is properly set
Args:
target_path (Path): Fixture providing test directory path
"""
trainer = MagicMock()
trainer.world_size = 1
trainer.is_global_zero = True
Expand Down
Loading

0 comments on commit f5db261

Please sign in to comment.