Skip to content

Commit

Permalink
Merge pull request scilus#767 from frheault/compare_tractogram_big
Browse files Browse the repository at this point in the history
Compare tractograms (TODI, TDI, ACC)
  • Loading branch information
arnaudbore authored Mar 20, 2024
2 parents 4965e01 + b67a138 commit 234b88c
Show file tree
Hide file tree
Showing 9 changed files with 599 additions and 5 deletions.
36 changes: 35 additions & 1 deletion scilpy/image/tests/test_volume_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import nibabel as nib
import numpy as np
from dipy.io.gradients import read_bvals_bvecs
from numpy.testing import assert_equal
from numpy.testing import assert_equal, assert_almost_equal

from scilpy import SCILPY_HOME
from scilpy.image.volume_operations import (apply_transform,
compute_snr,
crop_volume,
flip_volume,
merge_metrics,
normalize_metric,
resample_volume)
from scilpy.io.fetcher import fetch_data, get_testing_files_dict
from scilpy.utils.util import compute_nifti_bounding_box
Expand Down Expand Up @@ -127,3 +129,35 @@ def test_resample_volume():
resampled_img = resample_volume(moving3d_img, res=(2, 2, 2), interp='nn')

assert_equal(resampled_img.get_fdata(), ref3d)


def test_normalize_metric_basic():
metric = np.array([1, 2, 3, 4, 5])
expected_output = np.array([0., 0.25, 0.5, 0.75, 1.])
normalized_metric = normalize_metric(metric)
assert_almost_equal(normalized_metric, expected_output)


def test_normalize_metric_nan_handling():
metric = np.array([1, np.nan, 3, np.nan, 5])
expected_output = np.array([0., np.nan, 0.5, np.nan, 1.])
normalized_metric = normalize_metric(metric)

assert_almost_equal(normalized_metric, expected_output)


def test_merge_metrics_basic():
arrays = [np.array([1, 2, 3]), np.array([4, 5, 6])]
# Geometric mean boosted by beta=1
expected_output = np.array([2.0, 3.162278, 4.242641])
merged_metric = merge_metrics(*arrays)

assert_almost_equal(merged_metric, expected_output, decimal=6)


def test_merge_metrics_nan_propagation():
arrays = [np.array([1, np.nan, 3]), np.array([4, 5, 6])]
expected_output = np.array([2., np.nan, 4.242641]) # NaN replaced with -2
merged_metric = merge_metrics(*arrays)

assert_almost_equal(merged_metric, expected_output, decimal=6)
67 changes: 67 additions & 0 deletions scilpy/image/volume_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from dipy.segment.mask import crop, median_otsu
import nibabel as nib
import numpy as np
from numpy import ma
from scipy.ndimage import binary_dilation, gaussian_filter

from scilpy.image.reslice import reslice # Don't use Dipy's reslice. Buggy.
Expand Down Expand Up @@ -528,3 +529,69 @@ def crop_data_with_default_cube(data):
roi_mask = _mask_from_roi(shape, roi_center, roi_radii)

return data * roi_mask


def normalize_metric(metric, reverse=False):
"""
Normalize a metric array to a range between 0 and 1,
optionally reversing the normalization.
Parameters
----------
metric : ndarray
The input metric array to be normalized.
reverse : bool, optional
If True, reverse the normalization (i.e., 1 - normalized value).
Default is False.
Returns
-------
ndarray
The normalized (and possibly reversed) metric array.
NaN values in the input are retained.
"""
mask = np.isnan(metric)
masked_metric = ma.masked_array(metric, mask)

min_val, max_val = masked_metric.min(), masked_metric.max()
normalized_metric = (masked_metric - min_val) / (max_val - min_val)

if reverse:
normalized_metric = 1 - normalized_metric

return ma.filled(normalized_metric, fill_value=np.nan)


def merge_metrics(*arrays, beta=1.0):
"""
Merge an arbitrary number of metrics into a single heatmap using a weighted
geometric mean, ignoring NaN values. Each input array contributes equally
to the geometric mean, and the result is boosted by a specified factor.
Parameters
----------
*arrays : ndarray
An arbitrary number of input arrays (ndarrays).
All arrays must have the same shape.
beta : float, optional
Boosting factor for the geometric mean. The default is 1.0.
Returns
-------
ndarray
Boosted geometric mean of the inputs (same shape as the input arrays)
NaN values in any input array are propagated to the output.
"""

# Create a mask for NaN values in any of the arrays
mask = np.any([np.isnan(arr) for arr in arrays], axis=0)
masked_arrays = [ma.masked_array(arr, mask) for arr in arrays]

# Calculate the product of the arrays for the geometric mean
array_product = np.prod(masked_arrays, axis=0)

# Calculate the geometric mean for valid data
geometric_mean = np.power(array_product, 1 / len(arrays))
boosted_mean = geometric_mean ** beta

return ma.filled(boosted_mean, fill_value=np.nan)
Loading

0 comments on commit 234b88c

Please sign in to comment.