Skip to content

Commit

Permalink
Merge pull request scilus#1023 from frheault/distance_map
Browse files Browse the repository at this point in the history
Distance map
  • Loading branch information
frheault authored Sep 4, 2024
2 parents 5da7219 + a8bbb9f commit cc8e4f3
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 1 deletion.
60 changes: 59 additions & 1 deletion scilpy/image/tests/test_volume_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
crop_volume, flip_volume,
merge_metrics, normalize_metric,
resample_volume, register_image,
mask_data_with_default_cube)
mask_data_with_default_cube,
compute_distance_map)
from scilpy.io.fetcher import fetch_data, get_testing_files_dict
from scilpy.image.utils import compute_nifti_bounding_box

Expand Down Expand Up @@ -240,3 +241,60 @@ def test_mask_data_with_default_cube():
assert out[0, 0, 0] == 0
assert out[-1, -1, -1] == 0
assert out[6, 6, 6] == 1


def test_distance_map_smallest_first():
mask_1 = np.zeros((3, 3, 3))
mask_1[0, 0, 0] = 1

mask_2 = np.zeros((3, 3, 3))
mask_2[1:3, 1:3, 1:3] = 1

distance = compute_distance_map(mask_1, mask_2)
assert np.abs(np.sum(distance) - 1.732050) < 1e-6


def test_compute_distance_map_biggest_first():
# Swap both masks
mask_2 = np.zeros((3, 3, 3))
mask_2[0, 0, 0] = 1

mask_1 = np.zeros((3, 3, 3))
mask_1[1:3, 1:3, 1:3] = 1

distance = compute_distance_map(mask_1, mask_2)
assert np.abs(np.sum(distance) - 21.544621) < 1e-6


def test_compute_distance_map_symmetric():
mask_1 = np.zeros((3, 3, 3))
mask_1[0, 0, 0] = 1

mask_2 = np.zeros((3, 3, 3))
mask_2[1:3, 1:3, 1:3] = 1

distance = compute_distance_map(mask_1, mask_2, symmetric=True)
assert np.abs(np.sum(distance) - 23.276672) < 1e-6


def test_compute_distance_map_overlap():
mask_1 = np.zeros((3, 3, 3))
mask_1[1, 1, 1] = 1

mask_2 = np.zeros((3, 3, 3))
mask_2[1:3, 1:3, 1:3] = 1

distance = compute_distance_map(mask_1, mask_2)
assert np.all(distance == 0)


def test_compute_distance_map_wrong_shape():
mask_1 = np.zeros((3, 3, 3))
mask_2 = np.zeros((3, 3, 4))

# Different shapes, test should fail
try:
compute_distance_map(mask_1, mask_2)
assert False
except ValueError:
assert True
48 changes: 48 additions & 0 deletions scilpy/image/volume_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numpy as np
from numpy import ma
from scipy.ndimage import binary_dilation, gaussian_filter
from scipy.spatial import KDTree
from sklearn import linear_model

from scilpy.image.reslice import reslice # Don't use Dipy's reslice. Buggy.
Expand Down Expand Up @@ -688,3 +689,50 @@ def merge_metrics(*arrays, beta=1.0):
boosted_mean = geometric_mean ** beta

return ma.filled(boosted_mean, fill_value=np.nan)


def compute_distance_map(mask_1, mask_2, symmetric=False,
max_distance=np.inf):
"""
Compute the distance map between two binary masks.
The distance is computed using the Euclidean distance between the
first mask and the closest point in the second mask.
Use the symmetric flag to compute the distance map in both directions.
WARNING: This function will work even if inputs are not binary masks,
just make sure that you know what you are doing.
Parameters
----------
mask_1: np.ndarray
First binary mask.
mask_2: np.ndarray
Second binary mask.
symmetric: bool, optional
If True, compute the symmetric distance map. Default is np.inf
max_distance: float, optional
Maximum distance to consider for kdtree exploration. Default is None.
Returns
-------
distance_map: np.ndarray
Distance map between the two masks.
"""
if mask_1.shape != mask_2.shape:
raise ValueError("Masks must have the same shape.")

tree = KDTree(np.argwhere(mask_2))
distance_map = np.zeros(mask_1.shape)
distance = tree.query(np.argwhere(mask_1),
distance_upper_bound=max_distance)[0]
distance_map[np.where(mask_1)] = distance

if symmetric:
# Compute the symmetric distance map and merge it with the previous one
tree = KDTree(np.argwhere(mask_1))
distance = tree.query(np.argwhere(mask_2),
distance_upper_bound=max_distance)[0]
distance_map[np.where(mask_2)] = distance

return distance_map
84 changes: 84 additions & 0 deletions scripts/scil_volume_distance_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Compute distance map between two binary masks. The distance map is the
Euclidean distance from each voxel of the first mask to the closest
voxel of the second mask.
Slowest scenarios are 1) two very large masks that are far appart or 2) a very
small mask completely inside a very large mask (around 20-30 seconds).
Take this command as an example:
scil_volume_distance_map.py brain_mask.nii.gz AF_L.nii.gz \
AF_L_to_brain_mask.nii.gz
We have a brain mask and a bundle, the second is 100% inside the first.
The output will be a distance map from the brain mask to the bundle.
If we take the bundle as the first input and the brain mask as the second,
The output will be a distance map from the bundle to the brain mask, which
will be all zeros (because the bundle is fully inside the brain mask).
If you want both distance maps at once, you can use the --symmetric_distance
option.
"""

import argparse
import logging

import nibabel as nib
import numpy as np

from scilpy.image.volume_operations import compute_distance_map
from scilpy.io.image import get_data_as_mask
from scilpy.io.utils import add_overwrite_arg, add_verbose_arg, \
assert_headers_compatible, assert_inputs_exist, assert_outputs_exist


def _build_arg_parser():
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawTextHelpFormatter)
p.add_argument('in_mask_1', metavar='IN_SOURCE',
help='Input file name, in nifti format.')
p.add_argument('in_mask_2', metavar='IN_TARGET',
help='Input file name, in nifti format.')
p.add_argument('out_distance', metavar='OUT_DISTANCE_MAP',
help='Input file name, in nifti format.')

p.add_argument('--symmetric_distance', action='store_true',
help='Compute the distance from mask 1 to mask 2 and the '
'distance from mask 2 to mask 1 and sum them up.')
add_verbose_arg(p)
add_overwrite_arg(p)

return p


def main():
parser = _build_arg_parser()
args = parser.parse_args()
logging.getLogger().setLevel(logging.getLevelName(args.verbose))

assert_inputs_exist(parser, [args.in_mask_1, args.in_mask_2])
assert_outputs_exist(parser, args, args.out_distance)
assert_headers_compatible(parser, [args.in_mask_1, args.in_mask_2])

img_1 = nib.load(args.in_mask_1)
img_2 = nib.load(args.in_mask_2)

mask_1 = get_data_as_mask(img_1)
mask_2 = get_data_as_mask(img_2)
logging.debug(f'Loaded two masks with {np.count_nonzero(mask_1)} and '
f'{np.count_nonzero(mask_2)} voxels')

# Compute distance map using KDTree
distance_map = compute_distance_map(mask_1, mask_2,
args.symmetric_distance)

out_img = nib.Nifti1Image(distance_map.astype(float), img_1.affine)
nib.save(out_img, args.out_distance)


if __name__ == "__main__":
main()
37 changes: 37 additions & 0 deletions scripts/tests/test_volume_distance_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import tempfile

import nibabel as nib

from scilpy import SCILPY_HOME
from scilpy.io.fetcher import fetch_data, get_testing_files_dict

# If they already exist, this only takes 5 seconds (check md5sum)
fetch_data(get_testing_files_dict(), keys=['tractograms.zip'])
tmp_dir = tempfile.TemporaryDirectory()


def test_help_option(script_runner):
ret = script_runner.run('scil_volume_distance_map.py', '--help')
assert ret.success


def test_execution(script_runner, monkeypatch):
monkeypatch.chdir(os.path.expanduser(tmp_dir.name))
in_mask_1 = os.path.join(SCILPY_HOME, 'tractograms',
'streamline_and_mask_operations',
'bundle_4_head_tail.nii.gz')
in_mask_2 = os.path.join(SCILPY_HOME, 'tractograms',
'streamline_and_mask_operations',
'bundle_4_center.nii.gz')
ret = script_runner.run('scil_volume_distance_map.py',
in_mask_1, in_mask_2,
'distance_map.nii.gz')

img = nib.load('distance_map.nii.gz')
data = img.get_fdata()
assert data[data > 0].mean() - 17.7777 < 0.0001
assert ret.success

0 comments on commit cc8e4f3

Please sign in to comment.