diff --git a/pyproject.toml b/pyproject.toml index d4ce1a8..c14fd99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,9 +22,8 @@ dependencies = [ 'numcodecs==0.11.0', 'dask-image==2024.5.3', 'xarray_multiscale==1.1.0', - 'bokeh==2.4.2', 'pims==0.6.1', - 'dask[distributed]==2024.5.1', + 'dask[distributed]==2024.5.2', 'ome-zarr==0.8.2', 'imagecodecs[all]==2023.3.16', 'natsort==8.4.0', @@ -79,7 +78,8 @@ exclude_lines = [ "import", "pragma: no cover" ] -fail_under = 100 +fail_under = 75 +show_missing = true [tool.isort] line_length = 79 diff --git a/src/aind_smartspim_data_transformation/__init__.py b/src/aind_smartspim_data_transformation/__init__.py index 80137cd..1a003d5 100644 --- a/src/aind_smartspim_data_transformation/__init__.py +++ b/src/aind_smartspim_data_transformation/__init__.py @@ -1,3 +1,3 @@ """Init package""" -__version__ = "0.0.6" +__version__ = "0.0.7" diff --git a/src/aind_smartspim_data_transformation/_shared/__init__.py b/src/aind_smartspim_data_transformation/_shared/__init__.py new file mode 100644 index 0000000..8755617 --- /dev/null +++ b/src/aind_smartspim_data_transformation/_shared/__init__.py @@ -0,0 +1,3 @@ +""" +Shared constants +""" diff --git a/src/aind_smartspim_data_transformation/_shared/types.py b/src/aind_smartspim_data_transformation/_shared/types.py new file mode 100644 index 0000000..8371ae8 --- /dev/null +++ b/src/aind_smartspim_data_transformation/_shared/types.py @@ -0,0 +1,12 @@ +""" +Defines all the types used in the module +""" + +from pathlib import Path +from typing import Union + +import dask.array as da +import numpy as np + +PathLike = Union[Path, str] +ArrayLike = Union[da.Array, np.ndarray] diff --git a/src/aind_smartspim_data_transformation/compress/dask_utils.py b/src/aind_smartspim_data_transformation/compress/dask_utils.py index bb74731..1e555be 100644 --- a/src/aind_smartspim_data_transformation/compress/dask_utils.py +++ b/src/aind_smartspim_data_transformation/compress/dask_utils.py @@ -44,7 +44,7 @@ def log_dashboard_address( user = os.getenv("USER") # noqa: F841 LOGGER.info( f"To access the dashboard, run the following in " - "a terminal: ssh -L {port}:{host}:{port} {user}@" + f"a terminal: ssh -L {port}:{host}:{port} {user}@" f"{login_node_address} " ) diff --git a/src/aind_smartspim_data_transformation/compress/png_to_zarr.py b/src/aind_smartspim_data_transformation/compress/png_to_zarr.py index 8581e18..4cff245 100644 --- a/src/aind_smartspim_data_transformation/compress/png_to_zarr.py +++ b/src/aind_smartspim_data_transformation/compress/png_to_zarr.py @@ -22,7 +22,7 @@ from dask import config as da_cfg from dask.array.core import Array from dask.base import tokenize -from dask.distributed import Client, LocalCluster, performance_report +from dask.distributed import Client, LocalCluster # , performance_report # from distributed import wait from numcodecs import blosc @@ -31,15 +31,15 @@ from ome_zarr.writer import write_multiscales_metadata from skimage.io import imread as sk_imread -from aind_smartspim_data_transformation.compress.zarr_utilities import ( +from aind_smartspim_data_transformation._shared.types import ( ArrayLike, PathLike, - pad_array_n_d, ) from aind_smartspim_data_transformation.compress.zarr_writer import ( BlockedArrayWriter, ) from aind_smartspim_data_transformation.io import PngReader +from aind_smartspim_data_transformation.io.utils import pad_array_n_d def set_dask_config(dask_folder: str): @@ -646,66 +646,62 @@ def smartspim_channel_zarr_writer( metadata=_get_pyramid_metadata(), ) - performance_report_path = f"{output_path}/report_{stack_name}.html" + # performance_report_path = f"{output_path}/report_{stack_name}.html" start_time = time.time() # Writing zarr and performance report - with performance_report(filename=performance_report_path): - logger.info(f"{'='*40}Writing channel {channel_name}{'='*40}") + # with performance_report(filename=performance_report_path): + logger.info(f"{'='*40}Writing channel {channel_name}{'='*40}") - # Writing zarr - block_shape = list( - BlockedArrayWriter.get_block_shape( - arr=image_data, target_size_mb=12800 # 51200, - ) + # Writing zarr + block_shape = list( + BlockedArrayWriter.get_block_shape( + arr=image_data, target_size_mb=12800 # 51200, ) + ) - # Formatting to 5D block shape - block_shape = ([1] * (5 - len(block_shape))) + block_shape - written_pyramid = [] - pyramid_group = None - - # Writing multiple levels - for level in range(n_lvls): - if not level: - array_to_write = image_data - - else: - # It's faster to write the scale and then read it back - # to compute the next scale - previous_scale = da.from_zarr( - pyramid_group, pyramid_group.chunks - ) - new_scale_factor = ( - [1] * (len(previous_scale.shape) - len(scale_factor)) - ) + scale_factor - - previous_scale_pyramid, _ = compute_pyramid( - data=previous_scale, - scale_axis=new_scale_factor, - chunks=image_data.chunksize, - n_lvls=2, - ) - array_to_write = previous_scale_pyramid[-1] - - logger.info(f"[level {level}]: pyramid level: {array_to_write}") - - # Create the scale dataset - pyramid_group = new_channel_group.create_dataset( - name=level, - shape=array_to_write.shape, - chunks=array_to_write.chunksize, - dtype=array_to_write.dtype, - compressor=writing_options, - dimension_separator="/", - overwrite=True, + # Formatting to 5D block shape + block_shape = ([1] * (5 - len(block_shape))) + block_shape + written_pyramid = [] + pyramid_group = None + + # Writing multiple levels + for level in range(n_lvls): + if not level: + array_to_write = image_data + + else: + # It's faster to write the scale and then read it back + # to compute the next scale + previous_scale = da.from_zarr(pyramid_group, pyramid_group.chunks) + new_scale_factor = ( + [1] * (len(previous_scale.shape) - len(scale_factor)) + ) + scale_factor + + previous_scale_pyramid, _ = compute_pyramid( + data=previous_scale, + scale_axis=new_scale_factor, + chunks=image_data.chunksize, + n_lvls=2, ) + array_to_write = previous_scale_pyramid[-1] + + logger.info(f"[level {level}]: pyramid level: {array_to_write}") + + # Create the scale dataset + pyramid_group = new_channel_group.create_dataset( + name=level, + shape=array_to_write.shape, + chunks=array_to_write.chunksize, + dtype=array_to_write.dtype, + compressor=writing_options, + dimension_separator="/", + overwrite=True, + ) - # Block Zarr Writer - BlockedArrayWriter.store( - array_to_write, pyramid_group, block_shape - ) - written_pyramid.append(array_to_write) + # Block Zarr Writer + BlockedArrayWriter.store(array_to_write, pyramid_group, block_shape) + written_pyramid.append(array_to_write) end_time = time.time() logger.info(f"Time to write the dataset: {end_time - start_time}") diff --git a/src/aind_smartspim_data_transformation/compress/zarr_utilities.py b/src/aind_smartspim_data_transformation/compress/zarr_utilities.py deleted file mode 100644 index 0fa9620..0000000 --- a/src/aind_smartspim_data_transformation/compress/zarr_utilities.py +++ /dev/null @@ -1,640 +0,0 @@ -""" -Module to the zarr utilities -""" - -import multiprocessing -import os -import time -from pathlib import Path -from typing import Optional, Tuple, Union - -import dask -import dask.array as da -import numpy as np -import pims -from dask.array import concatenate, pad -from dask.array.core import Array -from dask.base import tokenize -from natsort import natsorted -from numcodecs import blosc -from skimage.io import imread as sk_imread - -PathLike = Union[str, Path] -ArrayLike = Union[dask.array.core.Array, np.ndarray] -blosc.use_threads = False - - -def add_leading_dim(data: ArrayLike) -> ArrayLike: - """ - Adds a leading dimension - - Parameters - ------------------------ - - data: ArrayLike - Input array that will have the - leading dimension - - Returns - ------------------------ - - ArrayLike: - Array with the new dimension in front. - """ - return data[None, ...] - - -def pad_array_n_d(arr: ArrayLike, dim: int = 5) -> ArrayLike: - """ - Pads a daks array to be in a 5D shape. - - Parameters - ------------------------ - - arr: ArrayLike - Dask/numpy array that contains image data. - dim: int - Number of dimensions that the array will be padded - - Returns - ------------------------ - ArrayLike: - Padded dask/numpy array. - """ - if dim > 5: - raise ValueError("Padding more than 5 dimensions is not supported.") - - while arr.ndim < dim: - arr = arr[np.newaxis, ...] - return arr - - -def read_image_directory_structure(folder_dir: PathLike) -> dict: - """ - Creates a dictionary representation of all the images - saved by folder/col_N/row_N/images_N.[file_extention] - - Parameters - ------------------------ - folder_dir:PathLike - Path to the folder where the images are stored - - Returns - ------------------------ - dict: - Dictionary with the image representation where: - {channel_1: ... {channel_n: {col_1: ... - col_n: {row_1: ... row_n: [image_0, ..., image_n]} } } } - """ - - directory_structure = {} - folder_dir = Path(folder_dir) - - channel_paths = natsorted( - [ - folder_dir.joinpath(folder) - for folder in os.listdir(folder_dir) - if os.path.isdir(folder_dir.joinpath(folder)) - ] - ) - - for channel_idx in range(len(channel_paths)): - directory_structure[channel_paths[channel_idx]] = {} - - cols = natsorted(os.listdir(channel_paths[channel_idx])) - - for col in cols: - possible_col = channel_paths[channel_idx].joinpath(col) - - if os.path.isdir(possible_col): - directory_structure[channel_paths[channel_idx]][col] = {} - - rows = natsorted(os.listdir(possible_col)) - - for row in rows: - possible_row = ( - channel_paths[channel_idx].joinpath(col).joinpath(row) - ) - - if os.path.isdir(possible_row): - directory_structure[channel_paths[channel_idx]][col][ - row - ] = natsorted(os.listdir(possible_row)) - - return directory_structure - - -def lazy_tiff_reader( - filename: PathLike, - shape: Optional[Tuple[int]] = None, - dtype: Optional[type] = None, -): - """ - Creates a dask array to read an image located in a specific path. - - Parameters - ------------------------ - - filename: PathLike - Path to the image - - Returns - ------------------------ - - dask.array.core.Array - Array representing the image data - """ - name = "imread-%s" % tokenize(filename, map(os.path.getmtime, filename)) - - if dtype is None or shape is None: - try: - with pims.open(filename) as imgs: - dtype = np.dtype(imgs.pixel_type) - shape = (1,) + (len(imgs),) + imgs.frame_shape - except: - return None - - key = [(name,) + (0,) * len(shape)] - value = [(add_leading_dim, (sk_imread, filename))] - dask_arr = dict(zip(key, value)) - chunks = tuple((d,) for d in shape) - - return Array(dask_arr, name, chunks, dtype) - - -def fix_image_diff_dims( - new_arr: ArrayLike, chunksize: Tuple[int], len_chunks: int, work_axis: int -) -> ArrayLike: - """ - Fixes the array dimension to match the shape of - the chunksize. - - Parameters - ------------------------ - - new_arr: ArrayLike - Array to be fixed - - chunksize: Tuple[int] - Chunksize of the original array - - len_chunks: int - Length of the chunksize. Used as a - parameter to avoid computing it - multiple times - - work_axis: int - Axis to concatenate. If the different - axis matches this one, there is no need - to fix the array dimension - - Returns - ------------------------ - - ArrayLike - Array with the new dimensions - """ - - zeros_dim = [] - diff_dim = -1 - c = 0 - - for chunk_idx in range(len_chunks): - new_chunk_dim = new_arr.chunksize[chunk_idx] - - if new_chunk_dim != chunksize[chunk_idx]: - c += 1 - diff_dim = chunk_idx - - zeros_dim.append(abs(chunksize[chunk_idx] - new_chunk_dim)) - - if c > 1: - raise ValueError("Block has two different dimensions") - else: - if (diff_dim - len_chunks) == work_axis: - return new_arr - - n_pad = tuple(tuple((0, dim)) for dim in zeros_dim) - new_arr = pad( - new_arr, pad_width=n_pad, mode="constant", constant_values=0 - ).rechunk(chunksize) - - return new_arr - - -def concatenate_dask_arrays( - arr_1: ArrayLike, arr_2: ArrayLike, axis: int -) -> ArrayLike: - """ - Concatenates two arrays in a given - dimension - - Parameters - ------------------------ - - arr_1: ArrayLike - Array 1 that will be concatenated - - arr_2: ArrayLike - Array 2 that will be concatenated - - axis: int - Concatenation axis - - Returns - ------------------------ - - ArrayLike - Concatenated array that contains - arr_1 and arr_2 - """ - - shape_arr_1 = arr_1.shape - shape_arr_2 = arr_2.shape - - if shape_arr_1 != shape_arr_2: - slices = [] - dims = len(shape_arr_1) - - for shape_dim_idx in range(dims): - if shape_arr_1[shape_dim_idx] > shape_arr_2[shape_dim_idx] and ( - shape_dim_idx - dims != axis - ): - raise ValueError( - f""" - Array 1 {shape_arr_1} must have - a smaller shape than array 2 {shape_arr_2} - except for the axis dimension {shape_dim_idx} - {dims} {shape_dim_idx - dims} {axis} - """ - ) - - if shape_arr_1[shape_dim_idx] != shape_arr_2[shape_dim_idx]: - slices.append(slice(0, shape_arr_1[shape_dim_idx])) - - else: - slices.append(slice(None)) - - slices = tuple(slices) - arr_2 = arr_2[slices] - - try: - res = concatenate([arr_1, arr_2], axis=axis) - except ValueError: - raise ValueError( - f""" - Unable to cancat arrays - Shape 1: - {shape_arr_1} shape 2: {shape_arr_2} - """ - ) - - return res - - -def read_chunked_stitched_image_per_channel( - directory_structure: dict, - channel_name: str, - start_slice: int, - end_slice: int, -) -> ArrayLike: - """ - Creates a dask array of the whole image volume - based on image chunks preserving the chunksize. - - Parameters - ------------------ - - directory_structure:dict - dictionary to store paths of images with the following structure: - {channel_1: ... {channel_n: {col_1: ... - col_n: {row_1: ... row_n: [image_0, ..., image_n]} } } } - - channel_name : str - Channel name to reconstruct the image volume - - start_slice: int - When using multiprocessing, this is - the start slice the worker will use for - the array concatenation - - end_slice: int - When using multiprocessing, this is - the final slice the worker will use for - the array concatenation - - Returns - ------------------------ - - ArrayLike - Array with the image volume - """ - concat_z_3d_blocks = concat_horizontals = horizontal = None - - # Getting col structure - rows = list(directory_structure.values())[0] - rows_paths = list(rows.keys()) - first = True - - for slice_pos in range(start_slice, end_slice): - idx_col = 0 - idx_row = 0 - - concat_horizontals = None - - for row_name in rows_paths: - idx_row = 0 - horizontal = [] - shape = None - dtype = None - column_names = list( - directory_structure[channel_name][row_name].keys() - ) - n_cols = len(column_names) - - # check_shape = (1, 256, 256, 256) - - for column_name_idx in range(n_cols): - valid_image = True - column_name = column_names[column_name_idx] - last_col = column_name_idx == n_cols - 1 - - if last_col: - shape = None - dtype = None - - try: - slice_name = directory_structure[channel_name][row_name][ - column_name - ][slice_pos] - - filepath = str( - channel_name.joinpath(row_name) - .joinpath(column_name) - .joinpath(slice_name) - ) - - new_arr = lazy_tiff_reader( - filepath - ) # , dtype=dtype, shape=shape) - - if (shape is None or dtype is None) and new_arr: - shape = new_arr.shape - dtype = new_arr.dtype - last_col = False - - except ValueError: - print( - f"{filepath} -> No valid image in ", - slice_name, - slice_pos, - ) - - if valid_image: - horizontal.append(new_arr) - - idx_row += 1 - - # Concatenating horizontally lazy images - print("Before fixing: ", horizontal) - for i in range(len(horizontal)): - if horizontal[i] is not None: - shape = horizontal[i].shape - dtype = horizontal[i].dtype - break - - for i in range(len(horizontal)): - if ( - horizontal[i] is None - and shape is not None - and dtype is not None - ): - horizontal[i] = da.zeros(shape, dtype=dtype) - - print("Before concat: ", horizontal) - horizontal_concat = concatenate(horizontal, axis=-1) - - if not idx_col: - concat_horizontals = horizontal_concat - else: - concat_horizontals = concatenate_dask_arrays( - arr_1=concat_horizontals, arr_2=horizontal_concat, axis=-2 - ) - - idx_col += 1 - - if first: - concat_z_3d_blocks = concat_horizontals - first = False - - else: - concat_z_3d_blocks = concatenate_dask_arrays( - arr_1=concat_z_3d_blocks, arr_2=concat_horizontals, axis=-3 - ) - - return concat_z_3d_blocks, [start_slice, end_slice] - - -def _read_chunked_stitched_image_per_channel(args_dict: dict): - """ - Function used to be dispatched to workers - by using multiprocessing - """ - return read_chunked_stitched_image_per_channel(**args_dict) - - -def channel_parallel_reading( - directory_structure: dict, - channel_idx: int, - workers: Optional[int] = 0, - chunks: Optional[int] = 1, - ensure_parallel: Optional[bool] = True, -) -> ArrayLike: - """ - Creates a dask array of the whole image channel volume based - on image chunks preserving the chunksize and using - multiprocessing. - - Parameters - ------------------------ - - directory_structure: dict - dictionary to store paths of images with the following structure: - {channel_1: ... {channel_n: {col_1: ... - col_n: {row_1: ... row_n: [image_0, ..., image_n]} } } } - - channel_name : str - Channel name to reconstruct the image volume - - sample_img: ArrayLike - Image used as guide for the chunksize - - workers: Optional[int] - Number of workers that will be used - for reading the chunked image. - Default value 0, it means that the - available number of cores will be used. - - chunks: Optional[int] - Chunksize of the 3D chunked images. - - ensure_parallel: Optional[bool] - True if we want to read the images in - parallel. False, otherwise. - - Returns - ------------------------ - - ArrayLike - Array with the image channel volume - """ - if workers == 0: - workers = multiprocessing.cpu_count() - - cols = list(directory_structure.values())[0] - n_images = len(list(list(cols.values())[0].values())[0]) - # print(f"n_images: {n_images}") - - channel_paths = list(directory_structure.keys()) - dask_array = None - - if n_images < workers and ensure_parallel: - workers = n_images - - if n_images < workers or not ensure_parallel: - dask_array = read_chunked_stitched_image_per_channel( - directory_structure=directory_structure, - channel_name=channel_paths[channel_idx], - start_slice=0, - end_slice=n_images, - )[0] - print(f"No need for parallel reading... {dask_array}") - - else: - images_per_worker = n_images // workers - print( - f"Setting workers to {workers}" - f"- {images_per_worker} - total images: {n_images}" - ) - - # Getting 5 dim image TCZYX - args = [] - start_slice = 0 - end_slice = images_per_worker - - for idx_worker in range(workers): - arg_dict = { - "directory_structure": directory_structure, - "channel_name": channel_paths[channel_idx], - "start_slice": start_slice, - "end_slice": end_slice, - } - - args.append(arg_dict) - - if idx_worker + 1 == workers - 1: - start_slice = end_slice - end_slice = n_images - else: - start_slice = end_slice - end_slice += images_per_worker - - res = [] - with multiprocessing.Pool(workers) as pool: - results = pool.imap( - _read_chunked_stitched_image_per_channel, - args, - chunksize=chunks, - ) - - for pos in results: - res.append(pos) - - for res_idx in range(len(res)): - if not res_idx: - dask_array = res[res_idx][0] - else: - dask_array = concatenate( - [dask_array, res[res_idx][0]], axis=-3 - ) - - print(f"Slides: {res[res_idx][1]}") - - return dask_array - - -def parallel_read_chunked_stitched_multichannel_image( - directory_structure: dict, - workers: Optional[int] = 0, - ensure_parallel: Optional[bool] = True, - divide_channels: Optional[bool] = True, -) -> ArrayLike: - """ - Creates a dask array of the whole image volume based - on image chunks preserving the chunksize and using - multiprocessing. - - Parameters - ------------------------ - - directory_structure: dict - dictionary to store paths of images with the following structure: - {channel_1: ... {channel_n: {col_1: ... - col_n: {row_1: ... row_n: [image_0, ..., image_n]} } } } - - sample_img: ArrayLike - Image used as guide for the chunksize - - workers: Optional[int] - Number of workers that will be used - for reading the chunked image. - Default value 0, it means that the - available number of cores will be used. - - ensure_parallel: Optional[bool] - True if we want to read the images in - parallel. False, otherwise. - - Returns - ------------------------ - - ArrayLike - Array with the image channel volume - """ - - multichannel_image = None - - channel_paths = list(directory_structure.keys()) - - multichannels = [] - read_channels = {} - print(f"Channel in directory structure: {channel_paths}") - - for channel_idx in range(len(channel_paths)): - print(f"Reading images from {channel_paths[channel_idx]}") - start_time = time.time() - read_chunked_channel = channel_parallel_reading( - directory_structure, - channel_idx, - workers=workers, - ensure_parallel=ensure_parallel, - ) - end_time = time.time() - - print(f"Time reading single channel image: {end_time - start_time}") - - # Padding to 4D if necessary - ch_name = Path(channel_paths[channel_idx]).name - - read_chunked_channel = pad_array_n_d(read_chunked_channel, 4) - multichannels.append(read_chunked_channel) - read_channels[ch_name] = read_chunked_channel - - if divide_channels: - return read_channels - - if len(multichannels) > 1: - multichannel_image = concatenate(multichannels, axis=0) - else: - multichannel_image = multichannels[0] - - return multichannel_image diff --git a/src/aind_smartspim_data_transformation/compress/zarr_writer.py b/src/aind_smartspim_data_transformation/compress/zarr_writer.py index 96628b9..7041c9b 100644 --- a/src/aind_smartspim_data_transformation/compress/zarr_writer.py +++ b/src/aind_smartspim_data_transformation/compress/zarr_writer.py @@ -23,7 +23,7 @@ def _get_size(shape: Tuple[int, ...], itemsize: int) -> int: """ if any(s <= 0 for s in shape): raise ValueError("shape must be > 0 in all dimensions") - return np.product(shape) * itemsize + return np.prod(shape) * itemsize def _closer_to_target( diff --git a/src/aind_smartspim_data_transformation/io/utils.py b/src/aind_smartspim_data_transformation/io/utils.py index 9349dcb..a75ba60 100644 --- a/src/aind_smartspim_data_transformation/io/utils.py +++ b/src/aind_smartspim_data_transformation/io/utils.py @@ -4,12 +4,11 @@ import json import os -from typing import Optional, Union +from typing import Optional -import dask.array as da import numpy as np -ArrayLike = Union[da.Array, np.ndarray] +from aind_smartspim_data_transformation._shared.types import ArrayLike def add_leading_dim(data: ArrayLike): @@ -29,6 +28,31 @@ def add_leading_dim(data: ArrayLike): return data[None, ...] +def pad_array_n_d(arr: ArrayLike, dim: int = 5) -> ArrayLike: + """ + Pads a daks array to be in a 5D shape. + + Parameters + ------------------------ + + arr: ArrayLike + Dask/numpy array that contains image data. + dim: int + Number of dimensions that the array will be padded + + Returns + ------------------------ + ArrayLike: + Padded dask/numpy array. + """ + if dim > 5: + raise ValueError("Padding more than 5 dimensions is not supported.") + + while arr.ndim < dim: + arr = arr[np.newaxis, ...] + return arr + + def extract_data( arr: ArrayLike, last_dimensions: Optional[int] = None ) -> ArrayLike: diff --git a/src/aind_smartspim_data_transformation/smartspim_job.py b/src/aind_smartspim_data_transformation/smartspim_job.py index ecb24dd..1cb1076 100644 --- a/src/aind_smartspim_data_transformation/smartspim_job.py +++ b/src/aind_smartspim_data_transformation/smartspim_job.py @@ -2,7 +2,6 @@ import logging import os -import sys from datetime import datetime from pathlib import Path from typing import Iterator, List, Optional @@ -11,7 +10,6 @@ BasicJobSettings, GenericEtl, JobResponse, - get_parser, ) from numcodecs.blosc import Blosc from pydantic import Field @@ -98,13 +96,8 @@ def _get_compressor(self) -> Blosc: """ if self.job_settings.compressor_name == CompressorName.BLOSC: return Blosc(**self.job_settings.compressor_kwargs) - else: - # TODO: This is validated during the construction of JobSettings, - # so we can probably just remove this exception. - raise Exception( - f"Unknown compressor. Please select one of " - f"{[c for c in CompressorName]}" - ) + + return None @staticmethod def _compress_and_write_channels( @@ -135,6 +128,7 @@ def _compress_and_write_channels( n_workers=n_workers, processes=True, ) + print(f"Instantiated client: {client}") try: for delayed_arr, output_path, stack_name in read_channel_stacks: @@ -147,7 +141,6 @@ def _compress_and_write_channels( n_lvls=4, channel_name=output_path.stem, stack_name=stack_name, - client=client, logger=logging, writing_options=compressor, ) @@ -211,28 +204,3 @@ def run_job(self) -> JobResponse: message=f"Job finished in: {job_end_time-job_start_time}", data=None, ) - - -def main(): - """Main function""" - sys_args = sys.argv[1:] - parser = get_parser() - cli_args = parser.parse_args(sys_args) - if cli_args.job_settings is not None: - job_settings = SmartspimJobSettings.model_validate_json( - cli_args.job_settings - ) - elif cli_args.config_file is not None: - job_settings = SmartspimJobSettings.from_config_file( - cli_args.config_file - ) - else: - # Construct settings from env vars - job_settings = SmartspimJobSettings() - job = SmartspimCompressionJob(job_settings=job_settings) - job_response = job.run_job() - logging.info(job_response.model_dump_json()) - - -if __name__ == "__main__": - main() diff --git a/tests/compress/__init__.py b/tests/compress/__init__.py new file mode 100644 index 0000000..e513071 --- /dev/null +++ b/tests/compress/__init__.py @@ -0,0 +1 @@ +"""Unit tests for the compression scripts""" diff --git a/tests/compress/test_dask_utils.py b/tests/compress/test_dask_utils.py new file mode 100644 index 0000000..772a6ac --- /dev/null +++ b/tests/compress/test_dask_utils.py @@ -0,0 +1,322 @@ +"""Tests dask utils""" + +import unittest +from unittest.mock import MagicMock, patch + +from distributed import Client + +from aind_smartspim_data_transformation.compress import dask_utils + + +class DaskUtilsTest(unittest.TestCase): + """Class for testing the zarr writer""" + + def test_get_local_deployment(self): + """Tests getting a deployment""" + deployment = dask_utils.get_deployment() + + self.assertEqual(dask_utils.Deployment.LOCAL.value, deployment) + + @patch.dict("os.environ", {"SLURM_JOBID": "000"}) + def test_get_allen_deploymet(self): + """Tests getting a deployment on the Allen HPC""" + deployment = dask_utils.get_deployment() + + self.assertEqual(dask_utils.Deployment.SLURM.value, deployment) + + @patch("aind_smartspim_data_transformation.compress.dask_utils.get_client") + def test_get_local_client(self, mock_client: MagicMock): + """Tests getting a local client""" + mock_client.return_value = (Client, 0) + + deployment = dask_utils.get_deployment() + client, _ = dask_utils.get_client( + deployment=deployment, + worker_options=0, + n_workers=1, + processes=True, + ) + + self.assertEqual(client, Client) + + def test_get_client_fail(self): + """Tests fail getting a local client""" + + with self.assertRaises(NotImplementedError): + dask_utils.get_client( + deployment="UnknownDeployment", + worker_options=0, + n_workers=1, + processes=True, + ) + + @patch.dict("os.environ", {"SLURM_JOBID": "000"}) + @patch("distributed.Client") + def test_get_slurm_client_mpi_failure(self, mock_client: MagicMock): + """Tests getting a slurm client""" + mock_client.return_value = (Client, 0) + + deployment = dask_utils.get_deployment() + + with self.assertRaises(ImportError): + dask_utils.get_client( + deployment=deployment, + worker_options=0, + n_workers=1, + processes=True, + ) + + @patch.dict("os.environ", {"SLURM_JOBID": "000"}) + @patch( + "aind_smartspim_data_transformation.compress.dask_utils.DASK_MPI_INSTALLED", + new=True, + ) + @patch.dict("sys.modules", {"dask_mpi": None}) + @patch("aind_smartspim_data_transformation.compress.dask_utils.get_client") + @patch("distributed.Client") + def test_get_slurm_client_mpi( + self, mock_client: MagicMock, mock_get_client: MagicMock + ): + """Tests getting a slurm client""" + mock_client.return_value = (Client, 0) + mock_get_client.return_value = Client + + deployment = dask_utils.get_deployment() + slurm_client = dask_utils.get_client( + deployment=deployment, + worker_options=None, + n_workers=1, + processes=True, + ) + + self.assertEqual(slurm_client, Client) + mock_get_client.assert_called_once_with( + deployment=deployment, + worker_options=None, + n_workers=1, + processes=True, + ) + + @patch("requests.delete") + def test_cancel_slurm_job_success(self, mock_requests_delete: MagicMock): + """ + Tests cancelling a slurm job successfully + """ + mock_response = MagicMock() + mock_response.status_code = 200 + mock_requests_delete.return_value = mock_response + + job_id = "123" + api_url = "http://myhost:80/api/slurm/v0.0.36" + headers = {"Authorization": "Bearer token"} + + response = dask_utils.cancel_slurm_job(job_id, api_url, headers) + + self.assertEqual(response.status_code, mock_response.status_code) + mock_requests_delete.assert_called_once_with( + f"{api_url}/job/{job_id}", headers=headers + ) + + @patch("requests.delete") + def test_cancel_slurm_job_failure(self, mock_requests_delete: MagicMock): + """ + Tests cancelling slurm job with + mock job failure + """ + mock_response = MagicMock() + mock_response.status_code = 500 + mock_requests_delete.return_value = mock_response + + job_id = "123" + api_url = "http://myhost:80/api/slurm/v0.0.36" + headers = {"Authorization": "Bearer token"} + + response = dask_utils.cancel_slurm_job(job_id, api_url, headers) + self.assertEqual(response.status_code, mock_response.status_code) + + mock_requests_delete.assert_called_once_with( + f"{api_url}/job/{job_id}", headers=headers + ) + + @patch.dict( + "os.environ", + { + "SLURM_JOBID": "000", + "HPC_HOST": "example.com", + "HPC_PORT": "80", + "HPC_API_ENDPOINT": "api", + "HPC_USERNAME": "username", + "HPC_PASSWORD": "password", + "HPC_TOKEN": "token", + }, + ) + @patch("os.getenv") + @patch( + "aind_smartspim_data_transformation.compress.dask_utils.cancel_slurm_job" + ) + def test_cleanup_slurm_with_env_vars( + self, mock_cancel_slurm_job: MagicMock, mock_getenv: MagicMock + ): + """ + Cleaning up slurm job with + environment variables + """ + mock_getenv.side_effect = lambda x: { + "SLURM_JOBID": "123", + "HPC_HOST": "example.com", + "HPC_PORT": "80", + "HPC_API_ENDPOINT": "api", + "HPC_USERNAME": "username", + "HPC_PASSWORD": "password", + "HPC_TOKEN": "token", + }.get(x) + + # Set up mock response for cancel_slurm_job + mock_response = MagicMock() + mock_response.status_code = 200 + mock_cancel_slurm_job.return_value = mock_response + + dask_utils._cleanup(deployment=dask_utils.Deployment.SLURM.value) + + mock_cancel_slurm_job.assert_called_once_with( + "123", + "http://example.com:80/api", + { + "X-SLURM-USER-NAME": "username", + "X-SLURM-USER-PASSWORD": "password", + "X-SLURM-USER-TOKEN": "token", + }, + ) + + @patch.dict( + "os.environ", + { + "SLURM_JOBID": "000", + "HPC_HOST": "example.com", + "HPC_PORT": "80", + "HPC_API_ENDPOINT": "api", + "HPC_USERNAME": "username", + "HPC_PASSWORD": "password", + "HPC_TOKEN": "token", + }, + ) + @patch("os.getenv") + @patch( + "aind_smartspim_data_transformation.compress.dask_utils.cancel_slurm_job" + ) + @patch("aind_smartspim_data_transformation.compress.dask_utils.logging") + def test_cleanup_slurm_with_env_vars_failed( + self, + mock_logging: MagicMock, + mock_cancel_slurm_job: MagicMock, + mock_getenv: MagicMock, + ): + """ + Tests failure cleaning up slurm job with + environment variables + """ + mock_getenv.side_effect = lambda x: { + "SLURM_JOBID": "123", + "HPC_HOST": "example.com", + "HPC_PORT": "80", + "HPC_API_ENDPOINT": "api", + "HPC_USERNAME": "username", + "HPC_PASSWORD": "password", + "HPC_TOKEN": "token", + }.get(x) + + # Set up mock response for cancel_slurm_job + mock_response = MagicMock() + mock_response.status_code = 404 + mock_response.text = "test" + mock_cancel_slurm_job.return_value = mock_response + + dask_utils._cleanup(deployment=dask_utils.Deployment.SLURM.value) + + mock_cancel_slurm_job.assert_called_once_with( + "123", + "http://example.com:80/api", + { + "X-SLURM-USER-NAME": "username", + "X-SLURM-USER-PASSWORD": "password", + "X-SLURM-USER-TOKEN": "token", + }, + ) + mock_logging.error.assert_called_once_with( + "Failed to cancel SLURM job 123: test" + ) + + @patch.dict("os.environ", {"SLURM_JOBID": "000"}) + @patch("aind_smartspim_data_transformation.compress.dask_utils.logging") + @patch("os.getenv") + def test_cleanup_slurm_without_env_vars( + self, mock_getenv: MagicMock, mock_logging: MagicMock + ): + """ + Tests cleaning up slurm without + environment variables + """ + mock_getenv.side_effect = lambda x: { + "SLURM_JOBID": "123", + "HPC_HOST": "example.com", + "HPC_PORT": "80", + "HPC_API_ENDPOINT": "api", + "HPC_USERNAME": "username", + "HPC_PASSWORD": "password", + "HPC_TOKEN": "token", + }.get(x) + + dask_utils._cleanup(deployment=dask_utils.Deployment.SLURM.value) + mock_logging.error.assert_called_once_with( + "Failed to get SLURM env vars to cleanup: 'HPC_HOST'" + ) + + @patch("os.getenv") + @patch("aind_smartspim_data_transformation.compress.dask_utils.logging") + def test_cleanup_local( + self, mock_logging: MagicMock, mock_getenv: MagicMock + ): + """ + Tests cleaning up a local cluster + """ + mock_getenv.return_value = None + + dask_utils._cleanup(deployment=dask_utils.Deployment.LOCAL.value) + + mock_logging.info.assert_not_called() + + @patch("os.getenv") + @patch( + "aind_smartspim_data_transformation.compress.dask_utils.LOGGER.info" + ) + @patch("distributed.Client") + def test_log_dashboard_address( + self, + mock_Client: MagicMock, + mock_logger_info: MagicMock, + mock_getenv: MagicMock, + ): + """ + Tests log dashboard address + """ + mock_getenv.return_value = "testuser" + + mock_client = MagicMock() + mock_Client.return_value = mock_client + + mock_client.scheduler_info.return_value = { + "services": {"dashboard": 8787} + } + + mock_client.run_on_scheduler.return_value = "scheduler-host" + + dask_utils.log_dashboard_address(client=mock_client) + + mock_logger_info.assert_called_once_with( + "To access the dashboard, run the following in " + "a terminal: ssh -L 8787:scheduler-host:8787 testuser@hpc-login " + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/compress/test_zarr_writer.py b/tests/compress/test_zarr_writer.py new file mode 100644 index 0000000..c0c768d --- /dev/null +++ b/tests/compress/test_zarr_writer.py @@ -0,0 +1,61 @@ +""" +Tests for the zarr writer +""" + +import unittest + +import dask.array as da +import numpy as np + +from aind_smartspim_data_transformation.compress import zarr_writer + + +class ZarrWriterTest(unittest.TestCase): + """Class for testing the zarr writer""" + + def test_get_size(self): + """ + Tests get size method + """ + test_arr = da.zeros((2, 2), dtype=np.uint8) + + expected_result = np.prod(test_arr.shape) * test_arr.itemsize + + arr_size = zarr_writer._get_size( + shape=test_arr.shape, itemsize=test_arr.itemsize + ) + self.assertEqual(expected_result, arr_size) + + def test_get_size_fail(self): + """ + Tests get size failure + """ + test_arr = da.zeros((0, 2), dtype=np.uint8) + + with self.assertRaises(ValueError): + zarr_writer._get_size( + shape=test_arr.shape, itemsize=test_arr.itemsize + ) + + def test_closer_to_target(self): + """Tests closer to target function""" + test_arr_1 = da.zeros((10, 10), dtype=np.uint8) + test_arr_2 = da.zeros((20, 20), dtype=np.uint8) + target_bytes = 4 + + shape_close_target = zarr_writer._closer_to_target( + shape1=test_arr_1.shape, + shape2=test_arr_2.shape, + target_bytes=target_bytes, + itemsize=test_arr_1.itemsize, + ) + + shape_close_target_2 = zarr_writer._closer_to_target( + shape1=test_arr_2.shape, + shape2=test_arr_1.shape, + target_bytes=target_bytes, + itemsize=test_arr_1.itemsize, + ) + + self.assertEqual(test_arr_1.shape, shape_close_target) + self.assertEqual(test_arr_1.shape, shape_close_target_2) diff --git a/tests/io/__init__.py b/tests/io/__init__.py new file mode 100644 index 0000000..b0ace37 --- /dev/null +++ b/tests/io/__init__.py @@ -0,0 +1 @@ +"""Unit tests of io package""" diff --git a/tests/io/test_utils.py b/tests/io/test_utils.py new file mode 100644 index 0000000..c715e1f --- /dev/null +++ b/tests/io/test_utils.py @@ -0,0 +1,84 @@ +""" +Unit tests of io utilities +""" + +import os +import unittest +from pathlib import Path + +import dask.array as da +import numpy as np + +from aind_smartspim_data_transformation.io import utils + + +class IoUtilitiesTest(unittest.TestCase): + """Class for testing the io utilities""" + + def setUp(self): + """Setting up temporary folder directory""" + current_path = Path(os.path.abspath(__file__)).parent + self.test_local_json_path = current_path.joinpath( + "../resources/local_json.json" + ) + + def test_add_leading_dim(self): + """ + Tests that a new dimension is added + to the array. + """ + test_arr = da.zeros((2, 2), dtype=np.uint8) + transformed_arr = utils.add_leading_dim(data=test_arr) + + self.assertEqual(test_arr.ndim + 1, transformed_arr.ndim) + + def test_extract_data(self): + """ + Tests the array data is extracted + when there are expanded dimensions. + """ + test_arr = da.zeros((1, 1, 1, 2, 2), dtype=np.uint8) + transformed_arr_no_lead = utils.extract_data(arr=test_arr) + transformed_arr_with_lead = utils.extract_data( + arr=test_arr, last_dimensions=3 + ) + + self.assertEqual(2, transformed_arr_no_lead.ndim) + self.assertEqual(test_arr.shape[-2:], transformed_arr_no_lead.shape) + + self.assertEqual(3, transformed_arr_with_lead.ndim) + self.assertEqual(test_arr.shape[-3:], transformed_arr_with_lead.shape) + + def test_extract_data_fail(self): + """ + Tests failure of extract data + """ + test_arr = da.zeros((2, 2), dtype=np.uint8) + + with self.assertRaises(ValueError): + utils.extract_data(arr=test_arr, last_dimensions=3) + + def test_pad_array(self): + """Tests padding an array""" + test_arr = da.zeros((2, 2), dtype=np.uint8) + padded_test_arr = utils.pad_array_n_d(arr=test_arr) + + self.assertEqual(5, padded_test_arr.ndim) + + padded_test_arr = utils.pad_array_n_d(arr=test_arr, dim=-1) + self.assertEqual(test_arr.ndim, padded_test_arr.ndim) + + def test_pad_array_fail(self): + """Tests padding an array""" + test_arr = da.zeros((2, 2), dtype=np.uint8) + + with self.assertRaises(ValueError): + utils.pad_array_n_d(arr=test_arr, dim=6) + + def test_read_json_as_dict(self): + """ + Tests successful reading of a dictionary + """ + expected_result = {"some_key": "some_value"} + result = utils.read_json_as_dict(self.test_local_json_path) + self.assertEqual(expected_result, result) diff --git a/tests/resources/local_json.json b/tests/resources/local_json.json new file mode 100644 index 0000000..fecb592 --- /dev/null +++ b/tests/resources/local_json.json @@ -0,0 +1,3 @@ +{ + "some_key": "some_value" +} \ No newline at end of file diff --git a/tests/test_smartspim_job.py b/tests/test_smartspim_job.py index fab1289..5c60f46 100644 --- a/tests/test_smartspim_job.py +++ b/tests/test_smartspim_job.py @@ -1,6 +1,8 @@ """Tests for the SmartSPIM data transfer""" import os +import shutil +import tempfile import unittest from pathlib import Path @@ -22,9 +24,12 @@ class SmartspimCompressionTest(unittest.TestCase): @classmethod def setUpClass(cls) -> None: """Setup basic job settings and job that can be used across tests""" + # Folder to test the zarr writing from PNGs + cls.temp_folder = tempfile.mkdtemp(prefix="unittest_") + basic_job_settings = SmartspimJobSettings( input_source=DATA_DIR, - output_directory=Path("output_dir"), + output_directory=Path(cls.temp_folder), ) cls.basic_job_settings = basic_job_settings cls.basic_job = SmartspimCompressionJob( @@ -63,7 +68,7 @@ def test_compressor(self): current_blosc = Blosc(**self.basic_job.job_settings.compressor_kwargs) self.assertEqual(compressor, current_blosc) - def test_failing_compressor(self): + def test_getting_compressor_fail(self): """Test failed compression with Blosc""" with self.assertRaises(Exception): @@ -77,9 +82,15 @@ def test_failing_compressor(self): failed_basic_job_settings = failed_basic_job_settings SmartspimCompressionJob(job_settings=failed_basic_job_settings) - def test_compress_and_write_channels(self): + def test_run_job(self): """Tests SmartSPIM compression and zarr writing""" - pass + self.basic_job.run_job() + + @classmethod + def tearDownClass(cls) -> None: + """Tear down class method to clean up""" + if os.path.exists(cls.temp_folder): + shutil.rmtree(cls.temp_folder, ignore_errors=True) if __name__ == "__main__":