From 39dbbfa7988ba96acf051c728743f64921f00ae1 Mon Sep 17 00:00:00 2001 From: Malte Hoffmann Date: Sat, 17 Feb 2024 07:59:10 -0500 Subject: [PATCH 1/8] Access sf.transform.Warp as sf.Warp. --- surfa/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/surfa/__init__.py b/surfa/__init__.py index 8012380..eb050fa 100644 --- a/surfa/__init__.py +++ b/surfa/__init__.py @@ -13,6 +13,7 @@ from .core import LabelRecoder from .transform import Affine +from .transform import Warp from .transform import Space from .transform import ImageGeometry From ee81b653875200b610e8473ad1cb13d344efadae Mon Sep 17 00:00:00 2001 From: Malte Hoffmann Date: Sat, 17 Feb 2024 08:06:00 -0500 Subject: [PATCH 2/8] Fix interpolation error messages. --- surfa/image/interp.pyx | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/surfa/image/interp.pyx b/surfa/image/interp.pyx index 718869c..2d9955a 100644 --- a/surfa/image/interp.pyx +++ b/surfa/image/interp.pyx @@ -37,13 +37,13 @@ def interpolate(source, target_shape, method, affine=None, disp=None, fill=0): raise ValueError('interpolation requires an affine transform and/or displacement field') if method not in ('linear', 'nearest'): - raise ValueError(f'interp method must be linear or nearest, but got {method}') + raise ValueError(f'interp method must be linear or nearest, got {method}') if not isinstance(source, np.ndarray): - raise ValueError(f'source data must a numpy array, but got input of type {source.__class__.__name__}') + raise ValueError(f'source data must be a numpy array, got {source.__class__.__name__}') if source.ndim != 4: - raise ValueError(f'source data must be 4D, but got input of shape {target_shape}') + raise ValueError(f'source data must be 4D, but got input of shape {source.shape}') target_shape = tuple(target_shape) if len(target_shape) != 3: @@ -53,7 +53,7 @@ def interpolate(source, target_shape, method, affine=None, disp=None, fill=0): use_affine = affine is not None if use_affine: if not isinstance(affine, np.ndarray): - raise ValueError(f'affine must a numpy array, but got input of type {source.__class__.__name__}') + raise ValueError(f'affine must be a numpy array, got {affine.__class__.__name__}') if not np.array_equal(affine.shape, (4, 4)): raise ValueError(f'affine must be 4x4, but got input of shape {affine.shape}') # only supports float32 affines for now @@ -63,9 +63,9 @@ def interpolate(source, target_shape, method, affine=None, disp=None, fill=0): use_disp = disp is not None if use_disp: if not isinstance(disp, np.ndarray): - raise ValueError(f'source data must a numpy array, but got input of type {source.__class__.__name__}') + raise ValueError(f'source data must be a numpy array, got {disp.__class__.__name__}') if not np.array_equal(disp.shape[:-1], target_shape): - raise ValueError(f'displacement field shape {disp.shape[:-1]} must match target shape {target_shape}') + raise ValueError(f'warp shape {disp.shape[:-1]} must match target shape {target_shape}') # TODO: figure out what would cause this if not disp.flags.c_contiguous and not disp.flags.f_contiguous: From 395713f31fbca182c66fcc09ad7673d4d3a476f2 Mon Sep 17 00:00:00 2001 From: Malte Hoffmann Date: Sat, 17 Feb 2024 14:16:44 -0500 Subject: [PATCH 3/8] Derive Warp from FramedImage to inherit functionality. --- surfa/transform/warp.py | 166 ++++++++++++++++------------------------ 1 file changed, 67 insertions(+), 99 deletions(-) diff --git a/surfa/transform/warp.py b/surfa/transform/warp.py index 00f0c8d..727eecb 100644 --- a/surfa/transform/warp.py +++ b/surfa/transform/warp.py @@ -3,10 +3,12 @@ import numpy as np import surfa as sf +from surfa import transform +from surfa.image import FramedImage from surfa.image.interp import interpolate -class Warp: +class Warp(FramedImage): class Format: """ @@ -26,12 +28,8 @@ class Format: # # constructor - def __init__(self, data=None, source=None, target=None, - spacing=1, exp_k=0.0, format=Format.abs_crs): + def __init__(self, data, source=None, target=None, format=Format.abs_crs, **kwargs): """ - Class constructor. - When it is invoked without any parameters, load(mgzwarp) call is needed after the object is created. - Class variables: _data: deformation field, 4D numpy array (c, r, s, 3) The _data (width x height x depth x nframes) is indexed by atlas CRS. @@ -44,8 +42,6 @@ def __init__(self, data=None, source=None, target=None, _format: Warp.Format _source: ImageGeometry, source image _target: ImageGeometry, target image - _spacing: int (this is from m3z, not sure if it is really needed) - _exp_k: double (this is from m3z, not sure if it is really needed) Parameters ---------- @@ -57,21 +53,13 @@ def __init__(self, data=None, source=None, target=None, target geometry format : Format deformation field format - spacing : int - exp_k : double + **kwargs + Extra arguments provided to the FramedImage superclass. """ - - if (data is None and source is None and target is None): - return - elif (data is not None and source is not None and target is not None): - self._data = data - self._format = format - self._source = source - self._target = target - self._spacing = spacing - self._exp_k = exp_k - else: - raise ValueError('Warp constructor: input parameters error') + self.format = format + self.source = source + basedim = len(data.shape) - 1 + super().__init__(basedim, data, geometry=target, **kwargs) # @@ -82,10 +70,26 @@ def __call__(self, *args, **kwargs): """ return self.transform(*args, **kwargs) - + def new(self, data, source=None, target=None, format=None): + """ + Return a new instance of the warp with updated data. Geometries and format are + preserved unless specified. + """ + if source is None: + source = self.source + + if target is None: + target = self.target + + if format is None: + format = self.format + + return self.__class__(data, source, target, metadata=self.metadata) + # # Read input mgz warp file - def load(self, filename): + @staticmethod + def load(filename): """ Read input mgz warp file, set up deformation field, source/target geometry @@ -94,34 +98,23 @@ def load(self, filename): filename : string input mgz warp file """ - - mgzwarp = sf.load_volume(filename) + warp = sf.load_volume(filename) - # check if mgzwarp is a volume - if (not isinstance(mgzwarp, sf.image.framed.Volume)): - raise ValueError('Warp::load() - input is not a Volume') - - # check if input is a mgzwarp (intent FramedArrayIntents.warpmap) - if (mgzwarp.metadata['intent'] != sf.core.framed.FramedArrayIntents.warpmap): - raise ValueError('Warp::load() - input is not a mgzwarp Volume') - - self._data = mgzwarp.data - self._format = mgzwarp.metadata['warpfield_dtfmt'] - - # create ImageGeometry object self._source from mgzwarp.metadata['gcamorph_volgeom_src'] - self._source = sf.transform.geometry.volgeom_dict2image_geometry(mgzwarp.metadata['gcamorph_volgeom_src']) + if not isinstance(warp, sf.image.framed.Volume): + raise ValueError('input is not a Volume') - # create ImageGeometry object self._target from mgzwarp.metadata['gcamorph_volgeom_trg'] - self._target = sf.transform.geometry.volgeom_dict2image_geometry(mgzwarp.metadata['gcamorph_volgeom_trg']) + if warp.metadata['intent'] != sf.core.framed.FramedArrayIntents.warpmap: + raise ValueError('input is not a warp Volume') - # not sure if these two are necessary - self._spacing = mgzwarp.metadata['gcamorph_spacing'] - self._exp_k = mgzwarp.metadata['gcamorph_exp_k'] + format = warp.metadata['warpfield_dtfmt'] + source = transform.volgeom_dict2image_geometry(warp.metadata['gcamorph_volgeom_src']) + target = transform.volgeom_dict2image_geometry(warp.metadata['gcamorph_volgeom_trg']) + return super().__class__(warp.data, source, target, format, metadata=warp.metadata) # # output _data as mgz warp - def save(self, filename): + def save(self, filename, fmt=None): """ Output _data as mgz warp volume @@ -129,22 +122,13 @@ def save(self, filename): ---------- filename : string output mgz warp file + fmt : str + Optional file format to force. """ - - # create a volume from _data - mgzwarp = sf.image.framed.cast_image(self._data, fallback_geom=self._target) - - # set metadata - mgzwarp.metadata['intent'] = sf.core.framed.FramedArrayIntents.warpmap - mgzwarp.metadata['gcamorph_volgeom_src'] = sf.transform.geometry.image_geometry2volgeom_dict(self._source) - mgzwarp.metadata['gcamorph_volgeom_trg'] = sf.transform.geometry.image_geometry2volgeom_dict(self._target) - - mgzwarp.metadata['warpfield_dtfmt'] = self._format - mgzwarp.metadata['gcamorph_spacing'] = self._spacing - mgzwarp.metadata['gcamorph_exp_k'] = self._exp_k - - # output the volume as mgz warp - mgzwarp.save(filename, None, sf.core.framed.FramedArrayIntents.warpmap) + self.metadata['warpfield_dtfmt'] = self.format + self.metadata['gcamorph_volgeom_src'] = transform.image_geometry2volgeom_dict(self.source) + self.metadata['gcamorph_volgeom_trg'] = transform.image_geometry2volgeom_dict(self.target) + super().save(filename, fmt=fmt, intent=sf.core.framed.FramedArrayIntents.warpmap) # @@ -169,9 +153,9 @@ def convert(self, newformat=Format.abs_crs): return self._data # cast vox2world.matrix and world2vox.matrix to float32 - src_vox2ras = self._source.vox2world.matrix.astype('float32') - src_ras2vox = self._source.world2vox.matrix.astype('float32') - trg_vox2ras = self._target.vox2world.matrix.astype('float32') + src_vox2ras = self.source.vox2world.matrix.astype('float32') + src_ras2vox = self.source.world2vox.matrix.astype('float32') + trg_vox2ras = self.target.vox2world.matrix.astype('float32') # reshape self._data to (3, n) array, n = c * s * r transform = self._data.astype('float32') @@ -286,64 +270,48 @@ def transform(self, image, method='linear', fill=0): raise ValueError(f'deformation ({self._data.shape[-1]}D) does not match ' f'dimensionality of image ({image.basedim}D)') - """ - # get the image in the space of the deformation - #source_data = image.resample_like(self._target).framed_data - """ - source_data = image.framed_data # convert deformation field to disp_crs deformationfield = self.convert(self.Format.disp_crs) # do the interpolation, the function assumes disp_crs deformation field - interpolated = interpolate(source=source_data, - target_shape=self._target.shape, + interpolated = interpolate(source=image.framed_data, + target_shape=self.geom.shape, method=method, disp=deformationfield, fill=fill) - - deformed = image.new(interpolated, self._target) - - return deformed - - # - # _data getter and setter - @property - def data(self): - return self._data - @data.setter - def data(self, warp): - self._data = warp + return self.new(interpolated) # # _format getter and setter @property - def deformationformat(self): + def format(self): return self._format - @deformationformat.setter - def deformationformat(self, format): - self._format = format + @format.setter + def format(self, format): + self._format = format - # - # _source getter and setter @property def source(self): + """ + Source (or moving) image geometry. + """ return self._source - @source.setter - def source(self, geom): - self._source = geom + @source.setter + def source(self, value): + self._source = transform.cast_image_geometry(value, copy=True) - # - # _target getter and setter @property def target(self): - return self._target - @target.setter - def target(self, geom): - self._target = geom - + """ + Target (or fixed) image geometry. + """ + return self.geom + @target.setter + def target(self, value): + self.geom = value From 0828c17eb8e8aab8620e69f81ed3c8e6ea3a431d Mon Sep 17 00:00:00 2001 From: Malte Hoffmann Date: Sat, 17 Feb 2024 16:15:02 -0500 Subject: [PATCH 4/8] Move Warp.load to sf.load_warp, remove duplicate metadata. --- surfa/__init__.py | 1 + surfa/io/__init__.py | 1 + surfa/io/framed.py | 69 +++++++++++++++------ surfa/io/utils.py | 117 +++++++++++++++++------------------- surfa/transform/__init__.py | 2 - surfa/transform/geometry.py | 72 ---------------------- surfa/transform/warp.py | 32 +--------- 7 files changed, 109 insertions(+), 185 deletions(-) diff --git a/surfa/__init__.py b/surfa/__init__.py index eb050fa..568b348 100644 --- a/surfa/__init__.py +++ b/surfa/__init__.py @@ -30,6 +30,7 @@ from .io import load_affine from .io import load_label_lookup from .io import load_mesh +from .io import load_warp from . import vis from . import freesurfer diff --git a/surfa/io/__init__.py b/surfa/io/__init__.py index 82b07e2..568b785 100644 --- a/surfa/io/__init__.py +++ b/surfa/io/__init__.py @@ -3,5 +3,6 @@ from .framed import load_volume from .framed import load_slice from .framed import load_overlay +from .framed import load_warp from .labels import load_label_lookup from .mesh import load_mesh diff --git a/surfa/io/framed.py b/surfa/io/framed.py index 19bd311..5e6a1a2 100644 --- a/surfa/io/framed.py +++ b/surfa/io/framed.py @@ -6,7 +6,7 @@ from surfa import Volume from surfa import Slice from surfa import Overlay -from surfa import ImageGeometry +from surfa import Warp from surfa.core.array import pad_vector_length from surfa.core.framed import FramedArray from surfa.core.framed import FramedArrayIntents @@ -17,8 +17,8 @@ from surfa.io.utils import write_int from surfa.io.utils import read_bytes from surfa.io.utils import write_bytes -from surfa.io.utils import read_volgeom -from surfa.io.utils import write_volgeom +from surfa.io.utils import read_geom +from surfa.io.utils import write_geom from surfa.io.utils import check_file_readability @@ -82,6 +82,25 @@ def load_overlay(filename, fmt=None): return load_framed_array(filename=filename, atype=Overlay, fmt=fmt) +def load_warp(filename, fmt=None): + """ + Load an image `Warp` from a 3D or 4D array file. + + Parameters + ---------- + filename : str + File path to read. + fmt : str, optional + Explicit file format. If None, we extrapolate from the file extension. + + Returns + ------- + Warp + Loaded warp. + """ + return load_framed_array(filename=filename, atype=Warp, fmt=fmt) + + def load_framed_array(filename, atype, fmt=None): """ Generic loader for `FramedArray` objects. @@ -171,6 +190,10 @@ def framed_array_from_4d(atype, data): # this code is a bit ugly - it does the job but should probably be cleaned up if atype == Volume: return atype(data) + if atype == Warp: + if data.ndim == 4 and data.shape[-1] == 2: + data = data.squeeze(-2) + return atype(data) # slice if data.ndim == 3: data = np.expand_dims(data, -1) @@ -323,17 +346,19 @@ def load(self, filename, atype): # gcamorph src & trg geoms (mgz warp) elif tag == fsio.tags.gcamorph_geom: - # read src vol geom - arr.metadata['gcamorph_volgeom_src'] = read_volgeom(file) + arr.source, valid, fname = read_geom(file) + arr.metadata['source-valid'] = valid + arr.metadata['source-fname'] = fname + + arr.target, valid, fname = read_geom(file) + arr.metadata['target-valid'] = valid + arr.metadata['target-fname'] = fname - # read trg vol geom - arr.metadata['gcamorph_volgeom_trg'] = read_volgeom(file) - # gcamorph meta (mgz warp: int int float) elif tag == fsio.tags.gcamorph_meta: - arr.metadata['warpfield_dtfmt'] = read_bytes(file, dtype='>i4') - arr.metadata['gcamorph_spacing'] = read_bytes(file, dtype='>i4') - arr.metadata['gcamorph_exp_k'] = read_bytes(file, dtype='>f4') + arr.format = read_bytes(file, dtype='>i4') + arr.metadata['spacing'] = read_bytes(file, dtype='>i4') + arr.metadata['exp_k'] = read_bytes(file, dtype='>f4') # skip everything else else: @@ -438,18 +463,24 @@ def save(self, arr, filename, intent=FramedArrayIntents.mri): write_bytes(file, arr.metadata.get('field-strength', 0.0), '>f4') # gcamorph geom and gcamorph meta for mgz warp - if (intent == FramedArrayIntents.warpmap): + if intent == FramedArrayIntents.warpmap: # gcamorph src & trg geoms (mgz warp) fsio.write_tag(file, fsio.tags.gcamorph_geom) - write_volgeom(file, arr.metadata['gcamorph_volgeom_src']) - write_volgeom(file, arr.metadata['gcamorph_volgeom_trg']) - + write_geom(file, + geom=arr.source, + valid=arr.metadata.get('source-valid', True), + fname=arr.metadata.get('source-fname', '')) + write_geom(file, + geom=arr.target, + valid=arr.metadata.get('target-valid', True), + fname=arr.metadata.get('target-fname', '')) + # gcamorph meta (mgz warp: int int float) fsio.write_tag(file, fsio.tags.gcamorph_meta, 12) - write_bytes(file, arr.metadata.get('warpfield_dtfmt', 0), dtype='>i4') - write_bytes(file, arr.metadata.get('gcamorph_spacing', 0.0), dtype='>i4') - write_bytes(file, arr.metadata.get('gcamorph_exp_k', 0.0), dtype='>f4') - + write_bytes(file, arr.format, dtype='>i4') + write_bytes(file, arr.metadata.get('spacing', 0), dtype='>i4') + write_bytes(file, arr.metadata.get('exp_k', 0.0), dtype='>f4') + # write history tags for hist in arr.metadata.get('history', []): fsio.write_tag(file, fsio.tags.history, len(hist)) diff --git a/surfa/io/utils.py b/surfa/io/utils.py index 78d42e1..8365bd4 100644 --- a/surfa/io/utils.py +++ b/surfa/io/utils.py @@ -1,6 +1,8 @@ import os import numpy as np +from surfa import ImageGeometry + def check_file_readability(filename): """ @@ -100,66 +102,57 @@ def write_bytes(file, value, dtype): file.write(np.asarray(value).astype(dtype, copy=False).tobytes()) -# read VOL_GEOM -# also see VOL_GEOM.read() in mri.h -def read_volgeom(file): - volgeom = dict( - valid = read_bytes(file, '>i4', 1), - - width = read_bytes(file, '>i4', 1), - height = read_bytes(file, '>i4', 1), - depth = read_bytes(file, '>i4', 1), - - xsize = read_bytes(file, '>f4', 1), - ysize = read_bytes(file, '>f4', 1), - zsize = read_bytes(file, '>f4', 1), - - x_r = read_bytes(file, '>f4', 1), - x_a = read_bytes(file, '>f4', 1), - x_s = read_bytes(file, '>f4', 1), - y_r = read_bytes(file, '>f4', 1), - y_a = read_bytes(file, '>f4', 1), - y_s = read_bytes(file, '>f4', 1), - z_r = read_bytes(file, '>f4', 1), - z_a = read_bytes(file, '>f4', 1), - z_s = read_bytes(file, '>f4', 1), - - c_r = read_bytes(file, '>f4', 1), - c_a = read_bytes(file, '>f4', 1), - c_s = read_bytes(file, '>f4', 1), - - fname = file.read(512).decode('utf-8').rstrip('\x00') - ) - return volgeom - - -# output VOL_GEOM -# also see VOL_GEOM.write() in mri.h -def write_volgeom(file, volgeom): - write_bytes(file, volgeom['valid'], '>i4') - - write_bytes(file, volgeom['width'], '>i4') - write_bytes(file, volgeom['height'], '>i4') - write_bytes(file, volgeom['depth'], '>i4') - - write_bytes(file, volgeom['xsize'], '>f4') - write_bytes(file, volgeom['ysize'], '>f4') - write_bytes(file, volgeom['zsize'], '>f4') - - write_bytes(file, volgeom['x_r'], '>f4') - write_bytes(file, volgeom['x_a'], '>f4') - write_bytes(file, volgeom['x_s'], '>f4') - write_bytes(file, volgeom['y_r'], '>f4') - write_bytes(file, volgeom['y_a'], '>f4') - write_bytes(file, volgeom['y_s'], '>f4') - write_bytes(file, volgeom['z_r'], '>f4') - write_bytes(file, volgeom['z_a'], '>f4') - write_bytes(file, volgeom['z_s'], '>f4') - - write_bytes(file, volgeom['c_r'], '>f4') - write_bytes(file, volgeom['c_a'], '>f4') - write_bytes(file, volgeom['c_s'], '>f4') - - # output 512 bytes padded with '/x00' - file.write(volgeom['fname'].ljust(512, '\x00').encode('utf-8')) +def read_geom(file): + """ + Read an image geometry from a binary file buffer. See VOL_GEOM.read() in mri.h. + + Parameters + ---------- + file : BufferedReader + Opened file buffer. + + Returns + ------- + ImageGeometry + Image geometry. + bool + True if the geometry is valid. + str + File name associated with the geometry. + """ + valid = bool(read_bytes(file, '>i4', 1)) + geom = ImageGeometry( + shape=read_bytes(file, '>i4', 3).astype(int), + voxsize=read_bytes(file, '>f4', 3), + rotation=read_bytes(file, '>f4', 9).reshape((3, 3), order='F'), + center=read_bytes(file, '>f4', 3), + ) + fname = file.read(512).decode('utf-8').rstrip('\x00') + return geom, valid, fname + + +def write_geom(file, geom, valid=True, fname=''): + """ + Write an image geometry to a binary file buffer. See VOL_GEOM.write() in mri.h. + + Parameters + ---------- + file : BufferedReader + Opened file buffer. + geom : ImageGeometry + Image geometry. + valid : bool + True if the geometry is valid. + fname : str + File name associated with the geometry. + """ + write_bytes(file, valid, '>i4') + + voxsize, rotation, center = geom.shearless_components() + write_bytes(file, geom.shape, '>i4') + write_bytes(file, voxsize, '>f4') + write_bytes(file, np.ravel(rotation, order='F'), '>f4') + write_bytes(file, center, '>f4') + # right-pad with '/x00' to 512 bytes + file.write(fname[:512].ljust(512, '\x00').encode('utf-8')) diff --git a/surfa/transform/__init__.py b/surfa/transform/__init__.py index 4534062..10061b5 100644 --- a/surfa/transform/__init__.py +++ b/surfa/transform/__init__.py @@ -11,7 +11,5 @@ from .geometry import ImageGeometry from .geometry import cast_image_geometry from .geometry import image_geometry_equal -from .geometry import image_geometry2volgeom_dict -from .geometry import volgeom_dict2image_geometry from .warp import Warp diff --git a/surfa/transform/geometry.py b/surfa/transform/geometry.py index a1c1f5e..15db214 100644 --- a/surfa/transform/geometry.py +++ b/surfa/transform/geometry.py @@ -544,75 +544,3 @@ def image_geometry_equal(a, b, tol=0.0): return False return True - - -# -# create volgeom dict from an ImageGeometry object -def image_geometry2volgeom_dict(imagegeometryObj): - """ - Create vol_geom dict from an ImageGeometry object - - Parameters - ---------- - imagegeometryObj : ImageGeometry - input ImageGeometry object - - Returns - ------- - volgeom : dict - vol_geom dict - """ - - volgeom = dict( - valid = 1, - width = imagegeometryObj.shape[0], - height = imagegeometryObj.shape[1], - depth = imagegeometryObj.shape[2], - - xsize = imagegeometryObj.voxsize[0], - ysize = imagegeometryObj.voxsize[1], - zsize = imagegeometryObj.voxsize[2], - - x_r = imagegeometryObj.rotation[:,0][0], - x_a = imagegeometryObj.rotation[:,0][1], - x_s = imagegeometryObj.rotation[:,0][2], - y_r = imagegeometryObj.rotation[:,1][0], - y_a = imagegeometryObj.rotation[:,1][1], - y_s = imagegeometryObj.rotation[:,1][2], - z_r = imagegeometryObj.rotation[:,2][0], - z_a = imagegeometryObj.rotation[:,2][1], - z_s = imagegeometryObj.rotation[:,2][2], - - c_r = imagegeometryObj.center[0], - c_a = imagegeometryObj.center[1], - c_s = imagegeometryObj.center[2], - - fname = '' - ) - return volgeom - - -# -# create an ImageGeometry object from volgeom dict -def volgeom_dict2image_geometry(volgeom): - """ - Create vol_geom dict from an ImageGeometry object - - Parameters - ---------- - volgeom : dict - volgeom dict - - Returns - ------- - imagegeometryObj : ImageGeometry - input ImageGeometry object - """ - - imagegeom = ImageGeometry( - shape = np.array([volgeom['width'], volgeom['height'], volgeom['depth']], dtype=int), - center = np.array([volgeom['c_r'], volgeom['c_a'], volgeom['c_s']]), - rotation = np.array([[volgeom['x_r'], volgeom['y_r'], volgeom['z_r']], [volgeom['x_a'], volgeom['y_a'], volgeom['z_a']], [volgeom['x_s'], volgeom['y_s'], volgeom['z_s']]]), - voxsize = np.array([volgeom['xsize'], volgeom['ysize'], volgeom['zsize']]) - ) - return imagegeom diff --git a/surfa/transform/warp.py b/surfa/transform/warp.py index 727eecb..6638c6d 100644 --- a/surfa/transform/warp.py +++ b/surfa/transform/warp.py @@ -86,48 +86,20 @@ def new(self, data, source=None, target=None, format=None): return self.__class__(data, source, target, metadata=self.metadata) - # - # Read input mgz warp file - @staticmethod - def load(filename): - """ - Read input mgz warp file, set up deformation field, source/target geometry - - Parameters - ---------- - filename : string - input mgz warp file - """ - warp = sf.load_volume(filename) - - if not isinstance(warp, sf.image.framed.Volume): - raise ValueError('input is not a Volume') - - if warp.metadata['intent'] != sf.core.framed.FramedArrayIntents.warpmap: - raise ValueError('input is not a warp Volume') - - format = warp.metadata['warpfield_dtfmt'] - source = transform.volgeom_dict2image_geometry(warp.metadata['gcamorph_volgeom_src']) - target = transform.volgeom_dict2image_geometry(warp.metadata['gcamorph_volgeom_trg']) - return super().__class__(warp.data, source, target, format, metadata=warp.metadata) - # # output _data as mgz warp def save(self, filename, fmt=None): """ - Output _data as mgz warp volume + Write warp to file. Parameters ---------- filename : string - output mgz warp file + Filename to write to. fmt : str Optional file format to force. """ - self.metadata['warpfield_dtfmt'] = self.format - self.metadata['gcamorph_volgeom_src'] = transform.image_geometry2volgeom_dict(self.source) - self.metadata['gcamorph_volgeom_trg'] = transform.image_geometry2volgeom_dict(self.target) super().save(filename, fmt=fmt, intent=sf.core.framed.FramedArrayIntents.warpmap) From d521dbae61ceb16d4fb369c5de5a4df4835f4c7c Mon Sep 17 00:00:00 2001 From: Malte Hoffmann Date: Sat, 17 Feb 2024 18:21:32 -0500 Subject: [PATCH 5/8] Revert bug that attempts to create warp instead of image. --- surfa/transform/warp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/surfa/transform/warp.py b/surfa/transform/warp.py index 6638c6d..19e41cb 100644 --- a/surfa/transform/warp.py +++ b/surfa/transform/warp.py @@ -253,7 +253,7 @@ def transform(self, image, method='linear', fill=0): disp=deformationfield, fill=fill) - return self.new(interpolated) + return image.new(interpolated, geometry=self.target) # From 6ebc11ae20fd060b373fc439b0cad83052b06efb Mon Sep 17 00:00:00 2001 From: Malte Hoffmann Date: Sat, 17 Feb 2024 18:36:59 -0500 Subject: [PATCH 6/8] Clean up white space, a few comments, etc. --- surfa/image/framed.py | 39 +++++------ surfa/io/framed.py | 17 ++--- surfa/transform/affine.py | 45 ++++++------ surfa/transform/warp.py | 141 +++++++++++++++++--------------------- 4 files changed, 105 insertions(+), 137 deletions(-) diff --git a/surfa/image/framed.py b/surfa/image/framed.py index 36f9a6d..396c6d3 100644 --- a/surfa/image/framed.py +++ b/surfa/image/framed.py @@ -392,7 +392,7 @@ def transform(self, trf=None, method='linear', rotation='corner', resample=True, **Note on trf argument:** It accepts Affine/Warp object, or deformation fields (4D numpy array). Pass trf argument as a numpy array is deprecated and will be removed in the future. - It is assumed that the deformation fields represent a *displacement* vector field in voxel space. + It is assumed that the deformation fields represent a *displacement* vector field in voxel space. So under the hood, images will be moved into the space of the deformation field if the image geometries differ. Parameters @@ -420,28 +420,26 @@ def transform(self, trf=None, method='linear', rotation='corner', resample=True, 'contact andrew if you need this') image = self.copy() - transformer = trf - if isinstance(transformer, Affine): - return transformer.transform(image, method, rotation, resample, fill) - + if isinstance(trf, Affine): + return trf.transform(image, method, rotation, resample, fill) + from surfa.transform.warp import Warp - if isinstance(transformer, np.ndarray): + if isinstance(trf, np.ndarray): warnings.warn('The option to pass \'trf\' argument as a numpy array is deprecated. ' 'Pass \'trf\' as either an Affine or Warp object', DeprecationWarning, stacklevel=2) - + deformation = cast_image(trf, fallback_geom=self.geom) image = image.resample_like(deformation) - transformer = Warp(data=trf, - source=image.geom, - target=deformation.geom, - format=Warp.Format.disp_crs) - - if not isinstance(transformer, Warp): - raise ValueError("Pass \'trf\' as either an Affine or Warp object") - - return transformer.transform(image, method, fill) - + trf = Warp(data=trf, + source=image.geom, + target=deformation.geom, + format=Warp.Format.disp_crs) + + if isinstance(trf, Warp): + return trf.transform(image, method, fill) + + raise ValueError("Pass \'trf\' as either an Affine or Warp object") def reorient(self, orientation, copy=True): """ @@ -707,20 +705,19 @@ def barycenters(self, labels=None, space='image'): one, the barycenter array will be of shape $(F, L, D)$. """ if labels is not None: - # + if not np.issubdtype(self.dtype, np.integer): raise ValueError('expected int dtype for computing barycenters on 1D, ' f'but got dtype {self.dtype}') weights = np.ones(self.baseshape, dtype=np.float32) centers = [center_of_mass(weights, self.framed_data[..., i], labels) for i in range(self.nframes)] else: - # + centers = [center_of_mass(self.framed_data[..., i]) for i in range(self.nframes)] - # + centers = np.squeeze(centers) - # space = cast_space(space) if space != 'image': centers = self.geom.affine('image', space)(centers) diff --git a/surfa/io/framed.py b/surfa/io/framed.py index 5e6a1a2..e8866d3 100644 --- a/surfa/io/framed.py +++ b/surfa/io/framed.py @@ -31,8 +31,7 @@ def load_volume(filename, fmt=None): filename : str File path to read. fmt : str, optional - Explicit file format. If None (default), the format is extrapolated - from the file extension. + Explicit file format. If None, we extrapolate from the file extension. Returns ------- @@ -51,8 +50,7 @@ def load_slice(filename, fmt=None): filename : str File path to read. fmt : str, optional - Explicit file format. If None (default), the format is extrapolated - from the file extension. + Explicit file format. If None, we extrapolate from the file extension. Returns ------- @@ -71,8 +69,7 @@ def load_overlay(filename, fmt=None): filename : str File path to read. fmt : str, optional - Explicit file format. If None (default), the format is extrapolated - from the file extension. + Explicit file format. If None, we extrapolate from the file extension. Returns ------- @@ -112,13 +109,12 @@ def load_framed_array(filename, atype, fmt=None): atype : class Particular FramedArray subclass to read into. fmt : str, optional - Forced file format. If None (default), file format is extrapolated - from extension. + Explicit file format. If None, we extrapolate from the file extension. Returns ------- FramedArray - Loaded framed array. + Loaded framed array. """ check_file_readability(filename) @@ -309,7 +305,7 @@ def load(self, filename, atype): # it's also not required in the freesurfer definition, so we'll # use the read() function directly in case end-of-file is reached file.read(np.dtype('>f4').itemsize) - + # update image-specific information if isinstance(arr, FramedImage): arr.geom.update(**geom_params) @@ -710,7 +706,6 @@ def save(self, arr, filename): if arr.labels is None: raise ValueError('overlay must have label lookup if saving as annotation') - # unknown_mask = arr.data < 0 # make sure all indices exist in the label lookup diff --git a/surfa/transform/affine.py b/surfa/transform/affine.py index 08f03a2..bdb2cad 100644 --- a/surfa/transform/affine.py +++ b/surfa/transform/affine.py @@ -34,7 +34,7 @@ def __init__(self, matrix, source=None, target=None, space=None): self._target = None # set the actual values - self.matrix = matrix + self.matrix = matrix self.space = space self.source = source self.target = target @@ -69,7 +69,7 @@ def matrix(self): The (N, N) affine matrix array. """ return self._matrix - + @matrix.setter def matrix(self, mat): # check writeable @@ -118,7 +118,7 @@ def space(self): Coordinate space of the transform. """ return self._space - + @space.setter def space(self, value): self._check_writeability() @@ -130,7 +130,7 @@ def source(self): Source (or moving) image geometry. """ return self._source - + @source.setter def source(self, value): self._check_writeability() @@ -142,7 +142,7 @@ def target(self): Target (or fixed) image geometry. """ return self._target - + @target.setter def target(self, value): self._check_writeability() @@ -207,10 +207,10 @@ def transform(self, data, method='linear', rotation='corner', resample=True, fil Parameters ---------- - data : input data to transform - N-D point values, or image Volume + data : (..., N) float or Volume + Input coordinates or image to transform. method : {'linear', 'nearest'} - Image interpolation method if resample is enabled. + Image interpolation method if `resample` is enabled. rotation : {'corner', 'center'} Apply affine with rotation axis at the image corner or center. resample : bool @@ -224,23 +224,21 @@ def transform(self, data, method='linear', rotation='corner', resample=True, fil Returns ------- - (..., N) float - Transformed N-D point array if (input data is N-D point) - - transformed : Volume - Transformed image if (input data is an image Volume) + (..., N) float or Volume + Transformed N-D point array if the input is N-D point data, or transformed image if the + input is an image. """ if points is not None: data = points warnings.warn('The \'points\' argument to transform() is deprecated. Just use ' 'the first positional argument to specify set of points or an image to transform.', DeprecationWarning, stacklevel=2) - + # a common mistake is to use this function for transforming a mesh, # so run this check to help the user out a bit if ismesh(data): raise ValueError('use mesh.transform(affine) to apply an affine to a mesh') - + if isimage(data): return self.__transform_image(data, method, rotation, resample, fill) @@ -395,22 +393,20 @@ def __transform_image(self, image, method='linear', rotation='corner', resample= Parameters ---------- image : Volume - input image Volume + Input image Volume. Returns ------- - transformed : Volume - transformed image + Volume + Transformed image. """ - if image.basedim == 2: raise NotImplementedError('Affine.transform() is not yet implemented for 2D data') - + affine = self.copy() - + # if not resampling, just change the image vox2world matrix and return if not resample: - # TODO: if affine is missing geometry info, do we assume that the affine # is in world space or voxel space? let's do world for now if affine.source is not None and affine.target is not None: @@ -440,7 +436,7 @@ def __transform_image(self, image, method='linear', rotation='corner', resample= if affine.space is None: affine = affine.copy() affine.space = 'voxel' - # + affine = affine.convert(space='voxel', source=image) target_geom = affine.target elif affine.space is not None and affine.space != 'voxel': @@ -452,7 +448,7 @@ def __transform_image(self, image, method='linear', rotation='corner', resample= raise ValueError("rotation must be 'center' or 'corner'") elif rotation == 'center': affine = center_to_corner_rotation(affine, image.baseshape) - + # make sure the matrix is actually inverted since we want a target to # source voxel mapping for resampling matrix_data = affine.inv().matrix @@ -468,7 +464,6 @@ def __transform_image(self, image, method='linear', rotation='corner', resample= return image.new(interpolated, target_geom) - def affine_equal(a, b, matrix_only=False, tol=0.0): """ Test whether two affine transforms are equivalent. diff --git a/surfa/transform/warp.py b/surfa/transform/warp.py index 19e41cb..bf9557d 100644 --- a/surfa/transform/warp.py +++ b/surfa/transform/warp.py @@ -1,4 +1,3 @@ -import copy import warnings import numpy as np @@ -18,22 +17,18 @@ class Format: abs_ras - RAS coordinates in image space disp_ras - displacement RAS, delta = image_RAS - atlas_RAS """ - abs_crs = 0 disp_crs = 1 abs_ras = 2 disp_ras = 3 - - # - # constructor def __init__(self, data, source=None, target=None, format=Format.abs_crs, **kwargs): """ Class variables: _data: deformation field, 4D numpy array (c, r, s, 3) The _data (width x height x depth x nframes) is indexed by atlas CRS. - frame 0 - image voxel ABS coordinate C, image voxel DISP coordinate C, + frame 0 - image voxel ABS coordinate C, image voxel DISP coordinate C, RAS ABS coordinate X, or RAS DISP coordinate X frame 1 - image voxel ABS coordinate R, image voxel DISP coordinate R, RAS ABS coordinate Y, or RAS DISP coordinate Y @@ -46,23 +41,24 @@ def __init__(self, data, source=None, target=None, format=Format.abs_crs, **kwar Parameters ---------- data : 4D numpy array (c, r, s, 3) - dense deformation field + Dense deformation field. source : ImageGeometry - source geometry + Source geometry. target : ImageGeometry - target geometry + Target geometry. format : Format - deformation field format + Deformation field format. **kwargs Extra arguments provided to the FramedImage superclass. """ self.format = format self.source = source - basedim = len(data.shape) - 1 - super().__init__(basedim, data, geometry=target, **kwargs) + basedim = data.shape[-1] + if len(data.shape) != basedim + 1: + raise ValueError('invalid shape {data.shape} for {basedim}D warp') + super().__init__(basedim, data, geometry=target, **kwargs) - # def __call__(self, *args, **kwargs): """ Apply non-linear transform to an image. @@ -86,9 +82,6 @@ def new(self, data, source=None, target=None, format=None): return self.__class__(data, source, target, metadata=self.metadata) - - # - # output _data as mgz warp def save(self, filename, fmt=None): """ Write warp to file. @@ -102,149 +95,137 @@ def save(self, filename, fmt=None): """ super().save(filename, fmt=fmt, intent=sf.core.framed.FramedArrayIntents.warpmap) - - # - # change deformation field data format - # return new deformation field, self._data is not changed def convert(self, newformat=Format.abs_crs): """ - Change deformation field data format + Change deformation field format. Parameters ---------- newformat : Format - output deformation field format + Output deformation field format. Returns ------- data : 4D numpy array (c, r, s, 3) - converted deformation field with newformat + Converted deformation field. """ - - if (self._format == newformat): + compute_type = np.float32 + if self.format == newformat: return self._data # cast vox2world.matrix and world2vox.matrix to float32 - src_vox2ras = self.source.vox2world.matrix.astype('float32') - src_ras2vox = self.source.world2vox.matrix.astype('float32') - trg_vox2ras = self.target.vox2world.matrix.astype('float32') + src_vox2ras = self.source.vox2world.matrix.astype(compute_type) + src_ras2vox = self.source.world2vox.matrix.astype(compute_type) + trg_vox2ras = self.target.vox2world.matrix.astype(compute_type) # reshape self._data to (3, n) array, n = c * s * r - transform = self._data.astype('float32') + transform = self._data.astype(compute_type) transform = transform.reshape(-1, 3) # (n, 3) transform = transform.transpose() # (3, n) # target crs grid corresponding to the reshaped (3, n) array - trg_crs = (np.arange(x, dtype=np.float32) for x in self._data.shape[:3]) + trg_crs = (np.arange(x, dtype=compute_type) for x in self.baseshape) trg_crs = np.meshgrid(*trg_crs, indexing='ij') trg_crs = np.stack(trg_crs) trg_crs = trg_crs.reshape(3, -1) # target ras trg_ras = trg_vox2ras[:3, :3] @ trg_crs + trg_vox2ras[:3, 3:] - - if (self._format == self.Format.abs_crs): - # - if (newformat == self.Format.disp_crs): + + if self._format == Warp.Format.abs_crs: + + if newformat == Warp.Format.disp_crs: # abs_crs => disp_crs deformationfield = transform - trg_crs else: # abs_crs => abs_ras src_ras = src_vox2ras[:3, :3] @ transform + src_vox2ras[:3, 3:] - if (newformat == self.Format.abs_ras): + if newformat == Warp.Format.abs_ras: deformationfield = src_ras - elif (newformat == self.Format.disp_ras): + elif newformat == Warp.Format.disp_ras: # abs_ras => disp_ras deformationfield = src_ras - trg_ras - # - elif (self._format == self.Format.disp_crs): - # - # disp_crs => abs_crs + + elif self._format == Warp.Format.disp_crs: + + # disp_crs => abs_crs src_crs = transform + trg_crs - if (newformat == self.Format.abs_crs): + if newformat == Warp.Format.abs_crs: deformationfield = src_crs else: # abs_crs => abs_ras src_ras = src_vox2ras[:3, :3] @ src_crs + src_vox2ras[:3, 3:] - if (newformat == self.Format.abs_ras): + if newformat == Warp.Format.abs_ras: deformationfield = src_ras - elif (newformat == self.Format.disp_ras): + elif newformat == Warp.Format.disp_ras: # abs_ras => disp_ras deformationfield = src_ras - trg_ras - # - elif (self._format == self.Format.abs_ras): - # - if (newformat == self.Format.disp_ras): + + elif self._format == Warp.Format.abs_ras: + + if newformat == Warp.Format.disp_ras: # abs_ras => disp_ras deformationfield = transform - trg_ras else: - # abs_ras => abs_crs + # abs_ras => abs_crs src_crs = src_ras2vox[:3, :3] @ transform + src_ras2vox[:3, 3:] - if (newformat == self.Format.abs_crs): + if newformat == Warp.Format.abs_crs: deformationfield = src_crs - elif (newformat == self.Format.disp_crs): + elif newformat == Warp.Format.disp_crs: # abs_crs => disp_crs deformationfield = src_crs - trg_crs - # - elif (self._format == self.Format.disp_ras): - # + + elif self._format == Warp.Format.disp_ras: + # disp_ras => abs_ras src_ras = transform + trg_ras - if (newformat == self.Format.abs_ras): + if newformat == Warp.Format.abs_ras: deformationfield = src_ras else: # abs_ras => abs_crs - src_crs = src_ras2vox[:3, :3] @ src_ras + src_ras2vox[:3, 3:] - if (newformat == self.Format.abs_crs): + src_crs = src_ras2vox[:3, :3] @ src_ras + src_ras2vox[:3, 3:] + if newformat == Warp.Format.abs_crs: deformationfield = src_crs - elif (newformat == self.Format.disp_crs): + elif newformat == Warp.Format.disp_crs: # abs_crs => disp_crs deformationfield = src_crs - trg_crs - # # reshape deformationfield to [c, r, s] x 3 deformationfield = deformationfield.transpose() - deformationfield = deformationfield.reshape(*self._data.shape[:3], 3) - - return deformationfield + deformationfield = deformationfield.reshape(self.shape) + return deformationfield - # - # apply _data on given image using Cython interpolation in image/interp.pyx - # return transformed image def transform(self, image, method='linear', fill=0): """ - Apply dense deformation field to input image volume + Apply dense deformation field to input image. Parameters ---------- image : Volume - input image Volume + Input image. method : {'linear', 'nearest'} - Image interpolation method + Interpolation method. fill : scalar Fill value for out-of-bounds voxels. Returns ------- - deformed : Volume - deformed image + Volume + Transfomred image. """ - - # check if image is a Volume - if (not isinstance(image, sf.image.framed.Volume)): + if not isinstance(image, sf.Volume): raise ValueError('Warp.transform() - input is not a Volume') if image.basedim == 2: raise NotImplementedError('Warp.transform() is not yet implemented for 2D data') - - if self._data.shape[-1] != image.basedim: - raise ValueError(f'deformation ({self._data.shape[-1]}D) does not match ' - f'dimensionality of image ({image.basedim}D)') + if self.nframes != image.basedim: + raise ValueError(f'deformation ({self.nframes}D) does not match ' + f'dimensionality of image ({image.basedim}D)') # convert deformation field to disp_crs - deformationfield = self.convert(self.Format.disp_crs) + deformationfield = self.convert(Warp.Format.disp_crs) # do the interpolation, the function assumes disp_crs deformation field interpolated = interpolate(source=image.framed_data, @@ -255,9 +236,6 @@ def transform(self, image, method='linear', fill=0): return image.new(interpolated, geometry=self.target) - - # - # _format getter and setter @property def format(self): return self._format @@ -286,4 +264,7 @@ def target(self): @target.setter def target(self, value): + """ + Set target (or fixed) image geometry. Invokes parent setter. + """ self.geom = value From 9d123bf68ef5e415fb1d49280f9d8f4b4d1980ed Mon Sep 17 00:00:00 2001 From: Malte Hoffmann Date: Sat, 17 Feb 2024 20:10:27 -0500 Subject: [PATCH 7/8] Return `Warp` instead of raw array when converting. --- surfa/transform/warp.py | 48 ++++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/surfa/transform/warp.py b/surfa/transform/warp.py index bf9557d..7a81794 100644 --- a/surfa/transform/warp.py +++ b/surfa/transform/warp.py @@ -95,23 +95,25 @@ def save(self, filename, fmt=None): """ super().save(filename, fmt=fmt, intent=sf.core.framed.FramedArrayIntents.warpmap) - def convert(self, newformat=Format.abs_crs): + def convert(self, format=Format.disp_crs, copy=True): """ Change deformation field format. Parameters ---------- - newformat : Format + format : Format Output deformation field format. + copy : bool + Return copy of object if format already satisfied. Returns ------- - data : 4D numpy array (c, r, s, 3) + Warp Converted deformation field. """ compute_type = np.float32 - if self.format == newformat: - return self._data + if self.format == format: + return self.copy() if copy else self # cast vox2world.matrix and world2vox.matrix to float32 src_vox2ras = self.source.vox2world.matrix.astype(compute_type) @@ -134,15 +136,15 @@ def convert(self, newformat=Format.abs_crs): if self._format == Warp.Format.abs_crs: - if newformat == Warp.Format.disp_crs: + if format == Warp.Format.disp_crs: # abs_crs => disp_crs deformationfield = transform - trg_crs else: # abs_crs => abs_ras src_ras = src_vox2ras[:3, :3] @ transform + src_vox2ras[:3, 3:] - if newformat == Warp.Format.abs_ras: + if format == Warp.Format.abs_ras: deformationfield = src_ras - elif newformat == Warp.Format.disp_ras: + elif format == Warp.Format.disp_ras: # abs_ras => disp_ras deformationfield = src_ras - trg_ras @@ -150,28 +152,28 @@ def convert(self, newformat=Format.abs_crs): # disp_crs => abs_crs src_crs = transform + trg_crs - if newformat == Warp.Format.abs_crs: + if format == Warp.Format.abs_crs: deformationfield = src_crs else: # abs_crs => abs_ras src_ras = src_vox2ras[:3, :3] @ src_crs + src_vox2ras[:3, 3:] - if newformat == Warp.Format.abs_ras: + if format == Warp.Format.abs_ras: deformationfield = src_ras - elif newformat == Warp.Format.disp_ras: + elif format == Warp.Format.disp_ras: # abs_ras => disp_ras deformationfield = src_ras - trg_ras elif self._format == Warp.Format.abs_ras: - if newformat == Warp.Format.disp_ras: + if format == Warp.Format.disp_ras: # abs_ras => disp_ras deformationfield = transform - trg_ras else: # abs_ras => abs_crs src_crs = src_ras2vox[:3, :3] @ transform + src_ras2vox[:3, 3:] - if newformat == Warp.Format.abs_crs: + if format == Warp.Format.abs_crs: deformationfield = src_crs - elif newformat == Warp.Format.disp_crs: + elif format == Warp.Format.disp_crs: # abs_crs => disp_crs deformationfield = src_crs - trg_crs @@ -179,14 +181,14 @@ def convert(self, newformat=Format.abs_crs): # disp_ras => abs_ras src_ras = transform + trg_ras - if newformat == Warp.Format.abs_ras: + if format == Warp.Format.abs_ras: deformationfield = src_ras else: # abs_ras => abs_crs src_crs = src_ras2vox[:3, :3] @ src_ras + src_ras2vox[:3, 3:] - if newformat == Warp.Format.abs_crs: + if format == Warp.Format.abs_crs: deformationfield = src_crs - elif newformat == Warp.Format.disp_crs: + elif format == Warp.Format.disp_crs: # abs_crs => disp_crs deformationfield = src_crs - trg_crs @@ -194,7 +196,7 @@ def convert(self, newformat=Format.abs_crs): deformationfield = deformationfield.transpose() deformationfield = deformationfield.reshape(self.shape) - return deformationfield + return self.new(deformationfield, format=format) def transform(self, image, method='linear', fill=0): """ @@ -224,14 +226,12 @@ def transform(self, image, method='linear', fill=0): raise ValueError(f'deformation ({self.nframes}D) does not match ' f'dimensionality of image ({image.basedim}D)') - # convert deformation field to disp_crs - deformationfield = self.convert(Warp.Format.disp_crs) - - # do the interpolation, the function assumes disp_crs deformation field + # interpolation assumes disp_crs format + disp = self.convert(Warp.Format.disp_crs, copy=False) interpolated = interpolate(source=image.framed_data, - target_shape=self.geom.shape, + target_shape=self.baseshape, method=method, - disp=deformationfield, + disp=disp.framed_data, fill=fill) return image.new(interpolated, geometry=self.target) From 76a815e983e5c2ea4084139dd1784bf75de4c966 Mon Sep 17 00:00:00 2001 From: Malte Hoffmann Date: Sat, 17 Feb 2024 22:28:26 -0500 Subject: [PATCH 8/8] Correctly set format in `Warp.new` and change default. --- surfa/io/framed.py | 2 +- surfa/transform/warp.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/surfa/io/framed.py b/surfa/io/framed.py index e8866d3..518c1a5 100644 --- a/surfa/io/framed.py +++ b/surfa/io/framed.py @@ -181,7 +181,7 @@ def framed_array_from_4d(atype, data): Returns ------- FramedArray - Squeezed framed array. + Squeezed framed array. """ # this code is a bit ugly - it does the job but should probably be cleaned up if atype == Volume: diff --git a/surfa/transform/warp.py b/surfa/transform/warp.py index 7a81794..6f0506b 100644 --- a/surfa/transform/warp.py +++ b/surfa/transform/warp.py @@ -23,7 +23,7 @@ class Format: disp_ras = 3 - def __init__(self, data, source=None, target=None, format=Format.abs_crs, **kwargs): + def __init__(self, data, source=None, target=None, format=Format.disp_crs, **kwargs): """ Class variables: _data: deformation field, 4D numpy array (c, r, s, 3) @@ -80,7 +80,7 @@ def new(self, data, source=None, target=None, format=None): if format is None: format = self.format - return self.__class__(data, source, target, metadata=self.metadata) + return self.__class__(data, source, target, format=format, metadata=self.metadata) def save(self, filename, fmt=None): """