Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

parallelize register/transform stack #17

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions pystackreg/pystackreg.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# -*- coding: utf-8 -*-
from . import turboreg # type: ignore
import warnings
from multiprocessing import Pool

import numpy as np
from tqdm import tqdm
import warnings

from . import turboreg # type: ignore


def simple_slice(arr, inds, axis):
Expand Down Expand Up @@ -122,6 +125,13 @@ def is_registered(self):
:return: True if a transformation matrix was already calculated
"""
return self._is_registered

def reg_help(self, x):
'''
Helper function for multiprocessing parallelization
'''
ref, img, i = x
return self.register(ref,img),i

def register(self, ref, mov):
"""
Expand All @@ -147,6 +157,7 @@ def register(self, ref, mov):

return self.get_matrix()


def transform(self, mov, tmat=None):
"""
Transform an image according to a previous registration.
Expand Down Expand Up @@ -309,6 +320,7 @@ def register_stack(
verbose=False,
progress_callback=None,
suppress_axis_warning=False,
processes=1,
):
"""
Register a stack of images (movie).
Expand Down Expand Up @@ -358,6 +370,11 @@ def register_stack(
warning when the detected time axis is not equal to the supplied axis.
Set this option to True to suppress this warning.

:type processes: int, optional
:param processes:
Set the number of processes that will be used when transforming the
stack, only used in the case were reference is 'mean' or 'first'

:rtype: ndarray(img.shape[axis], 3, 3)
:return: The transformation matrix for each image in the stack
"""
Expand Down Expand Up @@ -410,6 +427,16 @@ def register_stack(

iterable = range(idx_start, img.shape[axis])

if reference in ['mean','first'] and processes>1:
with Pool(processes=processes) as pool:
pool_iter = pool.imap_unordered(self.reg_help, [[ref, simple_slice(img, i , axis),i] for i in iterable])
if verbose:
pool_iter = tqdm(pool_iter,total=len(iterable))
for i in pool_iter:
self._tmats[i[1], :, :] = i[0]
return self._tmats


if verbose:
iterable = tqdm(iterable)

Expand Down Expand Up @@ -483,6 +510,7 @@ def register_transform_stack(
moving_average=1,
verbose=False,
progress_callback=None,
processes=1,
):
"""
Register and transform stack of images (movie).
Expand Down Expand Up @@ -522,11 +550,17 @@ def register_transform_stack(
A function that is called after every iteration. This function should accept
the keyword arguments current_iteration:int and end_iteration:int.

:type processes: int, optional
:param processes:
Set the number of processes that will be used when transforming the
stack, only used in the case were reference is 'mean' or 'first'

:rtype: ndarray(Ni..., Nj..., Nk...)
:return: The transformed stack
"""
self.register_stack(
img, reference, n_frames, axis, moving_average, verbose, progress_callback
img, reference, n_frames, axis, moving_average, verbose, progress_callback,
processes=processes
)

return self.transform_stack(img, axis)