Skip to content

Commit

Permalink
Merge pull request #23 from mu40/warp
Browse files Browse the repository at this point in the history
Build on Yujing's work, derive `Warp` from `FramedImage`.
  • Loading branch information
jnolan14 authored Feb 29, 2024
2 parents fde466d + 76a815e commit 3c89ab8
Show file tree
Hide file tree
Showing 10 changed files with 273 additions and 412 deletions.
2 changes: 2 additions & 0 deletions surfa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .core import LabelRecoder

from .transform import Affine
from .transform import Warp
from .transform import Space
from .transform import ImageGeometry

Expand All @@ -29,6 +30,7 @@
from .io import load_affine
from .io import load_label_lookup
from .io import load_mesh
from .io import load_warp

from . import vis
from . import freesurfer
Expand Down
39 changes: 18 additions & 21 deletions surfa/image/framed.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def transform(self, trf=None, method='linear', rotation='corner', resample=True,
**Note on trf argument:** It accepts Affine/Warp object, or deformation fields (4D numpy array).
Pass trf argument as a numpy array is deprecated and will be removed in the future.
It is assumed that the deformation fields represent a *displacement* vector field in voxel space.
It is assumed that the deformation fields represent a *displacement* vector field in voxel space.
So under the hood, images will be moved into the space of the deformation field if the image geometries differ.
Parameters
Expand Down Expand Up @@ -420,28 +420,26 @@ def transform(self, trf=None, method='linear', rotation='corner', resample=True,
'contact andrew if you need this')

image = self.copy()
transformer = trf
if isinstance(transformer, Affine):
return transformer.transform(image, method, rotation, resample, fill)

if isinstance(trf, Affine):
return trf.transform(image, method, rotation, resample, fill)

from surfa.transform.warp import Warp
if isinstance(transformer, np.ndarray):
if isinstance(trf, np.ndarray):
warnings.warn('The option to pass \'trf\' argument as a numpy array is deprecated. '
'Pass \'trf\' as either an Affine or Warp object',
DeprecationWarning, stacklevel=2)

deformation = cast_image(trf, fallback_geom=self.geom)
image = image.resample_like(deformation)
transformer = Warp(data=trf,
source=image.geom,
target=deformation.geom,
format=Warp.Format.disp_crs)

if not isinstance(transformer, Warp):
raise ValueError("Pass \'trf\' as either an Affine or Warp object")

return transformer.transform(image, method, fill)

trf = Warp(data=trf,
source=image.geom,
target=deformation.geom,
format=Warp.Format.disp_crs)

if isinstance(trf, Warp):
return trf.transform(image, method, fill)

raise ValueError("Pass \'trf\' as either an Affine or Warp object")

def reorient(self, orientation, copy=True):
"""
Expand Down Expand Up @@ -707,20 +705,19 @@ def barycenters(self, labels=None, space='image'):
one, the barycenter array will be of shape $(F, L, D)$.
"""
if labels is not None:
#

if not np.issubdtype(self.dtype, np.integer):
raise ValueError('expected int dtype for computing barycenters on 1D, '
f'but got dtype {self.dtype}')
weights = np.ones(self.baseshape, dtype=np.float32)
centers = [center_of_mass(weights, self.framed_data[..., i], labels) for i in range(self.nframes)]
else:
#

centers = [center_of_mass(self.framed_data[..., i]) for i in range(self.nframes)]

#

centers = np.squeeze(centers)

#
space = cast_space(space)
if space != 'image':
centers = self.geom.affine('image', space)(centers)
Expand Down
12 changes: 6 additions & 6 deletions surfa/image/interp.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ def interpolate(source, target_shape, method, affine=None, disp=None, fill=0):
raise ValueError('interpolation requires an affine transform and/or displacement field')

if method not in ('linear', 'nearest'):
raise ValueError(f'interp method must be linear or nearest, but got {method}')
raise ValueError(f'interp method must be linear or nearest, got {method}')

if not isinstance(source, np.ndarray):
raise ValueError(f'source data must a numpy array, but got input of type {source.__class__.__name__}')
raise ValueError(f'source data must be a numpy array, got {source.__class__.__name__}')

if source.ndim != 4:
raise ValueError(f'source data must be 4D, but got input of shape {target_shape}')
raise ValueError(f'source data must be 4D, but got input of shape {source.shape}')

target_shape = tuple(target_shape)
if len(target_shape) != 3:
Expand All @@ -53,7 +53,7 @@ def interpolate(source, target_shape, method, affine=None, disp=None, fill=0):
use_affine = affine is not None
if use_affine:
if not isinstance(affine, np.ndarray):
raise ValueError(f'affine must a numpy array, but got input of type {source.__class__.__name__}')
raise ValueError(f'affine must be a numpy array, got {affine.__class__.__name__}')
if not np.array_equal(affine.shape, (4, 4)):
raise ValueError(f'affine must be 4x4, but got input of shape {affine.shape}')
# only supports float32 affines for now
Expand All @@ -63,9 +63,9 @@ def interpolate(source, target_shape, method, affine=None, disp=None, fill=0):
use_disp = disp is not None
if use_disp:
if not isinstance(disp, np.ndarray):
raise ValueError(f'source data must a numpy array, but got input of type {source.__class__.__name__}')
raise ValueError(f'source data must be a numpy array, got {disp.__class__.__name__}')
if not np.array_equal(disp.shape[:-1], target_shape):
raise ValueError(f'displacement field shape {disp.shape[:-1]} must match target shape {target_shape}')
raise ValueError(f'warp shape {disp.shape[:-1]} must match target shape {target_shape}')

# TODO: figure out what would cause this
if not disp.flags.c_contiguous and not disp.flags.f_contiguous:
Expand Down
1 change: 1 addition & 0 deletions surfa/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
from .framed import load_volume
from .framed import load_slice
from .framed import load_overlay
from .framed import load_warp
from .labels import load_label_lookup
from .mesh import load_mesh
88 changes: 57 additions & 31 deletions surfa/io/framed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from surfa import Volume
from surfa import Slice
from surfa import Overlay
from surfa import ImageGeometry
from surfa import Warp
from surfa.core.array import pad_vector_length
from surfa.core.framed import FramedArray
from surfa.core.framed import FramedArrayIntents
Expand All @@ -17,8 +17,8 @@
from surfa.io.utils import write_int
from surfa.io.utils import read_bytes
from surfa.io.utils import write_bytes
from surfa.io.utils import read_volgeom
from surfa.io.utils import write_volgeom
from surfa.io.utils import read_geom
from surfa.io.utils import write_geom
from surfa.io.utils import check_file_readability


Expand All @@ -31,8 +31,7 @@ def load_volume(filename, fmt=None):
filename : str
File path to read.
fmt : str, optional
Explicit file format. If None (default), the format is extrapolated
from the file extension.
Explicit file format. If None, we extrapolate from the file extension.
Returns
-------
Expand All @@ -51,8 +50,7 @@ def load_slice(filename, fmt=None):
filename : str
File path to read.
fmt : str, optional
Explicit file format. If None (default), the format is extrapolated
from the file extension.
Explicit file format. If None, we extrapolate from the file extension.
Returns
-------
Expand All @@ -71,8 +69,7 @@ def load_overlay(filename, fmt=None):
filename : str
File path to read.
fmt : str, optional
Explicit file format. If None (default), the format is extrapolated
from the file extension.
Explicit file format. If None, we extrapolate from the file extension.
Returns
-------
Expand All @@ -82,6 +79,25 @@ def load_overlay(filename, fmt=None):
return load_framed_array(filename=filename, atype=Overlay, fmt=fmt)


def load_warp(filename, fmt=None):
"""
Load an image `Warp` from a 3D or 4D array file.
Parameters
----------
filename : str
File path to read.
fmt : str, optional
Explicit file format. If None, we extrapolate from the file extension.
Returns
-------
Warp
Loaded warp.
"""
return load_framed_array(filename=filename, atype=Warp, fmt=fmt)


def load_framed_array(filename, atype, fmt=None):
"""
Generic loader for `FramedArray` objects.
Expand All @@ -93,13 +109,12 @@ def load_framed_array(filename, atype, fmt=None):
atype : class
Particular FramedArray subclass to read into.
fmt : str, optional
Forced file format. If None (default), file format is extrapolated
from extension.
Explicit file format. If None, we extrapolate from the file extension.
Returns
-------
FramedArray
Loaded framed array.
Loaded framed array.
"""
check_file_readability(filename)

Expand Down Expand Up @@ -166,11 +181,15 @@ def framed_array_from_4d(atype, data):
Returns
-------
FramedArray
Squeezed framed array.
Squeezed framed array.
"""
# this code is a bit ugly - it does the job but should probably be cleaned up
if atype == Volume:
return atype(data)
if atype == Warp:
if data.ndim == 4 and data.shape[-1] == 2:
data = data.squeeze(-2)
return atype(data)
# slice
if data.ndim == 3:
data = np.expand_dims(data, -1)
Expand Down Expand Up @@ -286,7 +305,7 @@ def load(self, filename, atype):
# it's also not required in the freesurfer definition, so we'll
# use the read() function directly in case end-of-file is reached
file.read(np.dtype('>f4').itemsize)

# update image-specific information
if isinstance(arr, FramedImage):
arr.geom.update(**geom_params)
Expand Down Expand Up @@ -323,17 +342,19 @@ def load(self, filename, atype):

# gcamorph src & trg geoms (mgz warp)
elif tag == fsio.tags.gcamorph_geom:
# read src vol geom
arr.metadata['gcamorph_volgeom_src'] = read_volgeom(file)
arr.source, valid, fname = read_geom(file)
arr.metadata['source-valid'] = valid
arr.metadata['source-fname'] = fname

arr.target, valid, fname = read_geom(file)
arr.metadata['target-valid'] = valid
arr.metadata['target-fname'] = fname

# read trg vol geom
arr.metadata['gcamorph_volgeom_trg'] = read_volgeom(file)

# gcamorph meta (mgz warp: int int float)
elif tag == fsio.tags.gcamorph_meta:
arr.metadata['warpfield_dtfmt'] = read_bytes(file, dtype='>i4')
arr.metadata['gcamorph_spacing'] = read_bytes(file, dtype='>i4')
arr.metadata['gcamorph_exp_k'] = read_bytes(file, dtype='>f4')
arr.format = read_bytes(file, dtype='>i4')
arr.metadata['spacing'] = read_bytes(file, dtype='>i4')
arr.metadata['exp_k'] = read_bytes(file, dtype='>f4')

# skip everything else
else:
Expand Down Expand Up @@ -438,18 +459,24 @@ def save(self, arr, filename, intent=FramedArrayIntents.mri):
write_bytes(file, arr.metadata.get('field-strength', 0.0), '>f4')

# gcamorph geom and gcamorph meta for mgz warp
if (intent == FramedArrayIntents.warpmap):
if intent == FramedArrayIntents.warpmap:
# gcamorph src & trg geoms (mgz warp)
fsio.write_tag(file, fsio.tags.gcamorph_geom)
write_volgeom(file, arr.metadata['gcamorph_volgeom_src'])
write_volgeom(file, arr.metadata['gcamorph_volgeom_trg'])

write_geom(file,
geom=arr.source,
valid=arr.metadata.get('source-valid', True),
fname=arr.metadata.get('source-fname', ''))
write_geom(file,
geom=arr.target,
valid=arr.metadata.get('target-valid', True),
fname=arr.metadata.get('target-fname', ''))

# gcamorph meta (mgz warp: int int float)
fsio.write_tag(file, fsio.tags.gcamorph_meta, 12)
write_bytes(file, arr.metadata.get('warpfield_dtfmt', 0), dtype='>i4')
write_bytes(file, arr.metadata.get('gcamorph_spacing', 0.0), dtype='>i4')
write_bytes(file, arr.metadata.get('gcamorph_exp_k', 0.0), dtype='>f4')
write_bytes(file, arr.format, dtype='>i4')
write_bytes(file, arr.metadata.get('spacing', 0), dtype='>i4')
write_bytes(file, arr.metadata.get('exp_k', 0.0), dtype='>f4')

# write history tags
for hist in arr.metadata.get('history', []):
fsio.write_tag(file, fsio.tags.history, len(hist))
Expand Down Expand Up @@ -679,7 +706,6 @@ def save(self, arr, filename):
if arr.labels is None:
raise ValueError('overlay must have label lookup if saving as annotation')

#
unknown_mask = arr.data < 0

# make sure all indices exist in the label lookup
Expand Down
Loading

0 comments on commit 3c89ab8

Please sign in to comment.