diff --git a/pystackreg/pystackreg.py b/pystackreg/pystackreg.py index 1fcbe8e..7056941 100644 --- a/pystackreg/pystackreg.py +++ b/pystackreg/pystackreg.py @@ -1,8 +1,11 @@ # -*- coding: utf-8 -*- -from . import turboreg # type: ignore +import warnings +from multiprocessing import Pool + import numpy as np from tqdm import tqdm -import warnings + +from . import turboreg # type: ignore def simple_slice(arr, inds, axis): @@ -122,6 +125,13 @@ def is_registered(self): :return: True if a transformation matrix was already calculated """ return self._is_registered + + def reg_help(self, x): + ''' + Helper function for multiprocessing parallelization + ''' + ref, img, i = x + return self.register(ref,img),i def register(self, ref, mov): """ @@ -147,6 +157,7 @@ def register(self, ref, mov): return self.get_matrix() + def transform(self, mov, tmat=None): """ Transform an image according to a previous registration. @@ -309,6 +320,7 @@ def register_stack( verbose=False, progress_callback=None, suppress_axis_warning=False, + processes=1, ): """ Register a stack of images (movie). @@ -358,6 +370,11 @@ def register_stack( warning when the detected time axis is not equal to the supplied axis. Set this option to True to suppress this warning. + :type processes: int, optional + :param processes: + Set the number of processes that will be used when transforming the + stack, only used in the case were reference is 'mean' or 'first' + :rtype: ndarray(img.shape[axis], 3, 3) :return: The transformation matrix for each image in the stack """ @@ -410,6 +427,16 @@ def register_stack( iterable = range(idx_start, img.shape[axis]) + if reference in ['mean','first'] and processes>1: + with Pool(processes=processes) as pool: + pool_iter = pool.imap_unordered(self.reg_help, [[ref, simple_slice(img, i , axis),i] for i in iterable]) + if verbose: + pool_iter = tqdm(pool_iter,total=len(iterable)) + for i in pool_iter: + self._tmats[i[1], :, :] = i[0] + return self._tmats + + if verbose: iterable = tqdm(iterable) @@ -483,6 +510,7 @@ def register_transform_stack( moving_average=1, verbose=False, progress_callback=None, + processes=1, ): """ Register and transform stack of images (movie). @@ -522,11 +550,17 @@ def register_transform_stack( A function that is called after every iteration. This function should accept the keyword arguments current_iteration:int and end_iteration:int. + :type processes: int, optional + :param processes: + Set the number of processes that will be used when transforming the + stack, only used in the case were reference is 'mean' or 'first' + :rtype: ndarray(Ni..., Nj..., Nk...) :return: The transformed stack """ self.register_stack( - img, reference, n_frames, axis, moving_average, verbose, progress_callback + img, reference, n_frames, axis, moving_average, verbose, progress_callback, + processes=processes ) return self.transform_stack(img, axis)