-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add unittests for compute_metrics_reloaded.py
- Loading branch information
Showing
2 changed files
with
153 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
####################################################################### | ||
# | ||
# Tests for the `compute_metrics/compute_metrics_reloaded.py` script | ||
# | ||
# RUN BY: | ||
# python -m unittest tests/test_compute_metrics_reloaded.py | ||
####################################################################### | ||
|
||
import unittest | ||
import os | ||
import numpy as np | ||
import nibabel as nib | ||
from compute_metrics.compute_metrics_reloaded import compute_metrics_single_subject | ||
import tempfile | ||
|
||
METRICS = ['dsc', 'fbeta', 'nsd', 'vol_diff', 'rel_vol_error'] | ||
|
||
|
||
class TestComputeMetricsReloaded(unittest.TestCase): | ||
def setUp(self): | ||
# Use tempfile.NamedTemporaryFile to create temporary nifti files | ||
self.ref_file = tempfile.NamedTemporaryFile(suffix=".nii.gz", delete=False) | ||
self.pred_file = tempfile.NamedTemporaryFile(suffix=".nii.gz", delete=False) | ||
self.metrics = METRICS | ||
|
||
def create_dummy_nii(self, file_obj, data): | ||
img = nib.Nifti1Image(data, np.eye(4)) | ||
nib.save(img, file_obj.name) | ||
file_obj.seek(0) # Move back to the beginning of the file | ||
|
||
def tearDown(self): | ||
# Close and remove temporary files | ||
self.ref_file.close() | ||
os.unlink(self.ref_file.name) | ||
self.pred_file.close() | ||
os.unlink(self.pred_file.name) | ||
|
||
def assert_metrics(self, metrics_dict, expected_metrics): | ||
for metric in self.metrics: | ||
# if value is nan, use np.isnan to check | ||
if np.isnan(expected_metrics[metric]): | ||
self.assertTrue(np.isnan(metrics_dict[1][metric])) | ||
# if value is inf, use np.isinf to check | ||
elif np.isinf(expected_metrics[metric]): | ||
self.assertTrue(np.isinf(metrics_dict[1][metric])) | ||
else: | ||
self.assertAlmostEqual(metrics_dict[1][metric], expected_metrics[metric]) | ||
|
||
def test_empty_ref_and_pred(self): | ||
""" | ||
Empty reference and empty prediction | ||
""" | ||
|
||
expected_metrics = {'EmptyPred': True, | ||
'EmptyRef': True, | ||
'dsc': 1, | ||
'fbeta': 1, | ||
'nsd': np.nan, | ||
'rel_vol_error': 0, | ||
'vol_diff': np.nan} | ||
|
||
# Create empty reference | ||
self.create_dummy_nii(self.ref_file, np.zeros((10, 10, 10))) | ||
# Create empty prediction | ||
self.create_dummy_nii(self.pred_file, np.zeros((10, 10, 10))) | ||
# Compute metrics | ||
metrics_dict = compute_metrics_single_subject(self.pred_file.name, self.ref_file.name, self.metrics) | ||
# Assert metrics | ||
self.assert_metrics(metrics_dict, expected_metrics) | ||
|
||
def test_empty_ref(self): | ||
""" | ||
Empty reference and non-empty prediction | ||
""" | ||
|
||
expected_metrics = {'EmptyPred': False, | ||
'EmptyRef': True, | ||
'dsc': 0.0, | ||
'fbeta': 0, | ||
'nsd': 0.0, | ||
'rel_vol_error': 100, | ||
'vol_diff': np.inf} | ||
|
||
# Create empty reference | ||
self.create_dummy_nii(self.ref_file, np.zeros((10, 10, 10))) | ||
# Create non-empty prediction | ||
pred = np.zeros((10, 10, 10)) | ||
pred[5:7, 2:5] = 1 | ||
self.create_dummy_nii(self.pred_file, pred) | ||
# Compute metrics | ||
metrics_dict = compute_metrics_single_subject(self.pred_file.name, self.ref_file.name, self.metrics) | ||
# Assert metrics | ||
self.assert_metrics(metrics_dict, expected_metrics) | ||
|
||
def test_empty_pred(self): | ||
""" | ||
Non-empty reference and empty prediction | ||
""" | ||
|
||
expected_metrics = {'EmptyPred': True, | ||
'EmptyRef': False, | ||
'dsc': 0.0, | ||
'fbeta': 0, | ||
'nsd': 0.0, | ||
'rel_vol_error': -100.0, | ||
'vol_diff': 1.0} | ||
|
||
# Create non-empty reference | ||
ref = np.zeros((10, 10, 10)) | ||
ref[5:7, 2:5] = 1 | ||
self.create_dummy_nii(self.ref_file, ref) | ||
# Create empty prediction | ||
self.create_dummy_nii(self.pred_file, np.zeros((10, 10, 10))) | ||
# Compute metrics | ||
metrics_dict = compute_metrics_single_subject(self.pred_file.name, self.ref_file.name, self.metrics) | ||
# Assert metrics | ||
self.assert_metrics(metrics_dict, expected_metrics) | ||
|
||
def test_non_empty_ref_and_pred(self): | ||
""" | ||
Non-empty reference and non-empty prediction | ||
""" | ||
|
||
expected_metrics = {'EmptyPred': False, | ||
'EmptyRef': False, | ||
'dsc': 0.26666666666666666, | ||
'fbeta': 0.26666667461395266, | ||
'nsd': 0.5373134328358209, | ||
'rel_vol_error': 300.0, | ||
'vol_diff': 3.0} | ||
|
||
# Create non-empty reference | ||
ref = np.zeros((10, 10, 10)) | ||
ref[4:5, 3:6] = 1 | ||
self.create_dummy_nii(self.ref_file, ref) | ||
# Create non-empty prediction | ||
pred = np.zeros((10, 10, 10)) | ||
pred[4:8, 2:5] = 1 | ||
self.create_dummy_nii(self.pred_file, pred) | ||
# Compute metrics | ||
metrics_dict = compute_metrics_single_subject(self.pred_file.name, self.ref_file.name, self.metrics) | ||
# Assert metrics | ||
self.assert_metrics(metrics_dict, expected_metrics) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |