Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add async_local backend and allow using an existing dview for local and async_local backends #311

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions mesmerize_core/algorithms/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import caiman as cm
from contextlib import contextmanager
from ipyparallel import DirectView
from multiprocessing.pool import Pool
import os
import psutil
from typing import Union, Optional, Generator

Cluster = Union[Pool, DirectView]

def get_n_processes(dview: Optional[Cluster]) -> int:
"""Infer number of processes in a multiprocessing or ipyparallel cluster"""
if isinstance(dview, Pool) and hasattr(dview, '_processes'):
return dview._processes
elif isinstance(dview, DirectView):
return len(dview)
else:
return 1


@contextmanager
def ensure_server(dview: Optional[Cluster]) -> Generator[tuple[Cluster, int], None, None]:
"""
Context manager that passes through an existing 'dview' or
opens up a multiprocessing server if none is passed in.
If a server was opened, closes it upon exit.
Usage: `with ensure_server(dview) as (dview, n_processes):`
"""
if dview is not None:
yield dview, get_n_processes(dview)
else:
# no cluster passed in, so open one
if "MESMERIZE_N_PROCESSES" in os.environ.keys():
try:
n_processes = int(os.environ["MESMERIZE_N_PROCESSES"])
except:
n_processes = psutil.cpu_count() - 1
else:
n_processes = psutil.cpu_count() - 1

# Start cluster for parallel processing
_, dview, n_processes = cm.cluster.setup_cluster(
backend="multiprocessing", n_processes=n_processes, single_thread=False
)
try:
yield dview, n_processes
finally:
cm.stop_server(dview=dview)
172 changes: 77 additions & 95 deletions mesmerize_core/algorithms/cnmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,24 @@
import caiman as cm
from caiman.source_extraction.cnmf import cnmf as cnmf
from caiman.source_extraction.cnmf.params import CNMFParams
import psutil
import numpy as np
import traceback
from pathlib import Path, PurePosixPath
from shutil import move as move_file
import os
import time

# prevent circular import
if __name__ in ["__main__", "__mp_main__"]: # when running in subprocess
from mesmerize_core import set_parent_raw_data_path, load_batch
from mesmerize_core.utils import IS_WINDOWS
from mesmerize_core.algorithms._utils import ensure_server
else: # when running with local backend
from ..batch_utils import set_parent_raw_data_path, load_batch
from ..utils import IS_WINDOWS
from ._utils import ensure_server


def run_algo(batch_path, uuid, data_path: str = None):
def run_algo(batch_path, uuid, data_path: str = None, dview=None):
algo_start = time.time()
set_parent_raw_data_path(data_path)

Expand All @@ -41,102 +41,84 @@ def run_algo(batch_path, uuid, data_path: str = None):
f"Starting CNMF item:\n{item}\nWith params:{params}"
)

# adapted from current demo notebook
if "MESMERIZE_N_PROCESSES" in os.environ.keys():
try:
n_processes = int(os.environ["MESMERIZE_N_PROCESSES"])
except:
n_processes = psutil.cpu_count() - 1
else:
n_processes = psutil.cpu_count() - 1
# Start cluster for parallel processing
c, dview, n_processes = cm.cluster.setup_cluster(
backend="local", n_processes=n_processes, single_thread=False
)
with ensure_server(dview) as (dview, n_processes):

# merge cnmf and eval kwargs into one dict
cnmf_params = CNMFParams(params_dict=params["main"])
# Run CNMF, denote boolean 'success' if CNMF completes w/out error
try:
fname_new = cm.save_memmap(
[input_movie_path], base_name=f"{uuid}_cnmf-memmap_", order="C", dview=dview
)
# merge cnmf and eval kwargs into one dict
cnmf_params = CNMFParams(params_dict=params["main"])
# Run CNMF, denote boolean 'success' if CNMF completes w/out error
try:
fname_new = cm.save_memmap(
[input_movie_path], base_name=f"{uuid}_cnmf-memmap_", order="C", dview=dview
)

print("making memmap")
print("making memmap")

Yr, dims, T = cm.load_memmap(fname_new)

images = np.reshape(Yr.T, [T] + list(dims), order="F")

proj_paths = dict()
for proj_type in ["mean", "std", "max"]:
p_img = getattr(np, f"nan{proj_type}")(images, axis=0)
proj_paths[proj_type] = output_dir.joinpath(
f"{uuid}_{proj_type}_projection.npy"
)
np.save(str(proj_paths[proj_type]), p_img)

print("performing CNMF")
cnm = cnmf.CNMF(n_processes, params=cnmf_params, dview=dview)

print("fitting images")
cnm = cnm.fit(images)
#
if "refit" in params.keys():
if params["refit"] is True:
print("refitting")
cnm = cnm.refit(images, dview=dview)

print("performing eval")
cnm.estimates.evaluate_components(images, cnm.params, dview=dview)

output_path = output_dir.joinpath(f"{uuid}.hdf5").resolve()

cnm.save(str(output_path))

Cn = cm.local_correlations(images.transpose(1, 2, 0))
Cn[np.isnan(Cn)] = 0

corr_img_path = output_dir.joinpath(f"{uuid}_cn.npy").resolve()
np.save(str(corr_img_path), Cn, allow_pickle=False)

# output dict for dataframe row (pd.Series)
d = dict()

cnmf_memmap_path = output_dir.joinpath(Path(fname_new).name)
if IS_WINDOWS:
Yr._mmap.close() # accessing private attr but windows is annoying otherwise
move_file(fname_new, cnmf_memmap_path)

# save paths as relative path strings with forward slashes
cnmf_hdf5_path = str(PurePosixPath(output_path.relative_to(output_dir.parent)))
cnmf_memmap_path = str(PurePosixPath(cnmf_memmap_path.relative_to(output_dir.parent)))
corr_img_path = str(PurePosixPath(corr_img_path.relative_to(output_dir.parent)))
for proj_type in proj_paths.keys():
d[f"{proj_type}-projection-path"] = str(PurePosixPath(proj_paths[proj_type].relative_to(
output_dir.parent
)))

d.update(
{
"cnmf-hdf5-path": cnmf_hdf5_path,
"cnmf-memmap-path": cnmf_memmap_path,
"corr-img-path": corr_img_path,
"success": True,
"traceback": None,
}
)

Yr, dims, T = cm.load_memmap(fname_new)
images = np.reshape(Yr.T, [T] + list(dims), order="F")
except:
d = {"success": False, "traceback": traceback.format_exc()}

proj_paths = dict()
for proj_type in ["mean", "std", "max"]:
p_img = getattr(np, f"nan{proj_type}")(images, axis=0)
proj_paths[proj_type] = output_dir.joinpath(
f"{uuid}_{proj_type}_projection.npy"
)
np.save(str(proj_paths[proj_type]), p_img)

# in fname new load in memmap order C
cm.stop_server(dview=dview)
c, dview, n_processes = cm.cluster.setup_cluster(
backend="local", n_processes=None, single_thread=False
)
Comment on lines -78 to -82
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% sure what this part is for; it doesn't really work with the changes here, so I just deleted it, but we can try to do something else if it's important.


print("performing CNMF")
cnm = cnmf.CNMF(n_processes, params=cnmf_params, dview=dview)

print("fitting images")
cnm = cnm.fit(images)
#
if "refit" in params.keys():
if params["refit"] is True:
print("refitting")
cnm = cnm.refit(images, dview=dview)

print("performing eval")
cnm.estimates.evaluate_components(images, cnm.params, dview=dview)

output_path = output_dir.joinpath(f"{uuid}.hdf5").resolve()

cnm.save(str(output_path))

Cn = cm.local_correlations(images.transpose(1, 2, 0))
Cn[np.isnan(Cn)] = 0

corr_img_path = output_dir.joinpath(f"{uuid}_cn.npy").resolve()
np.save(str(corr_img_path), Cn, allow_pickle=False)

# output dict for dataframe row (pd.Series)
d = dict()

cnmf_memmap_path = output_dir.joinpath(Path(fname_new).name)
if IS_WINDOWS:
Yr._mmap.close() # accessing private attr but windows is annoying otherwise
move_file(fname_new, cnmf_memmap_path)

# save paths as relative path strings with forward slashes
cnmf_hdf5_path = str(PurePosixPath(output_path.relative_to(output_dir.parent)))
cnmf_memmap_path = str(PurePosixPath(cnmf_memmap_path.relative_to(output_dir.parent)))
corr_img_path = str(PurePosixPath(corr_img_path.relative_to(output_dir.parent)))
for proj_type in proj_paths.keys():
d[f"{proj_type}-projection-path"] = str(PurePosixPath(proj_paths[proj_type].relative_to(
output_dir.parent
)))

d.update(
{
"cnmf-hdf5-path": cnmf_hdf5_path,
"cnmf-memmap-path": cnmf_memmap_path,
"corr-img-path": corr_img_path,
"success": True,
"traceback": None,
}
)

except:
d = {"success": False, "traceback": traceback.format_exc()}

cm.stop_server(dview=dview)

runtime = round(time.time() - algo_start, 2)
df.caiman.update_item_with_results(uuid, d, runtime)

Expand Down
Loading
Loading