Skip to content

Commit

Permalink
Parallelize reconstructions with submitit (#477)
Browse files Browse the repository at this point in the history
* depend on submitit

* bare-bones submission

* simple status handling

* refactor `monitor_jobs`

* parsing

* simple resource estimates

* add ram-multiplier

* back to array jobs

* correct elapsed timer

* more informative messages

* message handling

* better local debugging

* refactor debug messages

* fix tests

* drop 3.12

* refactor monitoring

* ignore logs

* fix tests

* fix test

* print the first failure and # of successful jobs

* max 50 jobs at a time

* monitor when more jobs than terminal lines

* cpu_request = num_processes

* imrpved headline

* add `-rx` option to `reconstruct`

* better color handling
  • Loading branch information
talonchandler authored Sep 17, 2024
1 parent 9f0a37a commit 6c70732
Show file tree
Hide file tree
Showing 11 changed files with 253 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11"]

steps:
- name: Checkout repo
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_and_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
strategy:
matrix:
platform: [ubuntu-latest, windows-latest, macos-latest]
python-version: ["3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11"]

steps:
- uses: actions/checkout@v3
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,4 @@ recOrder/_version.py

# example data
/examples/data_temp/*
/logs/*
75 changes: 66 additions & 9 deletions recOrder/cli/apply_inverse_transfer_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import torch
import torch.multiprocessing as mp
import submitit
from iohub import open_ome_zarr

from recOrder.cli import apply_inverse_models
Expand All @@ -16,6 +17,7 @@
output_dirpath,
processes_option,
transfer_function_dirpath,
ram_multiplier,
)
from recOrder.cli.printing import echo_headline, echo_settings
from recOrder.cli.settings import ReconstructionSettings
Expand All @@ -24,6 +26,7 @@
create_empty_hcs_zarr,
)
from recOrder.io import utils
from recOrder.cli.monitor import monitor_jobs


def _check_background_consistency(
Expand Down Expand Up @@ -289,6 +292,7 @@ def apply_inverse_transfer_function_cli(
config_filepath: Path,
output_dirpath: Path,
num_processes: int = 1,
ram_multiplier: float = 1.0,
) -> None:
output_metadata = get_reconstruction_output_metadata(
input_position_dirpaths[0], config_filepath
Expand All @@ -303,15 +307,65 @@ def apply_inverse_transfer_function_cli(
torch.set_num_threads(1)
torch.set_num_interop_threads(1)

for input_position_dirpath in input_position_dirpaths:
apply_inverse_transfer_function_single_position(
input_position_dirpath,
transfer_function_dirpath,
config_filepath,
output_dirpath / Path(*input_position_dirpath.parts[-3:]),
num_processes,
output_metadata["channel_names"],
)
# Estimate resources
with open_ome_zarr(input_position_dirpaths[0]) as input_dataset:
T, C, Z, Y, X = input_dataset["0"].shape

settings = utils.yaml_to_model(config_filepath, ReconstructionSettings)
gb_ram_request = 0
gb_per_element = 4 / 2**30 # bytes_per_float32 / bytes_per_gb
voxel_resource_multiplier = 4
fourier_resource_multiplier = 32
input_memory = Z * Y * X * gb_per_element
if settings.birefringence is not None:
gb_ram_request += input_memory * voxel_resource_multiplier
if settings.phase is not None:
gb_ram_request += input_memory * fourier_resource_multiplier
if settings.fluorescence is not None:
gb_ram_request += input_memory * fourier_resource_multiplier

gb_ram_request = np.ceil(
np.max([1, ram_multiplier * gb_ram_request])
).astype(int)
cpu_request = np.min([32, num_processes])
num_jobs = len(input_position_dirpaths)

# Prepare and submit jobs
echo_headline(
f"Preparing {num_jobs} job{'s, each with' if num_jobs > 1 else ' with'} "
f"{cpu_request} CPU{'s' if cpu_request > 1 else ''} and "
f"{gb_ram_request} GB of memory per CPU."
)
executor = submitit.AutoExecutor(folder="logs")

executor.update_parameters(
slurm_array_parallelism=np.min([50, num_jobs]),
slurm_mem_per_cpu=f"{gb_ram_request}G",
slurm_cpus_per_task=cpu_request,
slurm_time=60,
slurm_partition="cpu",
# more slurm_*** resource parameters here
)

jobs = []
with executor.batch():
for input_position_dirpath in input_position_dirpaths:
jobs.append(
executor.submit(
apply_inverse_transfer_function_single_position,
input_position_dirpath,
transfer_function_dirpath,
config_filepath,
output_dirpath / Path(*input_position_dirpath.parts[-3:]),
num_processes,
output_metadata["channel_names"],
)
)
echo_headline(
f"{num_jobs} job{'s' if num_jobs > 1 else ''} submitted {'locally' if executor.cluster == 'local' else 'via ' + executor.cluster}."
)

monitor_jobs(jobs, input_position_dirpaths)


@click.command()
Expand All @@ -320,12 +374,14 @@ def apply_inverse_transfer_function_cli(
@config_filepath()
@output_dirpath()
@processes_option(default=1)
@ram_multiplier()
def apply_inv_tf(
input_position_dirpaths: list[Path],
transfer_function_dirpath: Path,
config_filepath: Path,
output_dirpath: Path,
num_processes,
ram_multiplier: float = 1.0,
) -> None:
"""
Apply an inverse transfer function to a dataset using a configuration file.
Expand All @@ -345,4 +401,5 @@ def apply_inv_tf(
config_filepath,
output_dirpath,
num_processes,
ram_multiplier,
)
153 changes: 153 additions & 0 deletions recOrder/cli/monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from pathlib import Path

import time
import numpy as np
import shutil
import submitit
import sys


def _move_cursor_up(n_lines):
sys.stdout.write("\033[F" * n_lines)


def _print_status(jobs, position_dirpaths, elapsed_list, print_indices=None):

columns = [15, 30, 40, 50]

# header
sys.stdout.write(
"\033[K" # clear line
"\033[96mID" # cyan
f"\033[{columns[0]}G WELL "
f"\033[{columns[1]}G STATUS "
f"\033[{columns[2]}G NODE "
f"\033[{columns[2]}G ELAPSED\n"
)

if print_indices is None:
print_indices = range(len(jobs))

complete_count = 0
for i, (job, position_dirpath) in enumerate(zip(jobs, position_dirpaths)):
try:
node_name = job.get_info()["NodeList"] # slowest, so do this first
except:
node_name = "SUBMITTED"

if job.state == "COMPLETED":
color = "\033[32m" # green
complete_count += 1
elif job.state == "RUNNING":
color = "\033[93m" # yellow
elapsed_list[i] += 1 # inexact timing
else:
color = "\033[91m" # red

if i in print_indices:
sys.stdout.write(
f"\033[K" # clear line
f"{color}{job.job_id}"
f"\033[{columns[0]}G {'/'.join(position_dirpath.parts[-3:])}"
f"\033[{columns[1]}G {job.state}"
f"\033[{columns[2]}G {node_name}"
f"\033[{columns[3]}G {elapsed_list[i]} s\n"
)
sys.stdout.flush()
print(
f"\033[32m{complete_count}/{len(jobs)} jobs complete. "
"<ctrl+z> to move monitor to background. "
"<ctrl+c> twice to cancel jobs."
)

return elapsed_list


def _get_jobs_to_print(jobs, num_to_print):
job_indices_to_print = []

# if number of jobs is smaller than termanal size, print all
if len(jobs) <= num_to_print:
return list(range(len(jobs)))

# prioritize incomplete jobs
for i, job in enumerate(jobs):
if not job.done():
job_indices_to_print.append(i)
if len(job_indices_to_print) == num_to_print:
return job_indices_to_print

# fill in the rest with complete jobs
for i, job in enumerate(jobs):
job_indices_to_print.append(i)
if len(job_indices_to_print) == num_to_print:
return job_indices_to_print

# shouldn't reach here
return job_indices_to_print


def monitor_jobs(jobs: list[submitit.Job], position_dirpaths: list[Path]):
"""Displays the status of a list of submitit jobs with corresponding paths.
Parameters
----------
jobs : list[submitit.Job]
List of submitit jobs
position_dirpaths : list[Path]
List of corresponding position paths
"""
NON_JOB_LINES = 3

if not len(jobs) == len(position_dirpaths):
raise ValueError(
"The number of jobs and position_dirpaths should be the same."
)

elapsed_list = [0] * len(jobs) # timer for each job

# print all jobs once if terminal is too small
if shutil.get_terminal_size().lines - NON_JOB_LINES < len(jobs):
_print_status(jobs, position_dirpaths, elapsed_list)

# main monitor loop
try:
while not all(job.done() for job in jobs):
terminal_lines = shutil.get_terminal_size().lines
num_jobs_to_print = np.min(
[terminal_lines - NON_JOB_LINES, len(jobs)]
)

job_indices_to_print = _get_jobs_to_print(jobs, num_jobs_to_print)

elapsed_list = _print_status(
jobs,
position_dirpaths,
elapsed_list,
job_indices_to_print,
)

time.sleep(1)
_move_cursor_up(num_jobs_to_print + 2)

# Print final status
time.sleep(1)
_print_status(jobs, position_dirpaths, elapsed_list)

# cancel jobs if ctrl+c
except KeyboardInterrupt:
for job in jobs:
job.cancel()
print("All jobs cancelled.\033[97m")

# Print STDOUT and STDERR for first incomplete job
incomplete_count = 0
for job in jobs:
if not job.done():
if incomplete_count == 0:
print("\033[32mSTDOUT")
print(job.stdout())
print("\033[91mSTDERR")
print(job.stderr())

print("\033[97m") # print white
22 changes: 18 additions & 4 deletions recOrder/cli/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
def _validate_and_process_paths(
ctx: click.Context, opt: click.Option, value: str
) -> list[Path]:
# Sort and validate the input paths
# Sort and validate the input paths, expanding plates into lists of positions
input_paths = [Path(path) for path in natsorted(value)]
for path in input_paths:
with open_ome_zarr(path, mode="r") as dataset:
if isinstance(dataset, Plate):
raise ValueError(
"Please supply a list of positions instead of an HCS plate. Likely fix: replace 'input.zarr' with 'input.zarr/*/*/*' or 'input.zarr/0/0/0'"
)
plate_path = input_paths.pop()
for position in dataset.positions():
input_paths.append(plate_path / position[0])

return input_paths


Expand Down Expand Up @@ -105,3 +106,16 @@ def decorator(f: Callable) -> Callable:
)(f)

return decorator


def ram_multiplier() -> Callable:
def decorator(f: Callable) -> Callable:
return click.option(
"--ram-multiplier",
"-rx",
default=1.0,
type=float,
help="SLURM RAM multiplier.",
)(f)

return decorator
9 changes: 8 additions & 1 deletion recOrder/cli/reconstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
input_position_dirpaths,
output_dirpath,
processes_option,
ram_multiplier,
)


Expand All @@ -21,8 +22,13 @@
@config_filepath()
@output_dirpath()
@processes_option(default=1)
@ram_multiplier()
def reconstruct(
input_position_dirpaths, config_filepath, output_dirpath, num_processes
input_position_dirpaths,
config_filepath,
output_dirpath,
num_processes,
ram_multiplier,
):
"""
Reconstruct a dataset using a configuration file. This is a
Expand Down Expand Up @@ -58,4 +64,5 @@ def reconstruct(
config_filepath,
output_dirpath,
num_processes,
ram_multiplier,
)
1 change: 1 addition & 0 deletions recOrder/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,4 @@ def apply_inverse_to_zyx_and_save(
t_idx, output_channel_indices
] = reconstruction_czyx
click.echo(f"Finished Writing.. t={t_idx}")

3 changes: 2 additions & 1 deletion recOrder/tests/cli_tests/test_reconstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def test_cli_apply_inv_tf_mock(tmp_input_path_zarr):
Path(tmp_config_yml),
Path(result_path),
1,
1,
)
assert result_inv.exit_code == 0

Expand Down Expand Up @@ -255,7 +256,7 @@ def test_cli_apply_inv_tf_output(tmp_input_path_zarr, capsys):

assert result_path.exists()
captured = capsys.readouterr()
assert "Reconstructing" in captured.out
assert "submitted" in captured.out

# Check scale transformations pass through
assert input_scale == result_dataset.scale
Loading

0 comments on commit 6c70732

Please sign in to comment.