Skip to content

Commit

Permalink
Add unittests for compute_metrics_reloaded.py
Browse files Browse the repository at this point in the history
  • Loading branch information
valosekj committed Mar 6, 2024
1 parent cfa7f0f commit 0dbd013
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 1 deletion.
7 changes: 6 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -r dataset_conversion/requirements.txt
git clone https://github.com/valosekj/MetricsReloaded.git
cd MetricsReloaded
git checkout jv/add_rel_vol_error_metric
python -m pip install .
- name: Run tests with unittest
run: |
python -m unittest tests/test_convert_bids_to_nnUNetV2.py
python -m unittest tests/test_convert_bids_to_nnUNetV2.py
python -m unittest tests/test_compute_metrics_reloaded.py
147 changes: 147 additions & 0 deletions tests/test_compute_metrics_reloaded.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
#######################################################################
#
# Tests for the `compute_metrics/compute_metrics_reloaded.py` script
#
# RUN BY:
# python -m unittest tests/test_compute_metrics_reloaded.py
#######################################################################

import unittest
import os
import numpy as np
import nibabel as nib
from compute_metrics.compute_metrics_reloaded import compute_metrics_single_subject
import tempfile

METRICS = ['dsc', 'fbeta', 'nsd', 'vol_diff', 'rel_vol_error']


class TestComputeMetricsReloaded(unittest.TestCase):
def setUp(self):
# Use tempfile.NamedTemporaryFile to create temporary nifti files
self.ref_file = tempfile.NamedTemporaryFile(suffix=".nii.gz", delete=False)
self.pred_file = tempfile.NamedTemporaryFile(suffix=".nii.gz", delete=False)
self.metrics = METRICS

def create_dummy_nii(self, file_obj, data):
img = nib.Nifti1Image(data, np.eye(4))
nib.save(img, file_obj.name)
file_obj.seek(0) # Move back to the beginning of the file

def tearDown(self):
# Close and remove temporary files
self.ref_file.close()
os.unlink(self.ref_file.name)
self.pred_file.close()
os.unlink(self.pred_file.name)

def assert_metrics(self, metrics_dict, expected_metrics):
for metric in self.metrics:
# if value is nan, use np.isnan to check
if np.isnan(expected_metrics[metric]):
self.assertTrue(np.isnan(metrics_dict[1][metric]))
# if value is inf, use np.isinf to check
elif np.isinf(expected_metrics[metric]):
self.assertTrue(np.isinf(metrics_dict[1][metric]))
else:
self.assertAlmostEqual(metrics_dict[1][metric], expected_metrics[metric])

def test_empty_ref_and_pred(self):
"""
Empty reference and empty prediction
"""

expected_metrics = {'EmptyPred': True,
'EmptyRef': True,
'dsc': 1,
'fbeta': 1,
'nsd': np.nan,
'rel_vol_error': 0,
'vol_diff': np.nan}

# Create empty reference
self.create_dummy_nii(self.ref_file, np.zeros((10, 10, 10)))
# Create empty prediction
self.create_dummy_nii(self.pred_file, np.zeros((10, 10, 10)))
# Compute metrics
metrics_dict = compute_metrics_single_subject(self.pred_file.name, self.ref_file.name, self.metrics)
# Assert metrics
self.assert_metrics(metrics_dict, expected_metrics)

def test_empty_ref(self):
"""
Empty reference and non-empty prediction
"""

expected_metrics = {'EmptyPred': False,
'EmptyRef': True,
'dsc': 0.0,
'fbeta': 0,
'nsd': 0.0,
'rel_vol_error': 100,
'vol_diff': np.inf}

# Create empty reference
self.create_dummy_nii(self.ref_file, np.zeros((10, 10, 10)))
# Create non-empty prediction
pred = np.zeros((10, 10, 10))
pred[5:7, 2:5] = 1
self.create_dummy_nii(self.pred_file, pred)
# Compute metrics
metrics_dict = compute_metrics_single_subject(self.pred_file.name, self.ref_file.name, self.metrics)
# Assert metrics
self.assert_metrics(metrics_dict, expected_metrics)

def test_empty_pred(self):
"""
Non-empty reference and empty prediction
"""

expected_metrics = {'EmptyPred': True,
'EmptyRef': False,
'dsc': 0.0,
'fbeta': 0,
'nsd': 0.0,
'rel_vol_error': -100.0,
'vol_diff': 1.0}

# Create non-empty reference
ref = np.zeros((10, 10, 10))
ref[5:7, 2:5] = 1
self.create_dummy_nii(self.ref_file, ref)
# Create empty prediction
self.create_dummy_nii(self.pred_file, np.zeros((10, 10, 10)))
# Compute metrics
metrics_dict = compute_metrics_single_subject(self.pred_file.name, self.ref_file.name, self.metrics)
# Assert metrics
self.assert_metrics(metrics_dict, expected_metrics)

def test_non_empty_ref_and_pred(self):
"""
Non-empty reference and non-empty prediction
"""

expected_metrics = {'EmptyPred': False,
'EmptyRef': False,
'dsc': 0.26666666666666666,
'fbeta': 0.26666667461395266,
'nsd': 0.5373134328358209,
'rel_vol_error': 300.0,
'vol_diff': 3.0}

# Create non-empty reference
ref = np.zeros((10, 10, 10))
ref[4:5, 3:6] = 1
self.create_dummy_nii(self.ref_file, ref)
# Create non-empty prediction
pred = np.zeros((10, 10, 10))
pred[4:8, 2:5] = 1
self.create_dummy_nii(self.pred_file, pred)
# Compute metrics
metrics_dict = compute_metrics_single_subject(self.pred_file.name, self.ref_file.name, self.metrics)
# Assert metrics
self.assert_metrics(metrics_dict, expected_metrics)


if __name__ == '__main__':
unittest.main()

0 comments on commit 0dbd013

Please sign in to comment.