From de3202c044d006de5a57aefdc3e073bcfd7df9af Mon Sep 17 00:00:00 2001 From: camilolaiton Date: Wed, 29 May 2024 20:02:47 +0000 Subject: [PATCH 1/9] updating image data transformations to convert from pngs to zarr --- pyproject.toml | 13 +- scripts/singularity_build.def | 12 + .../__init__.py | 2 +- .../io/__init__.py | 7 + .../io/_io.py | 643 ++++++++++++++ .../io/utils.py | 112 +++ .../models.py | 11 + .../png_to_zarr.py | 794 ++++++++++++++++++ .../smartspim_job.py | 236 ++++++ .../zarr_utilities.py | 621 ++++++++++++++ .../zarr_writer.py | 218 +++++ 11 files changed, 2667 insertions(+), 2 deletions(-) create mode 100644 scripts/singularity_build.def create mode 100644 src/aind_smartspim_data_transformation/io/__init__.py create mode 100644 src/aind_smartspim_data_transformation/io/_io.py create mode 100644 src/aind_smartspim_data_transformation/io/utils.py create mode 100644 src/aind_smartspim_data_transformation/models.py create mode 100644 src/aind_smartspim_data_transformation/png_to_zarr.py create mode 100644 src/aind_smartspim_data_transformation/smartspim_job.py create mode 100644 src/aind_smartspim_data_transformation/zarr_utilities.py create mode 100644 src/aind_smartspim_data_transformation/zarr_writer.py diff --git a/pyproject.toml b/pyproject.toml index decdb68..d4ce1a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" name = "aind-smartspim-data-transformation" description = "Generated from aind-library-template" license = {text = "MIT"} -requires-python = ">=3.7" +requires-python = ">=3.9" authors = [ {name = "Allen Institute for Neural Dynamics"} ] @@ -17,6 +17,17 @@ readme = "README.md" dynamic = ["version"] dependencies = [ + 'aind-data-transformation>=0.0.18', + 'zarr==2.18.2', + 'numcodecs==0.11.0', + 'dask-image==2024.5.3', + 'xarray_multiscale==1.1.0', + 'bokeh==2.4.2', + 'pims==0.6.1', + 'dask[distributed]==2024.5.1', + 'ome-zarr==0.8.2', + 'imagecodecs[all]==2023.3.16', + 'natsort==8.4.0', ] [project.optional-dependencies] diff --git a/scripts/singularity_build.def b/scripts/singularity_build.def new file mode 100644 index 0000000..b2f9e1a --- /dev/null +++ b/scripts/singularity_build.def @@ -0,0 +1,12 @@ +Bootstrap: docker +From: python:3.10-bullseye +Stage: build + +%setup + # Copy project directory into container + cp -R . ${SINGULARITY_ROOTFS}/aind-smartspim-data-transformation + +%post + cd ${SINGULARITY_ROOTFS}/aind-smartspim-data-transformation + pip install . --no-cache-dir + rm -rf ${SINGULARITY_ROOTFS}/aind-smartspim-data-transformation \ No newline at end of file diff --git a/src/aind_smartspim_data_transformation/__init__.py b/src/aind_smartspim_data_transformation/__init__.py index d0a8547..831e353 100644 --- a/src/aind_smartspim_data_transformation/__init__.py +++ b/src/aind_smartspim_data_transformation/__init__.py @@ -1,2 +1,2 @@ """Init package""" -__version__ = "0.0.0" +__version__ = "0.0.1" diff --git a/src/aind_smartspim_data_transformation/io/__init__.py b/src/aind_smartspim_data_transformation/io/__init__.py new file mode 100644 index 0000000..baff90e --- /dev/null +++ b/src/aind_smartspim_data_transformation/io/__init__.py @@ -0,0 +1,7 @@ +""" +Input and output operations +""" + +# flake8: noqa: F403 +from ._io import * +from .utils import * diff --git a/src/aind_smartspim_data_transformation/io/_io.py b/src/aind_smartspim_data_transformation/io/_io.py new file mode 100644 index 0000000..6b0d216 --- /dev/null +++ b/src/aind_smartspim_data_transformation/io/_io.py @@ -0,0 +1,643 @@ +""" +Module that defines base Image Reader class +and the available metrics +""" + +import os +from abc import ABC, abstractmethod, abstractproperty +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import dask.array as da +import imageio as iio +import numpy as np +import pims +import tifffile +import zarr +from dask.array.core import Array +from dask.base import tokenize +from dask_image.imread import imread as daimread +from skimage.io import imread as sk_imread + +from .utils import add_leading_dim, read_json_as_dict + +""" +File that defines the constants used +in the package +""" + +from pathlib import Path +from typing import Union + +import dask.array as da +import numpy as np + +# IO types +PathLike = Union[str, Path] +ArrayLike = Union[da.Array, np.ndarray] + +class ImageReader(ABC): + """ + Abstract class to create image readers + classes + """ + + def __init__(self, data_path: PathLike) -> None: + """ + Class constructor of image reader. + + Parameters + ------------------------ + data_path: PathLike + Path where the image is located + + """ + + self.__data_path = data_path + super().__init__() + + @abstractmethod + def as_dask_array(self, chunk_size: Optional[Any] = None) -> da.Array: + """ + Abstract method to return the image as a dask array. + + Parameters + ------------------------ + chunk_size: Optional[Any] + If provided, the image will be rechunked to the desired + chunksize + + Returns + ------------------------ + da.Array + Dask array with the image + + """ + pass + + @abstractmethod + def as_numpy_array(self) -> np.ndarray: + """ + Abstract method to return the image as a numpy array. + + Returns + ------------------------ + np.ndarray + Numpy array with the image + + """ + pass + + @abstractmethod + def metadata(self) -> Dict: + """ + Abstract method that return the image metadata. + + Returns + ------- + Dict + Dictionary with image metadata + """ + pass + + @abstractmethod + def close_handler(self) -> None: + """ + Abstract method to close the image hander when it's necessary. + + """ + pass + + @abstractmethod + def indexing(self, xv: np.array, yv: np.array) -> ArrayLike: + """ + Abstract method to index arrays + """ + pass + + @abstractproperty + def shape(self) -> Tuple: + """ + Abstract method to return the shape of the image. + + Returns + ------------------------ + Tuple + Tuple with the shape of the image + + """ + pass + + @abstractproperty + def chunks(self) -> Tuple: + """ + Abstract method to return the chunks of the image if it's possible. + + Returns + ------------------------ + Tuple + Tuple with the chunks of the image + + """ + pass + + @property + def data_path(self) -> PathLike: + """ + Getter to return the path where the image is located. + + Returns + ------------------------ + PathLike + Path of the image + + """ + return self.__data_path + + @data_path.setter + def data_path(self, new_data_path: PathLike) -> None: + """ + Setter of the path attribute where the image is located. + + Parameters + ------------------------ + new_data_path: PathLike + New path of the image + + """ + self.__data_path = new_data_path + + +class OMEZarrReader(ImageReader): + """ + OMEZarr reader class + """ + + def __init__( + self, + data_path: PathLike, + multiscale: Optional[str] = "0", + ) -> None: + """ + Class constructor of image OMEZarr reader. + + Parameters + ------------------------ + data_path: PathLike + Path where the image is located + + multiscale: Optional[str] + Desired multiscale to read from the image. Default: "0" + which is supposed to be the highest resolution + + """ + + # Adding multiscale to path + if isinstance(data_path, str): + data_path = f"{data_path}/{multiscale}" + + else: + data_path = data_path.joinpath(str(multiscale)) + + super().__init__(data_path=data_path) + self.lazy_image = da.from_zarr(self.data_path) + + def indexing(self, xv: np.array, yv: np.array) -> ArrayLike: + """ + Indexes arrays using X and Y locations + generated by a meshgrid + + Parameters + ---------- + xv: np.array + X locations generated by a meshgrid + + yv: np.array + y locations generated by a meshgrid + + Returns + ---------- + Data in provided locations + """ + + return self.lazy_image.vindex(xv, yv) + + def as_dask_array(self, chunk_size: Optional[Any] = None) -> da.array: + """ + Method to return the image as a dask array. + + Parameters + ------------------------ + chunk_size: Optional[Any] + If provided, the image will be rechunked to the desired + chunksize + + Returns + ------------------------ + da.Array + Dask array with the image + + """ + + if chunk_size: + return self.lazy_image.rechunk(chunks=chunk_size) + + return self.lazy_image + + def as_numpy_array(self): + """ + Method to return the image as a numpy array. + + Returns + ------------------------ + np.ndarray + Numpy array with the image + + """ + return zarr.open(self.data_path, "r")[:] + + def metadata(self) -> Dict: + """ + Returns the image metadata. + + Returns + ------- + Dict + Dictionary with image metadata + """ + metadata = {} + # Removing multiscale to path + if isinstance(self.data_path, str): + data_path = Path(self.data_path) + + zattrs_metadata = "" + zarray_metadata = "" + # Checking inside and outside of folder due to dimension separator "." or "/" + for path in [data_path, data_path.parent]: + if path.joinpath(".zattrs").exists(): + zattrs_metadata = path.joinpath(".zattrs") + + if path.joinpath(".zarray").exists(): + zarray_metadata = path.joinpath(".zarray") + + print(f"Reading metadata from {zattrs_metadata} and {zarray_metadata}") + + metadata[".zattrs"] = read_json_as_dict(zattrs_metadata) + metadata[".zarray"] = read_json_as_dict(zarray_metadata) + + return metadata + + def close_handler(self) -> None: + """ + Method to close the image hander when it's necessary. + """ + pass + + @property + def shape(self): + """ + Method to return the shape of the image. + + Returns + ------------------------ + Tuple + Tuple with the shape of the image + + """ + return zarr.open(self.data_path, "r").shape + + @property + def chunks(self): + """ + Method to return the chunks of the image. + + Returns + ------------------------ + Tuple + Tuple with the chunks of the image + + """ + return zarr.open(self.data_path, "r").chunks + + +class TiffReader(ImageReader): + """ + TiffReader class + """ + + def __init__(self, data_path: PathLike) -> None: + """ + Class constructor of image Tiff reader. + + Parameters + ------------------------ + data_path: PathLike + Path where the image is located + + """ + super().__init__(data_path) + self.tiff = tifffile.TiffFile(self.data_path) + + def indexing(self, xv: np.array, yv: np.array) -> ArrayLike: + """ + Indexes arrays using X and Y locations + generated by a meshgrid + + Parameters + ---------- + xv: np.array + X locations generated by a meshgrid + + yv: np.array + y locations generated by a meshgrid + + Returns + ---------- + Data in provided locations + """ + + return self.tiff.asarray()[xv, yv] + + def as_dask_array( + self, shape: Optional[Any] = None, dtype: Optional[Any] = None + ) -> da.Array: + """ + Method to return the image as a dask array. + + Parameters + ------------------------ + shape: Optional[Any] + Shape of the image + + Returns + ------------------------ + da.Array + Dask array with the image + + """ + data_path = str(self.data_path) + name = "imread-%s" % tokenize( + data_path, map(os.path.getmtime, data_path) + ) + + if shape is None or dtype is None: + with pims.open(data_path) as imgs: + shape = (1,) + (len(imgs),) + imgs.frame_shape + dtype = np.dtype(imgs.pixel_type) + + key = [(name,) + (0,) * len(shape)] + value = [(add_leading_dim, (sk_imread, data_path))] + dask_arr = dict(zip(key, value)) + chunk_size = tuple((d,) for d in shape) + + return Array(dask_arr, name, chunk_size, dtype) + + def metadata(self) -> Dict: + """ + Returns the image metadata. + + Returns + ------- + Dict + Dictionary with image metadata + """ + metadata = {} + with pims.open(data_path) as imgs: + metadata["shape"] = (1,) + (len(imgs),) + imgs.frame_shape + metadata["dtype"] = np.dtype(imgs.pixel_type) + + return metadata + + def as_numpy_array(self) -> np.ndarray: + """ + Abstract method to return the image as a numpy array. + + Returns + ------------------------ + np.ndarray + Numpy array with the image + + """ + return self.tiff.asarray() + + @property + def shape(self) -> Tuple: + """ + Abstract method to return the shape of the image. + + Returns + ------------------------ + Tuple + Tuple with the shape of the image + + """ + with pims.open(str(self.data_path)) as imgs: + shape = (len(imgs),) + imgs.frame_shape + + return shape + + @property + def chunks(self) -> Tuple: + """ + Abstract method to return the chunks of the image if it's possible. + + Returns + ------------------------ + Tuple + Tuple with the chunks of the image + + """ + return self.tiff.aszarr().chunks + + def close_handler(self) -> None: + """ + Closes image handler + """ + if self.tiff is not None: + self.tiff.close() + self.tiff = None + + def __del__(self) -> None: + """Overriding destructor to safely close image""" + self.close_handler() + + +class PngReader(ImageReader): + """ + PngReader class + """ + + def __init__(self, data_path: PathLike) -> None: + """ + Class constructor of image PNG reader. + + Parameters + ------------------------ + data_path: PathLike + Path where the image is located + + """ + super().__init__(data_path) + + def indexing(self, xv: np.array, yv: np.array) -> ArrayLike: + """ + Indexes arrays using X and Y locations + generated by a meshgrid + + Parameters + ---------- + xv: np.array + X locations generated by a meshgrid + + yv: np.array + y locations generated by a meshgrid + + Returns + ---------- + Data in provided locations + """ + return self.as_numpy_array[xv, yv] + + def as_dask_array(self, chunk_size: Optional[Any] = None) -> da.Array: + """ + Method to return the image as a dask array. + + Parameters + ------------------------ + chunk_size: Optional[Any] + If provided, the image will be rechunked to the desired + chunksize + + Returns + ------------------------ + da.Array + Dask array with the image + + """ + return daimread(self.data_path, arraytype="numpy") + + def as_numpy_array(self) -> np.ndarray: + """ + Abstract method to return the image as a numpy array. + + Returns + ------------------------ + np.ndarray + Numpy array with the image + + """ + return np.array(iio.imread(self.data_path)) + + @property + def shape(self) -> Tuple: + """ + Abstract method to return the shape of the image. + + Returns + ------------------------ + Tuple + Tuple with the shape of the image + + """ + with pims.open(str(self.data_path)) as imgs: + shape = (len(imgs),) + imgs.frame_shape + + return shape + + def metadata(self) -> Dict: + """ + Returns the image metadata. + + Returns + ------- + Dict + Dictionary with image metadata + """ + metadata = {} + with pims.open(data_path) as imgs: + metadata["shape"] = (len(imgs),) + imgs.frame_shape + + return metadata + + @property + def chunks(self) -> Tuple: + """ + Abstract method to return the chunks of the image if it's possible. + + Returns + ------------------------ + Tuple + Tuple with the chunks of the image + + """ + return self.as_dask_array().chunksize + + def close_handler(self) -> None: + """ + Closes image handler + """ + pass + + def __del__(self) -> None: + """Overriding destructor to safely close image""" + self.close_handler() + + +class ImageReaderFactory: + """ + Image reader factory class + """ + + def __init__(self): + """ + Class to create the image reader factory. + """ + self.__extensions = [".zarr", ".tif", ".tiff", ".png"] + self.factory = { + ".zarr": OMEZarrReader, + ".tif": TiffReader, + ".tiff": TiffReader, + ".png": PngReader, + } + + @property + def extensions(self) -> List: + """ + Method to return the allowed format extensions of the images. + + Returns + ------------------------ + List + List with the allowed image format extensions + + """ + return self.__extensions + + def create( + self, data_path: PathLike, parse_path: Optional[bool] = True, **kwargs + ) -> ImageReader: + """ + Method to create the image reader based on the format. + + Parameters + ---------- + data_path: PathLike + Path where the data is located + + parse_path: Optional[bool] + If True, parses the path with the pathlib.Path object. + Not useful when we're trying to access data in S3. + + Returns + ------- + List + List with the allowed image format extensions + + """ + path_cast = Path(data_path) if parse_path else data_path + ext = Path(data_path).suffix + + if ext not in self.__extensions: + raise NotImplementedError(f"File type {ext} not supported") + + return self.factory[ext](path_cast, **kwargs) \ No newline at end of file diff --git a/src/aind_smartspim_data_transformation/io/utils.py b/src/aind_smartspim_data_transformation/io/utils.py new file mode 100644 index 0000000..0925bb4 --- /dev/null +++ b/src/aind_smartspim_data_transformation/io/utils.py @@ -0,0 +1,112 @@ +""" +Utility functions for image readers +""" + +import json +import os +from typing import Optional + +from pathlib import Path +from typing import Union + +import dask.array as da +import numpy as np + +ArrayLike = Union[da.Array, np.ndarray] + + +def add_leading_dim(data: ArrayLike): + """ + Adds a new dimension to existing data. + Parameters + ------------------------ + arr: ArrayLike + Dask/numpy array that contains image data. + + Returns + ------------------------ + ArrayLike: + Padded dask/numpy array. + """ + + return data[None, ...] + + +def extract_data( + arr: ArrayLike, last_dimensions: Optional[int] = None +) -> ArrayLike: + """ + Extracts n dimensional data (numpy array or dask array) + given expanded dimensions. + e.g., (1, 1, 1, 1600, 2000) -> (1600, 2000) + e.g., (1, 1600, 2000) -> (1600, 2000) + e.g., (1, 1, 2, 1600, 2000) -> (2, 1600, 2000) + + Parameters + ------------------------ + arr: ArrayLike + Numpy or dask array with image data. It is assumed + that the last dimensions of the array contain + the information about the image. + + last_dimensions: Optional[int] + If given, it selects the number of dimensions given + stating from the end + of the array + e.g., arr=(1, 1, 1600, 2000) last_dimensions=3 -> (1, 1600, 2000) + e.g., arr=(1, 1, 1600, 2000) last_dimensions=1 -> (2000) + + Raises + ------------------------ + ValueError: + Whenever the last dimensions value is higher + than the array dimensions. + + Returns + ------------------------ + ArrayLike: + Reshaped array with the selected indices. + """ + + if last_dimensions is not None: + if last_dimensions > arr.ndim: + raise ValueError( + "Last dimensions should be lower than array dimensions" + ) + + else: + last_dimensions = len(arr.shape) - arr.shape.count(1) + + dynamic_indices = [slice(None)] * arr.ndim + + for idx in range(arr.ndim - last_dimensions): + dynamic_indices[idx] = 0 + + return arr[tuple(dynamic_indices)] + + +def read_json_as_dict(filepath: str) -> dict: + """ + Reads a json as dictionary. + + Parameters + ------------------------ + + filepath: PathLike + Path where the json is located. + + Returns + ------------------------ + + dict: + Dictionary with the data the json has. + + """ + + dictionary = {} + + if os.path.exists(filepath): + with open(filepath) as json_file: + dictionary = json.load(json_file) + + return dictionary \ No newline at end of file diff --git a/src/aind_smartspim_data_transformation/models.py b/src/aind_smartspim_data_transformation/models.py new file mode 100644 index 0000000..19ecd03 --- /dev/null +++ b/src/aind_smartspim_data_transformation/models.py @@ -0,0 +1,11 @@ +"""Helpful models used in the ephys compression job""" + +from enum import Enum + +from typing import List +from numcodecs import Blosc + +class CompressorName(str, Enum): + """Enum for compression algorithms a user can select""" + + BLOSC = Blosc.codec_id diff --git a/src/aind_smartspim_data_transformation/png_to_zarr.py b/src/aind_smartspim_data_transformation/png_to_zarr.py new file mode 100644 index 0000000..15172e9 --- /dev/null +++ b/src/aind_smartspim_data_transformation/png_to_zarr.py @@ -0,0 +1,794 @@ +""" +SmartSPIM Zarr writer. It takes an input path +where 3D fused chunked files are located, +reconstructs the volume as a dask array and +writes it in OME-Zarr format +""" + +import logging +import multiprocessing +import os +import time +from typing import Dict, Hashable, List, Optional, Sequence, Tuple, Union, cast + +import dask +import dask.array as da +# import matplotlib.pyplot as plt +import numpy as np +import pims +import xarray_multiscale +import zarr +from dask.array.core import Array +from dask.base import tokenize +from dask.distributed import Client, LocalCluster, performance_report +# from distributed import wait +from numcodecs import blosc +from ome_zarr.format import CurrentFormat +from ome_zarr.io import parse_url +from ome_zarr.writer import write_multiscales_metadata +from skimage.io import imread as sk_imread + +from aind_smartspim_data_transformation.zarr_writer import BlockedArrayWriter +from aind_smartspim_data_transformation.zarr_utilities import * +from pathlib import Path +from aind_smartspim_data_transformation.io import PngReader +from datetime import datetime +from dask import config as da_cfg + +def set_dask_config(dask_folder: str): + """ + Sets dask configuration + + Parameters + ---------- + dask_folder: str + Path to the temporary directory and local directory + of workers in dask. + """ + # Setting dask configuration + da_cfg.set( + { + "temporary-directory": dask_folder, + "local_directory": dask_folder, + # "tcp-timeout": "300s", + "array.chunk-size": "128MiB", + "distributed.worker.memory.target": 0.90, # 0.85, + "distributed.worker.memory.spill": 0.92, # False,# + "distributed.worker.memory.pause": 0.95, # False,# + "distributed.worker.memory.terminate": 0.98, + } + ) + +def _build_ome( + data_shape: Tuple[int, ...], + image_name: str, + channel_names: Optional[List[str]] = None, + channel_colors: Optional[List[int]] = None, + channel_minmax: Optional[List[Tuple[float, float]]] = None, + channel_startend: Optional[List[Tuple[float, float]]] = None, +) -> Dict: + """ + Create the necessary metadata for an OME tiff image + + Parameters + ---------- + data_shape: A 5-d tuple, assumed to be TCZYX order + image_name: The name of the image + channel_names: The names for each channel + channel_colors: List of all channel colors + channel_minmax: List of all (min, max) pairs of channel pixel + ranges (min value of darkest pixel, max value of brightest) + channel_startend: List of all pairs for rendering where start is + a pixel value of darkness and end where a pixel value is + saturated + + Returns + ------- + Dict: An "omero" metadata object suitable for writing to ome-zarr + """ + if channel_names is None: + channel_names = [f"Channel:{image_name}:{i}" for i in range(data_shape[1])] + if channel_colors is None: + channel_colors = [i for i in range(data_shape[1])] + if channel_minmax is None: + channel_minmax = [(0.0, 1.0) for _ in range(data_shape[1])] + if channel_startend is None: + channel_startend = channel_minmax + + ch = [] + for i in range(data_shape[1]): + ch.append( + { + "active": True, + "coefficient": 1, + "color": f"{channel_colors[i]:06x}", + "family": "linear", + "inverted": False, + "label": channel_names[i], + "window": { + "end": float(channel_startend[i][1]), + "max": float(channel_minmax[i][1]), + "min": float(channel_minmax[i][0]), + "start": float(channel_startend[i][0]), + }, + } + ) + + omero = { + "id": 1, # ID in OMERO + "name": image_name, # Name as shown in the UI + "version": "0.4", # Current version + "channels": ch, + "rdefs": { + "defaultT": 0, # First timepoint to show the user + "defaultZ": data_shape[2] // 2, # First Z section to show the user + "model": "color", # "color" or "greyscale" + }, + } + return omero + + +def _compute_scales( + scale_num_levels: int, + scale_factor: Tuple[float, float, float], + pixelsizes: Tuple[float, float, float], + chunks: Tuple[int, int, int, int, int], + data_shape: Tuple[int, int, int, int, int], + translation: Optional[List[float]] = None, +) -> Tuple[List, List]: + """Generate the list of coordinate transformations and associated chunk options. + + Parameters + ---------- + scale_num_levels: the number of downsampling levels + scale_factor: a tuple of scale factors in each spatial dimension (Z, Y, X) + pixelsizes: a list of pixel sizes in each spatial dimension (Z, Y, X) + chunks: a 5D tuple of integers with size of each chunk dimension (T, C, Z, Y, X) + data_shape: a 5D tuple of the full resolution image's shape + translation: a 5 element list specifying the offset in physical units in each dimension + + Returns + ------- + A tuple of the coordinate transforms and chunk options + """ + transforms = [ + [ + # the voxel size for the first scale level + { + "type": "scale", + "scale": [ + 1.0, + 1.0, + pixelsizes[0], + pixelsizes[1], + pixelsizes[2], + ], + } + ] + ] + if translation is not None: + transforms[0].append({"type": "translation", "translation": translation}) + chunk_sizes = [] + lastz = data_shape[2] + lasty = data_shape[3] + lastx = data_shape[4] + opts = dict( + chunks=( + 1, + 1, + min(lastz, chunks[2]), + min(lasty, chunks[3]), + min(lastx, chunks[4]), + ) + ) + chunk_sizes.append(opts) + if scale_num_levels > 1: + for i in range(scale_num_levels - 1): + last_transform = transforms[-1][0] + last_scale = cast(List, last_transform["scale"]) + transforms.append( + [ + { + "type": "scale", + "scale": [ + 1.0, + 1.0, + last_scale[2] * scale_factor[0], + last_scale[3] * scale_factor[1], + last_scale[4] * scale_factor[2], + ], + } + ] + ) + if translation is not None: + transforms[-1].append( + {"type": "translation", "translation": translation} + ) + lastz = int(np.ceil(lastz / scale_factor[0])) + lasty = int(np.ceil(lasty / scale_factor[1])) + lastx = int(np.ceil(lastx / scale_factor[2])) + opts = dict( + chunks=( + 1, + 1, + min(lastz, chunks[2]), + min(lasty, chunks[3]), + min(lastx, chunks[4]), + ) + ) + chunk_sizes.append(opts) + + return transforms, chunk_sizes + + +def _get_axes_5d( + time_unit: str = "millisecond", space_unit: str = "micrometer" +) -> List[Dict]: + """Generate the list of axes. + + Parameters + ---------- + time_unit: the time unit string, e.g., "millisecond" + space_unit: the space unit string, e.g., "micrometer" + + Returns + ------- + A list of dictionaries for each axis + """ + axes_5d = [ + {"name": "t", "type": "time", "unit": f"{time_unit}"}, + {"name": "c", "type": "channel"}, + {"name": "z", "type": "space", "unit": f"{space_unit}"}, + {"name": "y", "type": "space", "unit": f"{space_unit}"}, + {"name": "x", "type": "space", "unit": f"{space_unit}"}, + ] + return axes_5d + + +def write_ome_ngff_metadata( + group: zarr.Group, + arr: da.Array, + image_name: str, + n_lvls: int, + scale_factors: tuple, + voxel_size: tuple, + channel_names: List[str] = None, + channel_colors: List[str] = None, + channel_minmax: List[float] = None, + channel_startend: List[float] = None, + metadata: dict = None, +): + """ + Write OME-NGFF metadata to a Zarr group. + + Parameters + ---------- + group : zarr.Group + The output Zarr group. + arr : array-like + The input array. + image_name : str + The name of the image. + n_lvls : int + The number of pyramid levels. + scale_factors : tuple + The scale factors for downsampling along each dimension. + voxel_size : tuple + The voxel size along each dimension. + channel_names: List[str] + List of channel names to add to the OMENGFF metadata + channel_colors: List[str] + List of channel colors to visualize the data + chanel_minmax: List[float] + List of channel min and max values based on the + image dtype + channel_startend: List[float] + List of the channel start and end metadata. This is + used for visualization. The start and end range might be + different from the min max and it is usually inside the + range + metadata: dict + Extra metadata to write in the OME-NGFF metadata + """ + if metadata is None: + metadata = {} + fmt = CurrentFormat() + + # Building the OMERO metadata + ome_json = _build_ome( + arr.shape, + image_name, + channel_names=channel_names, + channel_colors=channel_colors, + channel_minmax=channel_minmax, + channel_startend=channel_startend, + ) + group.attrs["omero"] = ome_json + axes_5d = _get_axes_5d() + coordinate_transformations, chunk_opts = _compute_scales( + n_lvls, scale_factors, voxel_size, arr.chunksize, arr.shape, None + ) + fmt.validate_coordinate_transformations( + arr.ndim, n_lvls, coordinate_transformations + ) + # Setting coordinate transfomations + datasets = [{"path": str(i)} for i in range(n_lvls)] + if coordinate_transformations is not None: + for dataset, transform in zip(datasets, coordinate_transformations): + dataset["coordinateTransformations"] = transform + + # Writing the multiscale metadata + write_multiscales_metadata(group, datasets, fmt, axes_5d, **metadata) + + +def create_smartspim_opts(codec: str, compression_level: int) -> dict: + """ + Creates SmartSPIM options for writing + the OMEZarr. + + Parameters + ---------- + codec: str + Image codec used to write the image + + compression_level: int + Compression level for the image + + Returns + ------- + dict + Dictionary with the blosc compression + to write the SmartSPIM image + """ + return { + "compressor": blosc.Blosc( + cname=codec, clevel=compression_level, shuffle=blosc.SHUFFLE + ) + } + + +def _get_pyramid_metadata(): + """ + Gets the image pyramid metadata + using xarray_multiscale package + """ + return { + "metadata": { + "description": "Downscaling implementation based on the windowed mean of the original array", + "method": "xarray_multiscale.reducers.windowed_mean", + "version": str(xarray_multiscale.__version__), + "args": "[false]", + # No extra parameters were used different + # from the orig. array and scales + "kwargs": {}, + } + } + + +def compute_pyramid( + data: dask.array.core.Array, + n_lvls: int, + scale_axis: Tuple[int], + chunks: Union[str, Sequence[int], Dict[Hashable, int]] = "auto", +) -> List[dask.array.core.Array]: + """ + Computes the pyramid levels given an input full resolution image data + + Parameters + ------------------------ + + data: dask.array.core.Array + Dask array of the image data + + n_lvls: int + Number of downsampling levels + that will be applied to the original image + + scale_axis: Tuple[int] + Scaling applied to each axis + + chunks: Union[str, Sequence[int], Dict[Hashable, int]] + chunksize that will be applied to the multiscales + Default: "auto" + + Returns + ------------------------ + + Tuple[List[dask.array.core.Array], Dict]: + List with the downsampled image(s) and dictionary + with image metadata + """ + + metadata = _get_pyramid_metadata() + + pyramid = xarray_multiscale.multiscale( + array=data, + reduction=xarray_multiscale.reducers.windowed_mean, # func + scale_factors=scale_axis, # scale factors + preserve_dtype=True, + chunks=chunks, + )[:n_lvls] + + return [pyramid_level.data for pyramid_level in pyramid], metadata + + +def wavelength_to_hex(wavelength: int) -> int: + """ + Converts wavelength to corresponding color hex value. + + Parameters + ------------------------ + wavelength: int + Integer value representing wavelength. + + Returns + ------------------------ + int: + Hex value color. + """ + # Each wavelength key is the upper bound to a wavelgnth band. + # Wavelengths range from 380-750nm. + # Color map wavelength/hex pairs are generated by sampling along a CIE diagram arc. + color_map = { + 460: 0x690AFE, # Purple + 470: 0x3F2EFE, # Blue-Purple + 480: 0x4B90FE, # Blue + 490: 0x59D5F8, # Blue-Green + 500: 0x5DF8D6, # Green + 520: 0x5AFEB8, # Green + 540: 0x58FEA1, # Green + 560: 0x51FF1E, # Green + 565: 0xBBFB01, # Green-Yellow + 575: 0xE9EC02, # Yellow + 580: 0xF5C503, # Yellow-Orange + 590: 0xF39107, # Orange + 600: 0xF15211, # Orange-Red + 620: 0xF0121E, # Red + 750: 0xF00050, + } # Pink + + for ub, hex_val in color_map.items(): + if wavelength < ub: # Exclusive + return hex_val + return hex_val + + +def add_leading_dim(data: ArrayLike): + """ + Adds a leading dimension + + Parameters + ------------------------ + + data: ArrayLike + Input array that will have the + leading dimension + + Returns + ------------------------ + + ArrayLike: + Array with the new dimension in front. + """ + return data[None, ...] + + +def lazy_tiff_reader( + filename: str, + shape: Optional[Tuple[int]] = None, + dtype: Optional[type] = None, +): + """ + Creates a dask array to read an image located in a specific path. + + Parameters + ------------------------ + + filename: PathLike + Path to the image + + shape: Optional[Tuple[int]] + Optional shape provided to the + reader to avoid accessing to the + metadata + + dtype: Optional[type] + Optional array type to the reader + to avoid accessing to the metadata + + Returns + ------------------------ + + dask.array.core.Array + Array representing the image data + """ + name = "imread-%s" % tokenize(filename, map(os.path.getmtime, filename)) + + if dtype is None or shape is None: + with pims.open(filename) as imgs: + dtype = np.dtype(imgs.pixel_type) + shape = (1,) + (len(imgs),) + imgs.frame_shape + + key = [(name,) + (0,) * len(shape)] + value = [(add_leading_dim, (sk_imread, filename))] + dask_arr = dict(zip(key, value)) + chunks = tuple((d,) for d in shape) + + return Array(dask_arr, name, chunks, dtype) + + +def smartspim_channel_zarr_writer( + image_data: ArrayLike, + output_path: PathLike, + voxel_size: List[float], + final_chunksize: List[int], + scale_factor: List[int], + n_lvls: int, + channel_name: str, + logger: logging.Logger, + stack_name: str, + writing_options, + client +): + """ + Writes a fused SmartSPIM channel in OMEZarr + format. This channel was read as a lazy array. + + Parameters + ---------- + image_data: ArrayLike + Lazy readed SmartSPIM channel data + + output_path: PathLike + Path where we want to write the OMEZarr + channel + + voxel_size: List[float] + Voxel size representing the dataset + + final_chunksize: List[int] + Final chunksize we want to use to write + the final dataset + + codec: str + Image codec for writing the Zarr + + compression_level: int + Compression level + + scale_factor: List[int] + Scale factor per axis. The dimensionality + is organized as ZYX. + + n_lvls: int + Number of levels on the pyramid (multiresolution) + for better visualization + + channel_name: str + Channel name we are currently writing + + logger: logging.Logger + Logger object + + """ + # Getting channel color + tmp_channel_name = channel_name.replace(".zarr", "") + em_wav = int(tmp_channel_name.split("_")[-1]) + channel_colors = [wavelength_to_hex(em_wav)] + + # Rechunking dask array + image_data = image_data.rechunk(final_chunksize) + image_data = pad_array_n_d(arr=image_data) + + print(f"About to write {image_data} in {output_path}") + + # Creating Zarr dataset + store = parse_url(path=output_path, mode="w").store + root_group = zarr.group(store=store) + + # Using 1 thread since is in single machine. + # Avoiding the use of multithreaded due to GIL + + if np.issubdtype(image_data.dtype, np.integer): + np_info_func = np.iinfo + + else: + # Floating point + np_info_func = np.finfo + + # Getting min max metadata for the dtype + channel_minmax = [ + (np_info_func(image_data.dtype).min, np_info_func(image_data.dtype).max) + for _ in range(image_data.shape[1]) + ] + + # Setting values for SmartSPIM + # Ideally we would use da.percentile(image_data, (0.1, 95)) + # However, it would take so much time and resources and it is + # not used that much on neuroglancer + channel_startend = [(0.0, 350.0) for _ in range(image_data.shape[1])] + + new_channel_group = root_group.create_group(name=stack_name, overwrite=True) + + # Writing OME-NGFF metadata + write_ome_ngff_metadata( + group=new_channel_group, + arr=image_data, + image_name=channel_name, + n_lvls=n_lvls, + scale_factors=scale_factor, + voxel_size=voxel_size, + channel_names=[channel_name], + channel_colors=channel_colors, + channel_minmax=channel_minmax, + channel_startend=channel_startend, + metadata=_get_pyramid_metadata(), + ) + + performance_report_path = f"{output_path}/report_{stack_name}.html" + + start_time = time.time() + # Writing zarr and performance report + with performance_report(filename=performance_report_path): + logger.info(f"{'='*40}Writing channel {channel_name}{'='*40}") + + # Writing zarr + block_shape = list( + BlockedArrayWriter.get_block_shape( + arr=image_data, target_size_mb=12800 # 51200, + ) + ) + + # Formatting to 5D block shape + block_shape = ([1] * (5 - len(block_shape))) + block_shape + written_pyramid = [] + + # Writing multiple levels + for level in range(n_lvls): + if not level: + array_to_write = image_data + + else: + # It's faster to write the scale and then read it back + # to compute the next scale + previous_scale = da.from_zarr(pyramid_group, pyramid_group.chunks) + new_scale_factor = ( + [1] * (len(previous_scale.shape) - len(scale_factor)) + ) + scale_factor + + previous_scale_pyramid, _ = compute_pyramid( + data=previous_scale, + scale_axis=new_scale_factor, + chunks=image_data.chunksize, + n_lvls=2, + ) + array_to_write = previous_scale_pyramid[-1] + + logger.info(f"[level {level}]: pyramid level: {array_to_write}") + + # Create the scale dataset + pyramid_group = new_channel_group.create_dataset( + name=level, + shape=array_to_write.shape, + chunks=array_to_write.chunksize, + dtype=array_to_write.dtype, + compressor=writing_options, + dimension_separator="/", + overwrite=True, + ) + + # Block Zarr Writer + BlockedArrayWriter.store(array_to_write, pyramid_group, block_shape) + written_pyramid.append(array_to_write) + + end_time = time.time() + logger.info(f"Time to write the dataset: {end_time - start_time}") + logger.info(f"Written pyramid: {written_pyramid}") + + +def convert_stacks_to_ome_zarr(channel_path, logger, output_path): + # channel_path must end with Ex_{wav}_Em_{wav} + + # Setting up local cluster + n_workers = multiprocessing.cpu_count() + logger.info(f"Setting {n_workers} workers") + threads_per_worker = 1 + + # Instantiating local cluster for parallel writing + cluster = LocalCluster( + n_workers=n_workers, + threads_per_worker=threads_per_worker, + processes=True, + memory_limit="auto", + ) + + client = Client(cluster) + + cols = os.listdir(channel_path) + for col in cols: + curr_col = channel_path.joinpath(col) + for row in os.listdir(curr_col): + curr_row = curr_col.joinpath(row) + delayed_stack = PngReader(data_path=f"{curr_row}/*.png").as_dask_array() + print(f"Writing curr stack {curr_row}: {delayed_stack}") + + smartspim_channel_zarr_writer( + image_data=delayed_stack, + output_path=Path(output_path).joinpath(f"{channel_path.stem}"), + voxel_size=[2.0, 1.8, 1.8], + final_chunksize=(128, 128, 128), + scale_factor=[2, 2, 2], + codec="zstd", + compression_level=0, + n_lvls=4, + channel_name=channel_path.stem, + logger=logger, + stack_name=f"{col}_{row.split('_')[-1]}.zarr", + client=client + ) + + client.shutdown() + + +def create_logger(output_log_path: PathLike) -> logging.Logger: + """ + Creates a logger that generates + output logs to a specific path. + + Parameters + ------------ + output_log_path: PathLike + Path where the log is going + to be stored + + Returns + ----------- + logging.Logger + Created logger pointing to + the file path. + """ + CURR_DATE_TIME = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + LOGS_FILE = f"{output_log_path}/fusion_log_{CURR_DATE_TIME}.log" + + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s - %(levelname)s : %(message)s", + datefmt="%Y-%m-%d %H:%M", + handlers=[ + logging.StreamHandler(), + logging.FileHandler(LOGS_FILE, "a"), + ], + force=True, + ) + + logging.disable("DEBUG") + logger = logging.getLogger(__name__) + logger.setLevel(logging.DEBUG) + + return logger + +def main(): + + data_folder = Path(os.path.abspath("../data")) + results_folder = Path(os.path.abspath("../results")) + scratch_folder = Path(os.path.abspath("../scratch")) + + set_dask_config(dask_folder=scratch_folder) + + logger = create_logger(output_log_path=results_folder) + + BASE_PATH = data_folder.joinpath( + "SmartSPIM_692911_2023-10-23_11-27-30/SmartSPIM" + ) + + channels = os.listdir(str(BASE_PATH)) + + for channel in channels: + convert_stacks_to_ome_zarr( + channel_path=BASE_PATH.joinpath(channel), + logger=logger, + output_path=results_folder + ) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/aind_smartspim_data_transformation/smartspim_job.py b/src/aind_smartspim_data_transformation/smartspim_job.py new file mode 100644 index 0000000..9293420 --- /dev/null +++ b/src/aind_smartspim_data_transformation/smartspim_job.py @@ -0,0 +1,236 @@ +"""Module to handle ephys data compression""" + +import logging +import os +import platform +import shutil +import sys +from datetime import datetime +from pathlib import Path +from typing import Iterator, Literal, Optional, Union, List + +import numpy as np +from aind_data_transformation.core import ( + BasicJobSettings, + GenericEtl, + JobResponse, + get_parser, +) +from numcodecs.blosc import Blosc +from pydantic import Field + +from aind_smartspim_data_transformation.models import ( + CompressorName, +) + +from aind_smartspim_data_transformation.io import PngReader +from dask.distributed import Client, LocalCluster +from png_to_zarr import smartspim_channel_zarr_writer + +class SmartspimJobSettings(BasicJobSettings): + """SmartspimCompressionJob settings.""" + + # Compress settings + random_seed: Optional[int] = 0 + compress_write_output_format: Literal["zarr"] = Field( + default="zarr", + description=( + "Output format for compression. Currently, only zarr supported." + ), + title="Write Output Format", + ) + compressor_name: CompressorName = Field( + default=CompressorName.BLOSC, + description="Type of compressor to use.", + title="Compressor Name.", + ) + # It will be safer if these kwargs fields were objects with known schemas + compressor_kwargs: dict = Field( + default={"cname": "zstd", "clevel": 3, 'shuffle': Blosc.SHUFFLE}, + description="Arguments to be used for the compressor.", + title="Compressor Kwargs", + ) + compress_job_save_kwargs: dict = Field( + default={"n_jobs": -1}, # -1 to use all available cpu cores. + description="Arguments for recording save method.", + title="Compress Job Save Kwargs", + ) + chunk_size: int = Field( + default=128, + description="Image chunk size", + title="Image Chunk Size", + ) + + +class SmartspimCompressionJob(GenericEtl[SmartspimJobSettings]): + """Main class to handle smartspim data compression""" + + def _get_delayed_channel_stack( + self, + channel_paths: List[str], + output_dir: str + ) -> Iterator[dict]: + """ + Reads a stack of PNG images into a delayed zarr dataset. + + Returns: + Iterator[dict] + A generator that returns delayed PNG stacks. + + """ + for channel_path in channel_paths: + + cols = os.listdir(channel_path) + for col in cols: + curr_col = channel_path.joinpath(col) + for row in os.listdir(curr_col): + curr_row = curr_col.joinpath(row) + delayed_stack = PngReader(data_path=f"{curr_row}/*.png").as_dask_array() + stack_name = f"{col}_{row.split('_')[-1]}.zarr" + stack_output_path = Path(f"{output_dir}/{channel_path.stem}") + + yield ( + delayed_stack, + stack_output_path, + stack_name + ) + + + def _get_compressor(self) -> Blosc: + """ + Utility method to construct a compressor class. + Returns + ------- + Blosc + Either an instantiated Blosc or WavPack class. + + """ + if self.job_settings.compressor_name == CompressorName.BLOSC: + return Blosc(**self.job_settings.compressor_kwargs) + else: + # TODO: This is validated during the construction of JobSettings, + # so we can probably just remove this exception. + raise Exception( + f"Unknown compressor. Please select one of " + f"{[c for c in CompressorName]}" + ) + + @staticmethod + def _compress_and_write_channels( + read_channel_stacks: Iterator[dict], + compressor: Blosc, + job_kwargs: dict, + output_format: str = "zarr", + ) -> None: + + if job_kwargs["n_jobs"] == -1: + job_kwargs["n_jobs"] = os.cpu_count() + + n_workers = job_kwargs["n_jobs"] + threads_per_worker = 1 + + # Instantiating local cluster for parallel writing + cluster = LocalCluster( + n_workers=n_workers, + threads_per_worker=threads_per_worker, + processes=True, + memory_limit="auto", + ) + + client = Client(cluster) + i = 0 + for delayed_arr, output_path, stack_name in read_channel_stacks: + if i == 2: + break + + i += 1 + smartspim_channel_zarr_writer( + image_data=delayed_arr, + output_path=output_path, + voxel_size=[2.0, 1.8, 1.8], + final_chunksize=(128, 128, 128), + scale_factor=[2, 2, 2], + n_lvls=4, + channel_name=output_path.stem, + stack_name=stack_name, + client=client, + logger=logging, + writing_options=compressor, + ) + + def _compress_raw_data(self) -> None: + """Compresses smartspim data""" + + # Clip the data + logging.info("Converting PNG to OMEZarr. This may take some minutes.") + output_compressed_data = ( + self.job_settings.output_directory / "SPIM" + ) + + channel_paths = [ + Path(self.job_settings.input_source).joinpath(folder) + for folder in os.listdir( + self.job_settings.input_source + ) + ] + + # Get channel stack iterators and delayed arrays + read_delayed_channel_stacks = self._get_delayed_channel_stack( + channel_paths=channel_paths, + output_dir=output_compressed_data, + ) + + # Getting compressors + compressor = self._get_compressor() + + # Writing compressed stacks + self._compress_and_write_channels( + read_channel_stacks=read_delayed_channel_stacks, + compressor=compressor, + output_format=self.job_settings.compress_write_output_format, + job_kwargs=self.job_settings.compress_job_save_kwargs + ) + logging.info("Finished compressing source data.") + + def run_job(self) -> JobResponse: + """ + Main public method to run the compression job + Returns + ------- + JobResponse + Information about the job that can be used for metadata downstream. + + """ + job_start_time = datetime.now() + self._compress_raw_data() + job_end_time = datetime.now() + return JobResponse( + status_code=200, + message=f"Job finished in: {job_end_time-job_start_time}", + data=None, + ) + + +def main(): + """ Main function """ + sys_args = sys.argv[1:] + parser = get_parser() + cli_args = parser.parse_args(sys_args) + if cli_args.job_settings is not None: + job_settings = SmartspimJobSettings.model_validate_json( + cli_args.job_settings + ) + elif cli_args.config_file is not None: + job_settings = SmartspimJobSettings.from_config_file(cli_args.config_file) + else: + # Construct settings from env vars + job_settings = SmartspimJobSettings( + input_source="/data/SmartSPIM_714635_2024-03-18_10-47-48/SmartSPIM", + output_directory="/scratch/" + ) + job = SmartspimCompressionJob(job_settings=job_settings) + job_response = job.run_job() + logging.info(job_response.model_dump_json()) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/aind_smartspim_data_transformation/zarr_utilities.py b/src/aind_smartspim_data_transformation/zarr_utilities.py new file mode 100644 index 0000000..5f599ad --- /dev/null +++ b/src/aind_smartspim_data_transformation/zarr_utilities.py @@ -0,0 +1,621 @@ +""" +Module to the zarr utilities +""" + +import multiprocessing +import os +import time +from pathlib import Path +from typing import Optional, Tuple, Union + +import dask +import numpy as np +import pims +from dask.array import concatenate, pad +from dask.array.core import Array +from dask.base import tokenize +from natsort import natsorted +from numcodecs import blosc +from skimage.io import imread as sk_imread +import dask.array as da + +PathLike = Union[str, Path] +ArrayLike = Union[dask.array.core.Array, np.ndarray] +blosc.use_threads = False + + +def add_leading_dim(data: ArrayLike) -> ArrayLike: + """ + Adds a leading dimension + + Parameters + ------------------------ + + data: ArrayLike + Input array that will have the + leading dimension + + Returns + ------------------------ + + ArrayLike: + Array with the new dimension in front. + """ + return data[None, ...] + + +def pad_array_n_d(arr: ArrayLike, dim: int = 5) -> ArrayLike: + """ + Pads a daks array to be in a 5D shape. + + Parameters + ------------------------ + + arr: ArrayLike + Dask/numpy array that contains image data. + dim: int + Number of dimensions that the array will be padded + + Returns + ------------------------ + ArrayLike: + Padded dask/numpy array. + """ + if dim > 5: + raise ValueError("Padding more than 5 dimensions is not supported.") + + while arr.ndim < dim: + arr = arr[np.newaxis, ...] + return arr + + +def read_image_directory_structure(folder_dir: PathLike) -> dict: + """ + Creates a dictionary representation of all the images + saved by folder/col_N/row_N/images_N.[file_extention] + + Parameters + ------------------------ + folder_dir:PathLike + Path to the folder where the images are stored + + Returns + ------------------------ + dict: + Dictionary with the image representation where: + {channel_1: ... {channel_n: {col_1: ... col_n: {row_1: ... row_n: [image_0, ..., image_n]} } } } + """ + + directory_structure = {} + folder_dir = Path(folder_dir) + + channel_paths = natsorted( + [ + folder_dir.joinpath(folder) + for folder in os.listdir(folder_dir) + if os.path.isdir(folder_dir.joinpath(folder)) + ] + ) + + for channel_idx in range(len(channel_paths)): + directory_structure[channel_paths[channel_idx]] = {} + + cols = natsorted(os.listdir(channel_paths[channel_idx])) + + for col in cols: + possible_col = channel_paths[channel_idx].joinpath(col) + + if os.path.isdir(possible_col): + directory_structure[channel_paths[channel_idx]][col] = {} + + rows = natsorted(os.listdir(possible_col)) + + for row in rows: + possible_row = ( + channel_paths[channel_idx].joinpath(col).joinpath(row) + ) + + if os.path.isdir(possible_row): + directory_structure[channel_paths[channel_idx]][col][ + row + ] = natsorted(os.listdir(possible_row)) + + return directory_structure + + +def lazy_tiff_reader( + filename: PathLike, + shape: Optional[Tuple[int]] = None, + dtype: Optional[type] = None, +): + """ + Creates a dask array to read an image located in a specific path. + + Parameters + ------------------------ + + filename: PathLike + Path to the image + + Returns + ------------------------ + + dask.array.core.Array + Array representing the image data + """ + name = "imread-%s" % tokenize(filename, map(os.path.getmtime, filename)) + + if dtype is None or shape is None: + try: + with pims.open(filename) as imgs: + dtype = np.dtype(imgs.pixel_type) + shape = (1,) + (len(imgs),) + imgs.frame_shape + except: + return None + + key = [(name,) + (0,) * len(shape)] + value = [(add_leading_dim, (sk_imread, filename))] + dask_arr = dict(zip(key, value)) + chunks = tuple((d,) for d in shape) + + return Array(dask_arr, name, chunks, dtype) + + +def fix_image_diff_dims( + new_arr: ArrayLike, chunksize: Tuple[int], len_chunks: int, work_axis: int +) -> ArrayLike: + """ + Fixes the array dimension to match the shape of + the chunksize. + + Parameters + ------------------------ + + new_arr: ArrayLike + Array to be fixed + + chunksize: Tuple[int] + Chunksize of the original array + + len_chunks: int + Length of the chunksize. Used as a + parameter to avoid computing it + multiple times + + work_axis: int + Axis to concatenate. If the different + axis matches this one, there is no need + to fix the array dimension + + Returns + ------------------------ + + ArrayLike + Array with the new dimensions + """ + + zeros_dim = [] + diff_dim = -1 + c = 0 + + for chunk_idx in range(len_chunks): + new_chunk_dim = new_arr.chunksize[chunk_idx] + + if new_chunk_dim != chunksize[chunk_idx]: + c += 1 + diff_dim = chunk_idx + + zeros_dim.append(abs(chunksize[chunk_idx] - new_chunk_dim)) + + if c > 1: + raise ValueError("Block has two different dimensions") + else: + if (diff_dim - len_chunks) == work_axis: + return new_arr + + n_pad = tuple(tuple((0, dim)) for dim in zeros_dim) + new_arr = pad( + new_arr, pad_width=n_pad, mode="constant", constant_values=0 + ).rechunk(chunksize) + + return new_arr + + +def concatenate_dask_arrays(arr_1: ArrayLike, arr_2: ArrayLike, axis: int) -> ArrayLike: + """ + Concatenates two arrays in a given + dimension + + Parameters + ------------------------ + + arr_1: ArrayLike + Array 1 that will be concatenated + + arr_2: ArrayLike + Array 2 that will be concatenated + + axis: int + Concatenation axis + + Returns + ------------------------ + + ArrayLike + Concatenated array that contains + arr_1 and arr_2 + """ + + shape_arr_1 = arr_1.shape + shape_arr_2 = arr_2.shape + + if shape_arr_1 != shape_arr_2: + slices = [] + dims = len(shape_arr_1) + + for shape_dim_idx in range(dims): + if shape_arr_1[shape_dim_idx] > shape_arr_2[shape_dim_idx] and ( + shape_dim_idx - dims != axis + ): + raise ValueError( + f""" + Array 1 {shape_arr_1} must have + a smaller shape than array 2 {shape_arr_2} + except for the axis dimension {shape_dim_idx} + {dims} {shape_dim_idx - dims} {axis} + """ + ) + + if shape_arr_1[shape_dim_idx] != shape_arr_2[shape_dim_idx]: + slices.append(slice(0, shape_arr_1[shape_dim_idx])) + + else: + slices.append(slice(None)) + + slices = tuple(slices) + arr_2 = arr_2[slices] + + try: + res = concatenate([arr_1, arr_2], axis=axis) + except ValueError: + raise ValueError( + f""" + Unable to cancat arrays - Shape 1: + {shape_arr_1} shape 2: {shape_arr_2} + """ + ) + + return res + + +def read_chunked_stitched_image_per_channel( + directory_structure: dict, + channel_name: str, + start_slice: int, + end_slice: int, +) -> ArrayLike: + """ + Creates a dask array of the whole image volume + based on image chunks preserving the chunksize. + + Parameters + ------------------ + + directory_structure:dict + dictionary to store paths of images with the following structure: + {channel_1: ... {channel_n: {col_1: ... col_n: {row_1: ... row_n: [image_0, ..., image_n]} } } } + + channel_name : str + Channel name to reconstruct the image volume + + start_slice: int + When using multiprocessing, this is + the start slice the worker will use for + the array concatenation + + end_slice: int + When using multiprocessing, this is + the final slice the worker will use for + the array concatenation + + Returns + ------------------------ + + ArrayLike + Array with the image volume + """ + concat_z_3d_blocks = concat_horizontals = horizontal = None + + # Getting col structure + rows = list(directory_structure.values())[0] + rows_paths = list(rows.keys()) + first = True + + for slice_pos in range(start_slice, end_slice): + idx_col = 0 + idx_row = 0 + + concat_horizontals = None + + for row_name in rows_paths: + idx_row = 0 + horizontal = [] + shape = None + dtype = None + column_names = list(directory_structure[channel_name][row_name].keys()) + n_cols = len(column_names) + + check_shape = (1, 256,256,256) + + for column_name_idx in range(n_cols): + valid_image = True + column_name = column_names[column_name_idx] + last_col = column_name_idx == n_cols - 1 + + if last_col: + shape = None + dtype = None + + try: + slice_name = directory_structure[channel_name][row_name][ + column_name + ][slice_pos] + + filepath = str( + channel_name.joinpath(row_name) + .joinpath(column_name) + .joinpath(slice_name) + ) + + new_arr = lazy_tiff_reader(filepath)#, dtype=dtype, shape=shape) + + if (shape is None or dtype is None) and new_arr: + shape = new_arr.shape + dtype = new_arr.dtype + last_col = False + + except ValueError: + print(f"{filepath} -> No valid image in ", slice_name, slice_pos) +# valid_image = False +# new_arr = da.zeros(check_shape, dtype=new_arr.dtype) + + if valid_image: + horizontal.append(new_arr) + + idx_row += 1 + + # Concatenating horizontally lazy images + print("Before fixing: ", horizontal) + for i in range(len(horizontal)): + if horizontal[i] is not None: + shape = horizontal[i].shape + dtype = horizontal[i].dtype + break + + for i in range(len(horizontal)): + if horizontal[i] is None and shape is not None and dtype is not None: + horizontal[i] = da.zeros(shape, dtype=dtype) + + print("Before concat: ", horizontal) + horizontal_concat = concatenate(horizontal, axis=-1) + + if not idx_col: + concat_horizontals = horizontal_concat + else: + concat_horizontals = concatenate_dask_arrays( + arr_1=concat_horizontals, arr_2=horizontal_concat, axis=-2 + ) + + idx_col += 1 + + if first: + concat_z_3d_blocks = concat_horizontals + first = False + + else: + concat_z_3d_blocks = concatenate_dask_arrays( + arr_1=concat_z_3d_blocks, arr_2=concat_horizontals, axis=-3 + ) + + return concat_z_3d_blocks, [start_slice, end_slice] + + +def _read_chunked_stitched_image_per_channel(args_dict: dict): + """ + Function used to be dispatched to workers + by using multiprocessing + """ + return read_chunked_stitched_image_per_channel(**args_dict) + + +def channel_parallel_reading( + directory_structure: dict, + channel_idx: int, + workers: Optional[int] = 0, + chunks: Optional[int] = 1, + ensure_parallel: Optional[bool] = True, +) -> ArrayLike: + """ + Creates a dask array of the whole image channel volume based + on image chunks preserving the chunksize and using + multiprocessing. + + Parameters + ------------------------ + + directory_structure: dict + dictionary to store paths of images with the following structure: + {channel_1: ... {channel_n: {col_1: ... col_n: {row_1: ... row_n: [image_0, ..., image_n]} } } } + + channel_name : str + Channel name to reconstruct the image volume + + sample_img: ArrayLike + Image used as guide for the chunksize + + workers: Optional[int] + Number of workers that will be used + for reading the chunked image. + Default value 0, it means that the + available number of cores will be used. + + chunks: Optional[int] + Chunksize of the 3D chunked images. + + ensure_parallel: Optional[bool] + True if we want to read the images in + parallel. False, otherwise. + + Returns + ------------------------ + + ArrayLike + Array with the image channel volume + """ + if workers == 0: + workers = multiprocessing.cpu_count() + + cols = list(directory_structure.values())[0] + n_images = len(list(list(cols.values())[0].values())[0]) + # print(f"n_images: {n_images}") + + channel_paths = list(directory_structure.keys()) + dask_array = None + + if n_images < workers and ensure_parallel: + workers = n_images + + if n_images < workers or not ensure_parallel: + dask_array = read_chunked_stitched_image_per_channel( + directory_structure=directory_structure, + channel_name=channel_paths[channel_idx], + start_slice=0, + end_slice=n_images, + )[0] + print(f"No need for parallel reading... {dask_array}") + + else: + images_per_worker = n_images // workers + print( + f"Setting workers to {workers} - {images_per_worker} - total images: {n_images}" + ) + + # Getting 5 dim image TCZYX + args = [] + start_slice = 0 + end_slice = images_per_worker + + for idx_worker in range(workers): + arg_dict = { + "directory_structure": directory_structure, + "channel_name": channel_paths[channel_idx], + "start_slice": start_slice, + "end_slice": end_slice, + } + + args.append(arg_dict) + + if idx_worker + 1 == workers - 1: + start_slice = end_slice + end_slice = n_images + else: + start_slice = end_slice + end_slice += images_per_worker + + res = [] + with multiprocessing.Pool(workers) as pool: + results = pool.imap( + _read_chunked_stitched_image_per_channel, + args, + chunksize=chunks, + ) + + for pos in results: + res.append(pos) + + for res_idx in range(len(res)): + if not res_idx: + dask_array = res[res_idx][0] + else: + dask_array = concatenate([dask_array, res[res_idx][0]], axis=-3) + + print(f"Slides: {res[res_idx][1]}") + + return dask_array + + +def parallel_read_chunked_stitched_multichannel_image( + directory_structure: dict, + workers: Optional[int] = 0, + ensure_parallel: Optional[bool] = True, + divide_channels: Optional[bool] = True, +) -> ArrayLike: + """ + Creates a dask array of the whole image volume based + on image chunks preserving the chunksize and using + multiprocessing. + + Parameters + ------------------------ + + directory_structure: dict + dictionary to store paths of images with the following structure: + {channel_1: ... {channel_n: {col_1: ... col_n: {row_1: ... row_n: [image_0, ..., image_n]} } } } + + sample_img: ArrayLike + Image used as guide for the chunksize + + workers: Optional[int] + Number of workers that will be used + for reading the chunked image. + Default value 0, it means that the + available number of cores will be used. + + ensure_parallel: Optional[bool] + True if we want to read the images in + parallel. False, otherwise. + + Returns + ------------------------ + + ArrayLike + Array with the image channel volume + """ + + multichannel_image = None + + channel_paths = list(directory_structure.keys()) + + multichannels = [] + read_channels = {} + print(f"Channel in directory structure: {channel_paths}") + + for channel_idx in range(len(channel_paths)): + print(f"Reading images from {channel_paths[channel_idx]}") + start_time = time.time() + read_chunked_channel = channel_parallel_reading( + directory_structure, + channel_idx, + workers=workers, + ensure_parallel=ensure_parallel, + ) + end_time = time.time() + + print(f"Time reading single channel image: {end_time - start_time}") + + # Padding to 4D if necessary + ch_name = Path(channel_paths[channel_idx]).name + + read_chunked_channel = pad_array_n_d(read_chunked_channel, 4) + multichannels.append(read_chunked_channel) + read_channels[ch_name] = read_chunked_channel + + if divide_channels: + return read_channels + + if len(multichannels) > 1: + multichannel_image = concatenate(multichannels, axis=0) + else: + multichannel_image = multichannels[0] + + return multichannel_image diff --git a/src/aind_smartspim_data_transformation/zarr_writer.py b/src/aind_smartspim_data_transformation/zarr_writer.py new file mode 100644 index 0000000..9d86a44 --- /dev/null +++ b/src/aind_smartspim_data_transformation/zarr_writer.py @@ -0,0 +1,218 @@ +""" +This module defines a class that takes +big chunks (compilation of chunks) from +a dask array and writes it on disk in +zarr format +""" + +from typing import Generator, Tuple + +import dask.array as da +import numpy as np +from numpy.typing import ArrayLike + + +def _get_size(shape: Tuple[int, ...], itemsize: int) -> int: + """ + Return the size of an array with the given shape, in bytes + Args: + shape: the shape of the array + itemsize: number of bytes per array element + Returns: + the size of the array, in bytes + """ + if any(s <= 0 for s in shape): + raise ValueError("shape must be > 0 in all dimensions") + return np.product(shape) * itemsize + + +def _closer_to_target( + shape1: Tuple[int, ...], + shape2: Tuple[int, ...], + target_bytes: int, + itemsize: int, +) -> Tuple[int, ...]: + """ + Given two shapes with the same number of dimensions, + find which one is closer to target_bytes. + Args: + shape1: the first shape + shape2: the second shape + target_bytes: the target size for the returned shape + itemsize: number of bytes per array element + """ + size1 = _get_size(shape1, itemsize) + size2 = _get_size(shape2, itemsize) + if abs(size1 - target_bytes) < abs(size2 - target_bytes): + return shape1 + return shape2 + + +def expand_chunks( + chunks: Tuple[int, int, int], + data_shape: Tuple[int, int, int], + target_size: int, + itemsize: int, + mode: str = "iso", +) -> Tuple[int, int, int]: + """ + Given the shape and chunk size of a pre-chunked 3D array, determine the optimal chunk shape + closest to target_size. Expanded chunk dimensions are an integer multiple of the base chunk dimension, + to ensure optimal access patterns. + Args: + chunks: the shape of the input array chunks + data_shape: the shape of the input array + target_size: target chunk size in bytes + itemsize: the number of bytes per array element + mode: chunking strategy. Must be one of "cycle", or "iso" + Returns: + the optimal chunk shape + """ + if any(c < 1 for c in chunks): + raise ValueError("chunks must be >= 1 for all dimensions") + if any(s < 1 for s in data_shape): + raise ValueError("data_shape must be >= 1 for all dimensions") + if any(c > s for c, s in zip(chunks, data_shape)): + raise ValueError("chunks cannot be larger than data_shape in any dimension") + if target_size <= 0: + raise ValueError("target_size must be > 0") + if itemsize <= 0: + raise ValueError("itemsize must be > 0") + if mode == "cycle": + # get the spatial dimensions only + current = np.array(chunks, dtype=np.uint64) + prev = current.copy() + idx = 0 + ndims = len(current) + while _get_size(current, itemsize) < target_size: + prev = current.copy() + current[idx % ndims] = min( + data_shape[idx % ndims], current[idx % ndims] * 2 + ) + idx += 1 + if all(c >= s for c, s in zip(current, data_shape)): + break + expanded = _closer_to_target(current, prev, target_size, itemsize) + elif mode == "iso": + initial = np.array(chunks, dtype=np.uint64) + current = initial + prev = current + i = 2 + while _get_size(current, itemsize) < target_size: + prev = current + current = initial * i + current = ( + min(data_shape[0], current[0]), + min(data_shape[1], current[1]), + min(data_shape[2], current[2]), + ) + i += 1 + if all(c >= s for c, s in zip(current, data_shape)): + break + expanded = _closer_to_target(current, prev, target_size, itemsize) + else: + raise ValueError(f"Invalid mode {mode}") + + return tuple(int(d) for d in expanded) + + +class BlockedArrayWriter: + """ + Static class to write a lazy array + in big chunks to OMEZarr + """ + + @staticmethod + def gen_slices( + arr_shape: Tuple[int, ...], block_shape: Tuple[int, ...] + ) -> Generator: + """ + Generate a series of slices that can be used to traverse an array in blocks of a given shape. + + The method generates tuples of slices, each representing a block of the array. The blocks are generated by + iterating over the array in steps of the block shape along each dimension. + + Parameters + ---------- + arr_shape : tuple of int + The shape of the array to be sliced. + + block_shape : tuple of int + The desired shape of the blocks. This should be a tuple of integers representing the size of each + dimension of the block. The length of `block_shape` should be equal to the length of + `arr_shape`. If the array shape is not divisible by the block shape along a dimension, the last slice + along that dimension is truncated. + + Returns + ------- + generator of tuple of slice + A generator yielding tuples of slices. Each tuple can be used to index the input array. + """ + if len(arr_shape) != len(block_shape): + raise Exception("array shape and block shape have different lengths") + + def _slice_along_dim(dim: int) -> Generator: + """A helper generator function that slices along one dimension.""" + # Base case: if the dimension is beyond the last one, return an empty tuple + if dim >= len(arr_shape): + yield () + else: + # Iterate over the current dimension in steps of the block size + for i in range(0, arr_shape[dim], block_shape[dim]): + # Calculate the end index for this block + end_i = min(i + block_shape[dim], arr_shape[dim]) + # Generate slices for the remaining dimensions + for rest in _slice_along_dim(dim + 1): + yield (slice(i, end_i),) + rest + + # Start slicing along the first dimension + return _slice_along_dim(0) + + @staticmethod + def store(in_array: da.Array, out_array: ArrayLike, block_shape: tuple) -> None: + """ + Partitions the last 3 dimensions of a Dask array into non-overlapping blocks and + writes them sequentially to a Zarr array. This is meant to reduce the scheduling burden + for massive (terabyte-scale) arrays. + + :param in_array: The input Dask array + :param block_shape: Tuple of (block_depth, block_height, block_width) + :param out_array: The output array + """ + # Iterate through the input array in steps equal to the block shape dimensions + for sl in BlockedArrayWriter.gen_slices(in_array.shape, block_shape): + block = in_array[sl] + da.store( + block, + out_array, + regions=sl, + lock=False, + compute=True, + return_stored=False, + ) + + @staticmethod + def get_block_shape(arr, target_size_mb=409600, mode="cycle"): + """ + Given the shape and chunk size of a pre-chunked array, determine the optimal block shape + closest to target_size. Expanded block dimensions are an integer multiple of the chunk dimension + to ensure optimal access patterns. + Args: + arr: the input array + target_size_mb: target block size in megabytes, default is 409600 + mode: strategy. Must be one of "cycle", or "iso" + Returns: + the block shape + """ + if isinstance(arr, da.Array): + chunks = arr.chunksize[-3:] + else: + chunks = arr.chunks[-3:] + + return expand_chunks( + chunks, + arr.shape[-3:], + target_size_mb * 1024**2, + arr.itemsize, + mode, + ) From f2c2d3509b17f97ee3dd5891245ecf21b24763e7 Mon Sep 17 00:00:00 2001 From: camilolaiton Date: Wed, 29 May 2024 20:45:26 +0000 Subject: [PATCH 2/9] updating singularity container to use dask mpi --- scripts/singularity_build.def | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/scripts/singularity_build.def b/scripts/singularity_build.def index b2f9e1a..e376cf1 100644 --- a/scripts/singularity_build.def +++ b/scripts/singularity_build.def @@ -1,12 +1,25 @@ Bootstrap: docker From: python:3.10-bullseye -Stage: build %setup # Copy project directory into container cp -R . ${SINGULARITY_ROOTFS}/aind-smartspim-data-transformation %post + # Installing dask mpi + wget https://www.mpich.org/static/downloads/3.2/mpich-3.2.tar.gz + tar xfz mpich-3.2.tar.gz + rm mpich-3.2.tar.gz + mkdir mpich-build + cd mpich-build + ../mpich-3.2/configure --disable-fortran 2>&1 | tee c.txt + make 2>&1 | tee m.txt + make install 2>&1 | tee mi.txt + cd .. + cd ${SINGULARITY_ROOTFS}/aind-smartspim-data-transformation pip install . --no-cache-dir - rm -rf ${SINGULARITY_ROOTFS}/aind-smartspim-data-transformation \ No newline at end of file + rm -rf ${SINGULARITY_ROOTFS}/aind-smartspim-data-transformation + + pip install mpi4py --no-cache-dir + pip install dask_mpi \ No newline at end of file From bbbeb8806f764f29739a6461e90fbf1137b85d86 Mon Sep 17 00:00:00 2001 From: camilolaiton Date: Thu, 30 May 2024 15:06:41 +0000 Subject: [PATCH 3/9] adding dask utilities to create a slurm cluster if using the HPC --- .../__init__.py | 1 + .../dask_utils.py | 121 ++++++++++++++++++ .../io/_io.py | 3 +- .../io/utils.py | 6 +- .../models.py | 3 +- .../png_to_zarr.py | 81 +++++++----- .../smartspim_job.py | 75 +++++------ .../zarr_utilities.py | 46 ++++--- .../zarr_writer.py | 12 +- 9 files changed, 256 insertions(+), 92 deletions(-) create mode 100644 src/aind_smartspim_data_transformation/dask_utils.py diff --git a/src/aind_smartspim_data_transformation/__init__.py b/src/aind_smartspim_data_transformation/__init__.py index 831e353..e0247d0 100644 --- a/src/aind_smartspim_data_transformation/__init__.py +++ b/src/aind_smartspim_data_transformation/__init__.py @@ -1,2 +1,3 @@ """Init package""" + __version__ = "0.0.1" diff --git a/src/aind_smartspim_data_transformation/dask_utils.py b/src/aind_smartspim_data_transformation/dask_utils.py new file mode 100644 index 0000000..d149cc9 --- /dev/null +++ b/src/aind_smartspim_data_transformation/dask_utils.py @@ -0,0 +1,121 @@ +import logging +import os +import socket +from enum import Enum +from typing import Optional, Tuple + +import distributed +import requests +from distributed import Client, LocalCluster + +try: + from dask_mpi import initialize + + DASK_MPI_INSTALLED = True +except ImportError: + DASK_MPI_INSTALLED = False + +LOGGER = logging.getLogger(__name__) + + +class Deployment(Enum): + LOCAL = "local" + SLURM = "slurm" + + +def log_dashboard_address( + client: distributed.Client, login_node_address: str = "hpc-login" +) -> None: + """ + Logs the terminal command required to access the Dask dashboard + + Args: + client: the Client instance + login_node_address: the address of the cluster login node + """ + host = client.run_on_scheduler(socket.gethostname) + port = client.scheduler_info()["services"]["dashboard"] + user = os.getenv("USER") + LOGGER.info( + f"To access the dashboard, run the following in a terminal: ssh -L {port}:{host}:{port} {user}@" + f"{login_node_address} " + ) + + +def get_deployment() -> str: + if os.getenv("SLURM_JOBID") is None: + deployment = Deployment.LOCAL.value + else: + # we're running on the Allen HPC + deployment = Deployment.SLURM.value + return deployment + + +def get_client( + deployment: str = Deployment.LOCAL.value, + worker_options: Optional[dict] = None, + n_workers: int = 1, + processes=True, +) -> Tuple[distributed.Client, int]: + """ + Create a distributed Client + + Args: + deployment: the type of deployment. Either "local" or "slurm" + worker_options: a dictionary of options to pass to the worker class + n_workers: the number of workers (only applies to "local" deployment) + + Returns: + the distributed Client and number of workers + """ + if deployment == Deployment.SLURM.value: + if not DASK_MPI_INSTALLED: + raise ImportError( + "dask-mpi must be installed to use the SLURM deployment" + ) + if worker_options is None: + worker_options = {} + slurm_job_id = os.getenv("SLURM_JOBID") + if slurm_job_id is None: + raise Exception( + "SLURM_JOBID environment variable is not set. Are you running under SLURM?" + ) + initialize( + nthreads=int(os.getenv("SLURM_CPUS_PER_TASK", 1)), + local_directory=f"/scratch/fast/{slurm_job_id}", + worker_class="distributed.nanny.Nanny", + worker_options=worker_options, + ) + client = Client() + log_dashboard_address(client) + n_workers = int(os.getenv("SLURM_NTASKS")) + elif deployment == Deployment.LOCAL.value: + client = Client( + LocalCluster( + n_workers=n_workers, processes=processes, threads_per_worker=1 + ) + ) + else: + raise NotImplementedError + return client, n_workers + + +def cancel_slurm_job( + job_id: str, api_url: str, headers: dict +) -> requests.Response: + """ + Attempt to release resources and cancel the job + + Args: + job_id: the SLURM job ID + api_url: the URL of the SLURM REST API. E.g., "http://myhost:80/api/slurm/v0.0.36" + + Raises: + HTTPError: if the request to cancel the job fails + """ + # Attempt to release resources and cancel the job + # Workaround for https://github.com/dask/dask-mpi/issues/87 + endpoint = f"{api_url}/job/{job_id}" + response = requests.delete(endpoint, headers=headers) + + return response diff --git a/src/aind_smartspim_data_transformation/io/_io.py b/src/aind_smartspim_data_transformation/io/_io.py index 6b0d216..ae9b686 100644 --- a/src/aind_smartspim_data_transformation/io/_io.py +++ b/src/aind_smartspim_data_transformation/io/_io.py @@ -36,6 +36,7 @@ PathLike = Union[str, Path] ArrayLike = Union[da.Array, np.ndarray] + class ImageReader(ABC): """ Abstract class to create image readers @@ -640,4 +641,4 @@ def create( if ext not in self.__extensions: raise NotImplementedError(f"File type {ext} not supported") - return self.factory[ext](path_cast, **kwargs) \ No newline at end of file + return self.factory[ext](path_cast, **kwargs) diff --git a/src/aind_smartspim_data_transformation/io/utils.py b/src/aind_smartspim_data_transformation/io/utils.py index 0925bb4..aa5ad55 100644 --- a/src/aind_smartspim_data_transformation/io/utils.py +++ b/src/aind_smartspim_data_transformation/io/utils.py @@ -4,10 +4,8 @@ import json import os -from typing import Optional - from pathlib import Path -from typing import Union +from typing import Optional, Union import dask.array as da import numpy as np @@ -109,4 +107,4 @@ def read_json_as_dict(filepath: str) -> dict: with open(filepath) as json_file: dictionary = json.load(json_file) - return dictionary \ No newline at end of file + return dictionary diff --git a/src/aind_smartspim_data_transformation/models.py b/src/aind_smartspim_data_transformation/models.py index 19ecd03..5ca3517 100644 --- a/src/aind_smartspim_data_transformation/models.py +++ b/src/aind_smartspim_data_transformation/models.py @@ -1,10 +1,11 @@ """Helpful models used in the ephys compression job""" from enum import Enum - from typing import List + from numcodecs import Blosc + class CompressorName(str, Enum): """Enum for compression algorithms a user can select""" diff --git a/src/aind_smartspim_data_transformation/png_to_zarr.py b/src/aind_smartspim_data_transformation/png_to_zarr.py index 15172e9..bee72ec 100644 --- a/src/aind_smartspim_data_transformation/png_to_zarr.py +++ b/src/aind_smartspim_data_transformation/png_to_zarr.py @@ -9,18 +9,23 @@ import multiprocessing import os import time +from datetime import datetime +from pathlib import Path from typing import Dict, Hashable, List, Optional, Sequence, Tuple, Union, cast import dask import dask.array as da + # import matplotlib.pyplot as plt import numpy as np import pims import xarray_multiscale import zarr +from dask import config as da_cfg from dask.array.core import Array from dask.base import tokenize from dask.distributed import Client, LocalCluster, performance_report + # from distributed import wait from numcodecs import blosc from ome_zarr.format import CurrentFormat @@ -28,12 +33,10 @@ from ome_zarr.writer import write_multiscales_metadata from skimage.io import imread as sk_imread -from aind_smartspim_data_transformation.zarr_writer import BlockedArrayWriter -from aind_smartspim_data_transformation.zarr_utilities import * -from pathlib import Path from aind_smartspim_data_transformation.io import PngReader -from datetime import datetime -from dask import config as da_cfg +from aind_smartspim_data_transformation.zarr_utilities import * +from aind_smartspim_data_transformation.zarr_writer import BlockedArrayWriter + def set_dask_config(dask_folder: str): """ @@ -59,6 +62,7 @@ def set_dask_config(dask_folder: str): } ) + def _build_ome( data_shape: Tuple[int, ...], image_name: str, @@ -87,7 +91,9 @@ def _build_ome( Dict: An "omero" metadata object suitable for writing to ome-zarr """ if channel_names is None: - channel_names = [f"Channel:{image_name}:{i}" for i in range(data_shape[1])] + channel_names = [ + f"Channel:{image_name}:{i}" for i in range(data_shape[1]) + ] if channel_colors is None: channel_colors = [i for i in range(data_shape[1])] if channel_minmax is None: @@ -167,7 +173,9 @@ def _compute_scales( ] ] if translation is not None: - transforms[0].append({"type": "translation", "translation": translation}) + transforms[0].append( + {"type": "translation", "translation": translation} + ) chunk_sizes = [] lastz = data_shape[2] lasty = data_shape[3] @@ -528,7 +536,7 @@ def smartspim_channel_zarr_writer( logger: logging.Logger, stack_name: str, writing_options, - client + client, ): """ Writes a fused SmartSPIM channel in OMEZarr @@ -579,7 +587,7 @@ def smartspim_channel_zarr_writer( # Rechunking dask array image_data = image_data.rechunk(final_chunksize) image_data = pad_array_n_d(arr=image_data) - + print(f"About to write {image_data} in {output_path}") # Creating Zarr dataset @@ -598,7 +606,10 @@ def smartspim_channel_zarr_writer( # Getting min max metadata for the dtype channel_minmax = [ - (np_info_func(image_data.dtype).min, np_info_func(image_data.dtype).max) + ( + np_info_func(image_data.dtype).min, + np_info_func(image_data.dtype).max, + ) for _ in range(image_data.shape[1]) ] @@ -608,7 +619,9 @@ def smartspim_channel_zarr_writer( # not used that much on neuroglancer channel_startend = [(0.0, 350.0) for _ in range(image_data.shape[1])] - new_channel_group = root_group.create_group(name=stack_name, overwrite=True) + new_channel_group = root_group.create_group( + name=stack_name, overwrite=True + ) # Writing OME-NGFF metadata write_ome_ngff_metadata( @@ -624,7 +637,7 @@ def smartspim_channel_zarr_writer( channel_startend=channel_startend, metadata=_get_pyramid_metadata(), ) - + performance_report_path = f"{output_path}/report_{stack_name}.html" start_time = time.time() @@ -651,7 +664,9 @@ def smartspim_channel_zarr_writer( else: # It's faster to write the scale and then read it back # to compute the next scale - previous_scale = da.from_zarr(pyramid_group, pyramid_group.chunks) + previous_scale = da.from_zarr( + pyramid_group, pyramid_group.chunks + ) new_scale_factor = ( [1] * (len(previous_scale.shape) - len(scale_factor)) ) + scale_factor @@ -678,7 +693,9 @@ def smartspim_channel_zarr_writer( ) # Block Zarr Writer - BlockedArrayWriter.store(array_to_write, pyramid_group, block_shape) + BlockedArrayWriter.store( + array_to_write, pyramid_group, block_shape + ) written_pyramid.append(array_to_write) end_time = time.time() @@ -688,12 +705,12 @@ def smartspim_channel_zarr_writer( def convert_stacks_to_ome_zarr(channel_path, logger, output_path): # channel_path must end with Ex_{wav}_Em_{wav} - + # Setting up local cluster n_workers = multiprocessing.cpu_count() logger.info(f"Setting {n_workers} workers") threads_per_worker = 1 - + # Instantiating local cluster for parallel writing cluster = LocalCluster( n_workers=n_workers, @@ -709,7 +726,9 @@ def convert_stacks_to_ome_zarr(channel_path, logger, output_path): curr_col = channel_path.joinpath(col) for row in os.listdir(curr_col): curr_row = curr_col.joinpath(row) - delayed_stack = PngReader(data_path=f"{curr_row}/*.png").as_dask_array() + delayed_stack = PngReader( + data_path=f"{curr_row}/*.png" + ).as_dask_array() print(f"Writing curr stack {curr_row}: {delayed_stack}") smartspim_channel_zarr_writer( @@ -724,12 +743,10 @@ def convert_stacks_to_ome_zarr(channel_path, logger, output_path): channel_name=channel_path.stem, logger=logger, stack_name=f"{col}_{row.split('_')[-1]}.zarr", - client=client + client=client, ) - - client.shutdown() - - + + def create_logger(output_log_path: PathLike) -> logging.Logger: """ Creates a logger that generates @@ -767,28 +784,30 @@ def create_logger(output_log_path: PathLike) -> logging.Logger: return logger + def main(): - + data_folder = Path(os.path.abspath("../data")) results_folder = Path(os.path.abspath("../results")) scratch_folder = Path(os.path.abspath("../scratch")) - + set_dask_config(dask_folder=scratch_folder) - + logger = create_logger(output_log_path=results_folder) - + BASE_PATH = data_folder.joinpath( "SmartSPIM_692911_2023-10-23_11-27-30/SmartSPIM" ) - + channels = os.listdir(str(BASE_PATH)) - + for channel in channels: convert_stacks_to_ome_zarr( channel_path=BASE_PATH.joinpath(channel), logger=logger, - output_path=results_folder + output_path=results_folder, ) - + + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/aind_smartspim_data_transformation/smartspim_job.py b/src/aind_smartspim_data_transformation/smartspim_job.py index 9293420..5678f68 100644 --- a/src/aind_smartspim_data_transformation/smartspim_job.py +++ b/src/aind_smartspim_data_transformation/smartspim_job.py @@ -7,7 +7,7 @@ import sys from datetime import datetime from pathlib import Path -from typing import Iterator, Literal, Optional, Union, List +from typing import Iterator, List, Literal, Optional, Union import numpy as np from aind_data_transformation.core import ( @@ -16,16 +16,18 @@ JobResponse, get_parser, ) +from dask.distributed import Client, LocalCluster from numcodecs.blosc import Blosc +from png_to_zarr import smartspim_channel_zarr_writer from pydantic import Field -from aind_smartspim_data_transformation.models import ( - CompressorName, +from aind_smartspim_data_transformation.dask_utils import ( + get_client, + get_deployment, ) - from aind_smartspim_data_transformation.io import PngReader -from dask.distributed import Client, LocalCluster -from png_to_zarr import smartspim_channel_zarr_writer +from aind_smartspim_data_transformation.models import CompressorName + class SmartspimJobSettings(BasicJobSettings): """SmartspimCompressionJob settings.""" @@ -46,7 +48,7 @@ class SmartspimJobSettings(BasicJobSettings): ) # It will be safer if these kwargs fields were objects with known schemas compressor_kwargs: dict = Field( - default={"cname": "zstd", "clevel": 3, 'shuffle': Blosc.SHUFFLE}, + default={"cname": "zstd", "clevel": 3, "shuffle": Blosc.SHUFFLE}, description="Arguments to be used for the compressor.", title="Compressor Kwargs", ) @@ -66,9 +68,7 @@ class SmartspimCompressionJob(GenericEtl[SmartspimJobSettings]): """Main class to handle smartspim data compression""" def _get_delayed_channel_stack( - self, - channel_paths: List[str], - output_dir: str + self, channel_paths: List[str], output_dir: str ) -> Iterator[dict]: """ Reads a stack of PNG images into a delayed zarr dataset. @@ -85,16 +85,15 @@ def _get_delayed_channel_stack( curr_col = channel_path.joinpath(col) for row in os.listdir(curr_col): curr_row = curr_col.joinpath(row) - delayed_stack = PngReader(data_path=f"{curr_row}/*.png").as_dask_array() + delayed_stack = PngReader( + data_path=f"{curr_row}/*.png" + ).as_dask_array() stack_name = f"{col}_{row.split('_')[-1]}.zarr" - stack_output_path = Path(f"{output_dir}/{channel_path.stem}") - - yield ( - delayed_stack, - stack_output_path, - stack_name + stack_output_path = Path( + f"{output_dir}/{channel_path.stem}" ) + yield (delayed_stack, stack_output_path, stack_name) def _get_compressor(self) -> Blosc: """ @@ -130,19 +129,19 @@ def _compress_and_write_channels( threads_per_worker = 1 # Instantiating local cluster for parallel writing - cluster = LocalCluster( + deployment = get_deployment() + client, _ = get_client( + deployment, + worker_options=None, # worker_options, n_workers=n_workers, - threads_per_worker=threads_per_worker, processes=True, - memory_limit="auto", ) - client = Client(cluster) i = 0 for delayed_arr, output_path, stack_name in read_channel_stacks: if i == 2: break - + i += 1 smartspim_channel_zarr_writer( image_data=delayed_arr, @@ -158,37 +157,36 @@ def _compress_and_write_channels( writing_options=compressor, ) + # Closing client + client.shutdown() + def _compress_raw_data(self) -> None: """Compresses smartspim data""" # Clip the data logging.info("Converting PNG to OMEZarr. This may take some minutes.") - output_compressed_data = ( - self.job_settings.output_directory / "SPIM" - ) - + output_compressed_data = self.job_settings.output_directory / "SPIM" + channel_paths = [ Path(self.job_settings.input_source).joinpath(folder) - for folder in os.listdir( - self.job_settings.input_source - ) + for folder in os.listdir(self.job_settings.input_source) ] # Get channel stack iterators and delayed arrays read_delayed_channel_stacks = self._get_delayed_channel_stack( channel_paths=channel_paths, - output_dir=output_compressed_data, + output_dir=output_compressed_data, ) - + # Getting compressors compressor = self._get_compressor() - + # Writing compressed stacks self._compress_and_write_channels( read_channel_stacks=read_delayed_channel_stacks, compressor=compressor, output_format=self.job_settings.compress_write_output_format, - job_kwargs=self.job_settings.compress_job_save_kwargs + job_kwargs=self.job_settings.compress_job_save_kwargs, ) logging.info("Finished compressing source data.") @@ -212,7 +210,7 @@ def run_job(self) -> JobResponse: def main(): - """ Main function """ + """Main function""" sys_args = sys.argv[1:] parser = get_parser() cli_args = parser.parse_args(sys_args) @@ -221,16 +219,19 @@ def main(): cli_args.job_settings ) elif cli_args.config_file is not None: - job_settings = SmartspimJobSettings.from_config_file(cli_args.config_file) + job_settings = SmartspimJobSettings.from_config_file( + cli_args.config_file + ) else: # Construct settings from env vars job_settings = SmartspimJobSettings( input_source="/data/SmartSPIM_714635_2024-03-18_10-47-48/SmartSPIM", - output_directory="/scratch/" + output_directory="/scratch/", ) job = SmartspimCompressionJob(job_settings=job_settings) job_response = job.run_job() logging.info(job_response.model_dump_json()) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/aind_smartspim_data_transformation/zarr_utilities.py b/src/aind_smartspim_data_transformation/zarr_utilities.py index 5f599ad..c902021 100644 --- a/src/aind_smartspim_data_transformation/zarr_utilities.py +++ b/src/aind_smartspim_data_transformation/zarr_utilities.py @@ -9,6 +9,7 @@ from typing import Optional, Tuple, Union import dask +import dask.array as da import numpy as np import pims from dask.array import concatenate, pad @@ -17,7 +18,6 @@ from natsort import natsorted from numcodecs import blosc from skimage.io import imread as sk_imread -import dask.array as da PathLike = Union[str, Path] ArrayLike = Union[dask.array.core.Array, np.ndarray] @@ -221,7 +221,9 @@ def fix_image_diff_dims( return new_arr -def concatenate_dask_arrays(arr_1: ArrayLike, arr_2: ArrayLike, axis: int) -> ArrayLike: +def concatenate_dask_arrays( + arr_1: ArrayLike, arr_2: ArrayLike, axis: int +) -> ArrayLike: """ Concatenates two arrays in a given dimension @@ -342,11 +344,13 @@ def read_chunked_stitched_image_per_channel( horizontal = [] shape = None dtype = None - column_names = list(directory_structure[channel_name][row_name].keys()) + column_names = list( + directory_structure[channel_name][row_name].keys() + ) n_cols = len(column_names) - - check_shape = (1, 256,256,256) - + + check_shape = (1, 256, 256, 256) + for column_name_idx in range(n_cols): valid_image = True column_name = column_names[column_name_idx] @@ -367,17 +371,23 @@ def read_chunked_stitched_image_per_channel( .joinpath(slice_name) ) - new_arr = lazy_tiff_reader(filepath)#, dtype=dtype, shape=shape) - + new_arr = lazy_tiff_reader( + filepath + ) # , dtype=dtype, shape=shape) + if (shape is None or dtype is None) and new_arr: shape = new_arr.shape dtype = new_arr.dtype last_col = False except ValueError: - print(f"{filepath} -> No valid image in ", slice_name, slice_pos) -# valid_image = False -# new_arr = da.zeros(check_shape, dtype=new_arr.dtype) + print( + f"{filepath} -> No valid image in ", + slice_name, + slice_pos, + ) + # valid_image = False + # new_arr = da.zeros(check_shape, dtype=new_arr.dtype) if valid_image: horizontal.append(new_arr) @@ -391,11 +401,15 @@ def read_chunked_stitched_image_per_channel( shape = horizontal[i].shape dtype = horizontal[i].dtype break - + for i in range(len(horizontal)): - if horizontal[i] is None and shape is not None and dtype is not None: + if ( + horizontal[i] is None + and shape is not None + and dtype is not None + ): horizontal[i] = da.zeros(shape, dtype=dtype) - + print("Before concat: ", horizontal) horizontal_concat = concatenate(horizontal, axis=-1) @@ -537,7 +551,9 @@ def channel_parallel_reading( if not res_idx: dask_array = res[res_idx][0] else: - dask_array = concatenate([dask_array, res[res_idx][0]], axis=-3) + dask_array = concatenate( + [dask_array, res[res_idx][0]], axis=-3 + ) print(f"Slides: {res[res_idx][1]}") diff --git a/src/aind_smartspim_data_transformation/zarr_writer.py b/src/aind_smartspim_data_transformation/zarr_writer.py index 9d86a44..ec18a15 100644 --- a/src/aind_smartspim_data_transformation/zarr_writer.py +++ b/src/aind_smartspim_data_transformation/zarr_writer.py @@ -73,7 +73,9 @@ def expand_chunks( if any(s < 1 for s in data_shape): raise ValueError("data_shape must be >= 1 for all dimensions") if any(c > s for c, s in zip(chunks, data_shape)): - raise ValueError("chunks cannot be larger than data_shape in any dimension") + raise ValueError( + "chunks cannot be larger than data_shape in any dimension" + ) if target_size <= 0: raise ValueError("target_size must be > 0") if itemsize <= 0: @@ -149,7 +151,9 @@ def gen_slices( A generator yielding tuples of slices. Each tuple can be used to index the input array. """ if len(arr_shape) != len(block_shape): - raise Exception("array shape and block shape have different lengths") + raise Exception( + "array shape and block shape have different lengths" + ) def _slice_along_dim(dim: int) -> Generator: """A helper generator function that slices along one dimension.""" @@ -169,7 +173,9 @@ def _slice_along_dim(dim: int) -> Generator: return _slice_along_dim(0) @staticmethod - def store(in_array: da.Array, out_array: ArrayLike, block_shape: tuple) -> None: + def store( + in_array: da.Array, out_array: ArrayLike, block_shape: tuple + ) -> None: """ Partitions the last 3 dimensions of a Dask array into non-overlapping blocks and writes them sequentially to a Zarr array. This is meant to reduce the scheduling burden From c6e0f85eb04d378ca68dc7274e7112bddd82f243 Mon Sep 17 00:00:00 2001 From: camilolaiton Date: Thu, 30 May 2024 15:12:48 +0000 Subject: [PATCH 4/9] updating writer --- src/aind_smartspim_data_transformation/smartspim_job.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/aind_smartspim_data_transformation/smartspim_job.py b/src/aind_smartspim_data_transformation/smartspim_job.py index 5678f68..fbef0e7 100644 --- a/src/aind_smartspim_data_transformation/smartspim_job.py +++ b/src/aind_smartspim_data_transformation/smartspim_job.py @@ -137,12 +137,7 @@ def _compress_and_write_channels( processes=True, ) - i = 0 for delayed_arr, output_path, stack_name in read_channel_stacks: - if i == 2: - break - - i += 1 smartspim_channel_zarr_writer( image_data=delayed_arr, output_path=output_path, From f13bed2e36610bfb4785257692ae76e52886e51d Mon Sep 17 00:00:00 2001 From: camilolaiton Date: Thu, 30 May 2024 15:27:55 +0000 Subject: [PATCH 5/9] removing SPIM folder from compression output --- src/aind_smartspim_data_transformation/smartspim_job.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aind_smartspim_data_transformation/smartspim_job.py b/src/aind_smartspim_data_transformation/smartspim_job.py index fbef0e7..124acdb 100644 --- a/src/aind_smartspim_data_transformation/smartspim_job.py +++ b/src/aind_smartspim_data_transformation/smartspim_job.py @@ -160,7 +160,7 @@ def _compress_raw_data(self) -> None: # Clip the data logging.info("Converting PNG to OMEZarr. This may take some minutes.") - output_compressed_data = self.job_settings.output_directory / "SPIM" + output_compressed_data = self.job_settings.output_directory channel_paths = [ Path(self.job_settings.input_source).joinpath(folder) From 60ff78e6c0db03b6543854aa85b149ccc6905c6a Mon Sep 17 00:00:00 2001 From: camilolaiton Date: Thu, 30 May 2024 15:28:21 +0000 Subject: [PATCH 6/9] adding OMEZarr extension --- src/aind_smartspim_data_transformation/smartspim_job.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aind_smartspim_data_transformation/smartspim_job.py b/src/aind_smartspim_data_transformation/smartspim_job.py index 124acdb..4dc1aa3 100644 --- a/src/aind_smartspim_data_transformation/smartspim_job.py +++ b/src/aind_smartspim_data_transformation/smartspim_job.py @@ -88,7 +88,7 @@ def _get_delayed_channel_stack( delayed_stack = PngReader( data_path=f"{curr_row}/*.png" ).as_dask_array() - stack_name = f"{col}_{row.split('_')[-1]}.zarr" + stack_name = f"{col}_{row.split('_')[-1]}.ome.zarr" stack_output_path = Path( f"{output_dir}/{channel_path.stem}" ) From b966018936b11ed91ee78b16da90447df912956c Mon Sep 17 00:00:00 2001 From: camilolaiton Date: Thu, 30 May 2024 15:41:52 +0000 Subject: [PATCH 7/9] updating input and output path --- src/aind_smartspim_data_transformation/smartspim_job.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/aind_smartspim_data_transformation/smartspim_job.py b/src/aind_smartspim_data_transformation/smartspim_job.py index 4dc1aa3..87a7774 100644 --- a/src/aind_smartspim_data_transformation/smartspim_job.py +++ b/src/aind_smartspim_data_transformation/smartspim_job.py @@ -161,10 +161,13 @@ def _compress_raw_data(self) -> None: # Clip the data logging.info("Converting PNG to OMEZarr. This may take some minutes.") output_compressed_data = self.job_settings.output_directory + + raw_path = self.job_settings.input_source / "SmartSPIM" + logging.info(f"Raw path: {raw_path} - OS: {os.listdir(self.job_settings.input_source)}") channel_paths = [ - Path(self.job_settings.input_source).joinpath(folder) - for folder in os.listdir(self.job_settings.input_source) + Path(raw_path).joinpath(folder) + for folder in os.listdir(raw_path) ] # Get channel stack iterators and delayed arrays @@ -220,7 +223,7 @@ def main(): else: # Construct settings from env vars job_settings = SmartspimJobSettings( - input_source="/data/SmartSPIM_714635_2024-03-18_10-47-48/SmartSPIM", + input_source="/data/SmartSPIM_714635_2024-03-18_10-47-48", output_directory="/scratch/", ) job = SmartspimCompressionJob(job_settings=job_settings) From 9390dd88685af706d191108f54148ed6feacfdd2 Mon Sep 17 00:00:00 2001 From: camilolaiton Date: Thu, 30 May 2024 16:09:24 +0000 Subject: [PATCH 8/9] organizing compression utilities --- .../compress/__init__.py | 3 + .../{ => compress}/dask_utils.py | 9 ++- .../{ => compress}/png_to_zarr.py | 6 +- .../{ => compress}/zarr_utilities.py | 19 +++--- .../{ => compress}/zarr_writer.py | 60 ++++++++++++------- .../io/_io.py | 17 +----- .../io/utils.py | 1 - .../models.py | 1 - .../smartspim_job.py | 23 ++++--- 9 files changed, 78 insertions(+), 61 deletions(-) create mode 100644 src/aind_smartspim_data_transformation/compress/__init__.py rename src/aind_smartspim_data_transformation/{ => compress}/dask_utils.py (90%) rename src/aind_smartspim_data_transformation/{ => compress}/png_to_zarr.py (99%) rename src/aind_smartspim_data_transformation/{ => compress}/zarr_utilities.py (95%) rename src/aind_smartspim_data_transformation/{ => compress}/zarr_writer.py (76%) diff --git a/src/aind_smartspim_data_transformation/compress/__init__.py b/src/aind_smartspim_data_transformation/compress/__init__.py new file mode 100644 index 0000000..f8926e0 --- /dev/null +++ b/src/aind_smartspim_data_transformation/compress/__init__.py @@ -0,0 +1,3 @@ +""" +Init internal package +""" diff --git a/src/aind_smartspim_data_transformation/dask_utils.py b/src/aind_smartspim_data_transformation/compress/dask_utils.py similarity index 90% rename from src/aind_smartspim_data_transformation/dask_utils.py rename to src/aind_smartspim_data_transformation/compress/dask_utils.py index d149cc9..51edce3 100644 --- a/src/aind_smartspim_data_transformation/dask_utils.py +++ b/src/aind_smartspim_data_transformation/compress/dask_utils.py @@ -37,7 +37,8 @@ def log_dashboard_address( port = client.scheduler_info()["services"]["dashboard"] user = os.getenv("USER") LOGGER.info( - f"To access the dashboard, run the following in a terminal: ssh -L {port}:{host}:{port} {user}@" + f"To access the dashboard, run the following in " + "a terminal: ssh -L {port}:{host}:{port} {user}@" f"{login_node_address} " ) @@ -78,7 +79,8 @@ def get_client( slurm_job_id = os.getenv("SLURM_JOBID") if slurm_job_id is None: raise Exception( - "SLURM_JOBID environment variable is not set. Are you running under SLURM?" + "SLURM_JOBID environment variable is not set." + "Are you running under SLURM?" ) initialize( nthreads=int(os.getenv("SLURM_CPUS_PER_TASK", 1)), @@ -108,7 +110,8 @@ def cancel_slurm_job( Args: job_id: the SLURM job ID - api_url: the URL of the SLURM REST API. E.g., "http://myhost:80/api/slurm/v0.0.36" + api_url: the URL of the SLURM REST API. + E.g., "http://myhost:80/api/slurm/v0.0.36" Raises: HTTPError: if the request to cancel the job fails diff --git a/src/aind_smartspim_data_transformation/png_to_zarr.py b/src/aind_smartspim_data_transformation/compress/png_to_zarr.py similarity index 99% rename from src/aind_smartspim_data_transformation/png_to_zarr.py rename to src/aind_smartspim_data_transformation/compress/png_to_zarr.py index bee72ec..02e44fa 100644 --- a/src/aind_smartspim_data_transformation/png_to_zarr.py +++ b/src/aind_smartspim_data_transformation/compress/png_to_zarr.py @@ -33,9 +33,11 @@ from ome_zarr.writer import write_multiscales_metadata from skimage.io import imread as sk_imread +from aind_smartspim_data_transformation.compress.zarr_utilities import * +from aind_smartspim_data_transformation.compress.zarr_writer import ( + BlockedArrayWriter, +) from aind_smartspim_data_transformation.io import PngReader -from aind_smartspim_data_transformation.zarr_utilities import * -from aind_smartspim_data_transformation.zarr_writer import BlockedArrayWriter def set_dask_config(dask_folder: str): diff --git a/src/aind_smartspim_data_transformation/zarr_utilities.py b/src/aind_smartspim_data_transformation/compress/zarr_utilities.py similarity index 95% rename from src/aind_smartspim_data_transformation/zarr_utilities.py rename to src/aind_smartspim_data_transformation/compress/zarr_utilities.py index c902021..0fa9620 100644 --- a/src/aind_smartspim_data_transformation/zarr_utilities.py +++ b/src/aind_smartspim_data_transformation/compress/zarr_utilities.py @@ -83,7 +83,8 @@ def read_image_directory_structure(folder_dir: PathLike) -> dict: ------------------------ dict: Dictionary with the image representation where: - {channel_1: ... {channel_n: {col_1: ... col_n: {row_1: ... row_n: [image_0, ..., image_n]} } } } + {channel_1: ... {channel_n: {col_1: ... + col_n: {row_1: ... row_n: [image_0, ..., image_n]} } } } """ directory_structure = {} @@ -305,7 +306,8 @@ def read_chunked_stitched_image_per_channel( directory_structure:dict dictionary to store paths of images with the following structure: - {channel_1: ... {channel_n: {col_1: ... col_n: {row_1: ... row_n: [image_0, ..., image_n]} } } } + {channel_1: ... {channel_n: {col_1: ... + col_n: {row_1: ... row_n: [image_0, ..., image_n]} } } } channel_name : str Channel name to reconstruct the image volume @@ -349,7 +351,7 @@ def read_chunked_stitched_image_per_channel( ) n_cols = len(column_names) - check_shape = (1, 256, 256, 256) + # check_shape = (1, 256, 256, 256) for column_name_idx in range(n_cols): valid_image = True @@ -386,8 +388,6 @@ def read_chunked_stitched_image_per_channel( slice_name, slice_pos, ) - # valid_image = False - # new_arr = da.zeros(check_shape, dtype=new_arr.dtype) if valid_image: horizontal.append(new_arr) @@ -459,7 +459,8 @@ def channel_parallel_reading( directory_structure: dict dictionary to store paths of images with the following structure: - {channel_1: ... {channel_n: {col_1: ... col_n: {row_1: ... row_n: [image_0, ..., image_n]} } } } + {channel_1: ... {channel_n: {col_1: ... + col_n: {row_1: ... row_n: [image_0, ..., image_n]} } } } channel_name : str Channel name to reconstruct the image volume @@ -511,7 +512,8 @@ def channel_parallel_reading( else: images_per_worker = n_images // workers print( - f"Setting workers to {workers} - {images_per_worker} - total images: {n_images}" + f"Setting workers to {workers}" + f"- {images_per_worker} - total images: {n_images}" ) # Getting 5 dim image TCZYX @@ -576,7 +578,8 @@ def parallel_read_chunked_stitched_multichannel_image( directory_structure: dict dictionary to store paths of images with the following structure: - {channel_1: ... {channel_n: {col_1: ... col_n: {row_1: ... row_n: [image_0, ..., image_n]} } } } + {channel_1: ... {channel_n: {col_1: ... + col_n: {row_1: ... row_n: [image_0, ..., image_n]} } } } sample_img: ArrayLike Image used as guide for the chunksize diff --git a/src/aind_smartspim_data_transformation/zarr_writer.py b/src/aind_smartspim_data_transformation/compress/zarr_writer.py similarity index 76% rename from src/aind_smartspim_data_transformation/zarr_writer.py rename to src/aind_smartspim_data_transformation/compress/zarr_writer.py index ec18a15..96628b9 100644 --- a/src/aind_smartspim_data_transformation/zarr_writer.py +++ b/src/aind_smartspim_data_transformation/compress/zarr_writer.py @@ -56,9 +56,11 @@ def expand_chunks( mode: str = "iso", ) -> Tuple[int, int, int]: """ - Given the shape and chunk size of a pre-chunked 3D array, determine the optimal chunk shape - closest to target_size. Expanded chunk dimensions are an integer multiple of the base chunk dimension, - to ensure optimal access patterns. + Given the shape and chunk size of a pre-chunked 3D array, + determine the optimal chunk shape closest to target_size. + Expanded chunk dimensions are an integer multiple of + the base chunk dimension, to ensure optimal access patterns. + Args: chunks: the shape of the input array chunks data_shape: the shape of the input array @@ -129,10 +131,13 @@ def gen_slices( arr_shape: Tuple[int, ...], block_shape: Tuple[int, ...] ) -> Generator: """ - Generate a series of slices that can be used to traverse an array in blocks of a given shape. + Generate a series of slices that can be + used to traverse an array in blocks of a given shape. - The method generates tuples of slices, each representing a block of the array. The blocks are generated by - iterating over the array in steps of the block shape along each dimension. + The method generates tuples of slices, each representing + a block of the array. The blocks are generated by + iterating over the array in steps of the block + shape along each dimension. Parameters ---------- @@ -140,15 +145,19 @@ def gen_slices( The shape of the array to be sliced. block_shape : tuple of int - The desired shape of the blocks. This should be a tuple of integers representing the size of each - dimension of the block. The length of `block_shape` should be equal to the length of - `arr_shape`. If the array shape is not divisible by the block shape along a dimension, the last slice + The desired shape of the blocks. This should be a + tuple of integers representing the size of each + dimension of the block. The length of `block_shape` + should be equal to the length of `arr_shape`. + If the array shape is not divisible by the block + shape along a dimension, the last slice along that dimension is truncated. Returns ------- generator of tuple of slice - A generator yielding tuples of slices. Each tuple can be used to index the input array. + A generator yielding tuples of slices. + Each tuple can be used to index the input array. """ if len(arr_shape) != len(block_shape): raise Exception( @@ -156,8 +165,12 @@ def gen_slices( ) def _slice_along_dim(dim: int) -> Generator: - """A helper generator function that slices along one dimension.""" - # Base case: if the dimension is beyond the last one, return an empty tuple + """ + A helper generator function that + slices along one dimension. + """ + # Base case: if the dimension is beyond + # the last one, return an empty tuple if dim >= len(arr_shape): yield () else: @@ -177,15 +190,17 @@ def store( in_array: da.Array, out_array: ArrayLike, block_shape: tuple ) -> None: """ - Partitions the last 3 dimensions of a Dask array into non-overlapping blocks and - writes them sequentially to a Zarr array. This is meant to reduce the scheduling burden - for massive (terabyte-scale) arrays. + Partitions the last 3 dimensions of a Dask array + into non-overlapping blocks and writes them sequentially + to a Zarr array. This is meant to reduce the + scheduling burden for massive (terabyte-scale) arrays. :param in_array: The input Dask array :param block_shape: Tuple of (block_depth, block_height, block_width) :param out_array: The output array """ - # Iterate through the input array in steps equal to the block shape dimensions + # Iterate through the input array in + # steps equal to the block shape dimensions for sl in BlockedArrayWriter.gen_slices(in_array.shape, block_shape): block = in_array[sl] da.store( @@ -200,13 +215,18 @@ def store( @staticmethod def get_block_shape(arr, target_size_mb=409600, mode="cycle"): """ - Given the shape and chunk size of a pre-chunked array, determine the optimal block shape - closest to target_size. Expanded block dimensions are an integer multiple of the chunk dimension + Given the shape and chunk size of a pre-chunked + array, determine the optimal block shape closest + to target_size. Expanded block dimensions are + an integer multiple of the chunk dimension to ensure optimal access patterns. + Args: arr: the input array - target_size_mb: target block size in megabytes, default is 409600 - mode: strategy. Must be one of "cycle", or "iso" + target_size_mb: target block size in megabytes, + default is 409600 mode: strategy. + Must be one of "cycle", or "iso" + Returns: the block shape """ diff --git a/src/aind_smartspim_data_transformation/io/_io.py b/src/aind_smartspim_data_transformation/io/_io.py index ae9b686..20a9f81 100644 --- a/src/aind_smartspim_data_transformation/io/_io.py +++ b/src/aind_smartspim_data_transformation/io/_io.py @@ -6,7 +6,7 @@ import os from abc import ABC, abstractmethod, abstractproperty from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import dask.array as da import imageio as iio @@ -21,17 +21,6 @@ from .utils import add_leading_dim, read_json_as_dict -""" -File that defines the constants used -in the package -""" - -from pathlib import Path -from typing import Union - -import dask.array as da -import numpy as np - # IO types PathLike = Union[str, Path] ArrayLike = Union[da.Array, np.ndarray] @@ -403,7 +392,7 @@ def metadata(self) -> Dict: Dictionary with image metadata """ metadata = {} - with pims.open(data_path) as imgs: + with pims.open(self.data_path) as imgs: metadata["shape"] = (1,) + (len(imgs),) + imgs.frame_shape metadata["dtype"] = np.dtype(imgs.pixel_type) @@ -555,7 +544,7 @@ def metadata(self) -> Dict: Dictionary with image metadata """ metadata = {} - with pims.open(data_path) as imgs: + with pims.open(self.data_path) as imgs: metadata["shape"] = (len(imgs),) + imgs.frame_shape return metadata diff --git a/src/aind_smartspim_data_transformation/io/utils.py b/src/aind_smartspim_data_transformation/io/utils.py index aa5ad55..9349dcb 100644 --- a/src/aind_smartspim_data_transformation/io/utils.py +++ b/src/aind_smartspim_data_transformation/io/utils.py @@ -4,7 +4,6 @@ import json import os -from pathlib import Path from typing import Optional, Union import dask.array as da diff --git a/src/aind_smartspim_data_transformation/models.py b/src/aind_smartspim_data_transformation/models.py index 5ca3517..9f3ed53 100644 --- a/src/aind_smartspim_data_transformation/models.py +++ b/src/aind_smartspim_data_transformation/models.py @@ -1,7 +1,6 @@ """Helpful models used in the ephys compression job""" from enum import Enum -from typing import List from numcodecs import Blosc diff --git a/src/aind_smartspim_data_transformation/smartspim_job.py b/src/aind_smartspim_data_transformation/smartspim_job.py index 87a7774..30dc83e 100644 --- a/src/aind_smartspim_data_transformation/smartspim_job.py +++ b/src/aind_smartspim_data_transformation/smartspim_job.py @@ -2,29 +2,27 @@ import logging import os -import platform -import shutil import sys from datetime import datetime from pathlib import Path -from typing import Iterator, List, Literal, Optional, Union +from typing import Iterator, List, Literal, Optional -import numpy as np from aind_data_transformation.core import ( BasicJobSettings, GenericEtl, JobResponse, get_parser, ) -from dask.distributed import Client, LocalCluster from numcodecs.blosc import Blosc -from png_to_zarr import smartspim_channel_zarr_writer from pydantic import Field -from aind_smartspim_data_transformation.dask_utils import ( +from aind_smartspim_data_transformation.compress.dask_utils import ( get_client, get_deployment, ) +from aind_smartspim_data_transformation.compress.png_to_zarr import ( + smartspim_channel_zarr_writer, +) from aind_smartspim_data_transformation.io import PngReader from aind_smartspim_data_transformation.models import CompressorName @@ -126,7 +124,6 @@ def _compress_and_write_channels( job_kwargs["n_jobs"] = os.cpu_count() n_workers = job_kwargs["n_jobs"] - threads_per_worker = 1 # Instantiating local cluster for parallel writing deployment = get_deployment() @@ -161,13 +158,15 @@ def _compress_raw_data(self) -> None: # Clip the data logging.info("Converting PNG to OMEZarr. This may take some minutes.") output_compressed_data = self.job_settings.output_directory - + raw_path = self.job_settings.input_source / "SmartSPIM" - logging.info(f"Raw path: {raw_path} - OS: {os.listdir(self.job_settings.input_source)}") + logging.info( + f"Raw path: {raw_path}" + f"OS: {os.listdir(self.job_settings.input_source)}" + ) channel_paths = [ - Path(raw_path).joinpath(folder) - for folder in os.listdir(raw_path) + Path(raw_path).joinpath(folder) for folder in os.listdir(raw_path) ] # Get channel stack iterators and delayed arrays From d23d389d9bd9334b4d7dd0075246d44855467f2b Mon Sep 17 00:00:00 2001 From: camilolaiton Date: Thu, 30 May 2024 16:59:53 +0000 Subject: [PATCH 9/9] updating docs and removing hardcoded paths --- .../compress/dask_utils.py | 15 ++++++++++++ .../compress/png_to_zarr.py | 20 ++++++++++++++++ .../smartspim_job.py | 24 ++++++++++++------- 3 files changed, 50 insertions(+), 9 deletions(-) diff --git a/src/aind_smartspim_data_transformation/compress/dask_utils.py b/src/aind_smartspim_data_transformation/compress/dask_utils.py index 51edce3..764b5a7 100644 --- a/src/aind_smartspim_data_transformation/compress/dask_utils.py +++ b/src/aind_smartspim_data_transformation/compress/dask_utils.py @@ -1,3 +1,7 @@ +""" +Module for dask utilities +""" + import logging import os import socket @@ -19,6 +23,8 @@ class Deployment(Enum): + """Deployment enums""" + LOCAL = "local" SLURM = "slurm" @@ -44,6 +50,15 @@ def log_dashboard_address( def get_deployment() -> str: + """ + Gets the SLURM deployment if this + exists + + Returns + ------- + str + SLURM_JOBID + """ if os.getenv("SLURM_JOBID") is None: deployment = Deployment.LOCAL.value else: diff --git a/src/aind_smartspim_data_transformation/compress/png_to_zarr.py b/src/aind_smartspim_data_transformation/compress/png_to_zarr.py index 02e44fa..602bfa1 100644 --- a/src/aind_smartspim_data_transformation/compress/png_to_zarr.py +++ b/src/aind_smartspim_data_transformation/compress/png_to_zarr.py @@ -706,6 +706,23 @@ def smartspim_channel_zarr_writer( def convert_stacks_to_ome_zarr(channel_path, logger, output_path): + """ + Converts image stacks from PNG to OMEZarr. + + Parameters + ---------- + channel_path: str + Path where the stacks for the channel + are located. + + logger: logging.Logger + Logging object + + output_path: str + Path where we want to write the converted + stacks to OMEZarr. + + """ # channel_path must end with Ex_{wav}_Em_{wav} # Setting up local cluster @@ -788,6 +805,9 @@ def create_logger(output_log_path: PathLike) -> logging.Logger: def main(): + """ + Main function in how to use these functions + """ data_folder = Path(os.path.abspath("../data")) results_folder = Path(os.path.abspath("../results")) diff --git a/src/aind_smartspim_data_transformation/smartspim_job.py b/src/aind_smartspim_data_transformation/smartspim_job.py index 30dc83e..0e61dbb 100644 --- a/src/aind_smartspim_data_transformation/smartspim_job.py +++ b/src/aind_smartspim_data_transformation/smartspim_job.py @@ -67,12 +67,12 @@ class SmartspimCompressionJob(GenericEtl[SmartspimJobSettings]): def _get_delayed_channel_stack( self, channel_paths: List[str], output_dir: str - ) -> Iterator[dict]: + ) -> Iterator[tuple]: """ Reads a stack of PNG images into a delayed zarr dataset. Returns: - Iterator[dict] + Iterator[tuple] A generator that returns delayed PNG stacks. """ @@ -99,7 +99,7 @@ def _get_compressor(self) -> Blosc: Returns ------- Blosc - Either an instantiated Blosc or WavPack class. + An instantiated Blosc compressor. """ if self.job_settings.compressor_name == CompressorName.BLOSC: @@ -114,11 +114,20 @@ def _get_compressor(self) -> Blosc: @staticmethod def _compress_and_write_channels( - read_channel_stacks: Iterator[dict], + read_channel_stacks: Iterator[tuple], compressor: Blosc, job_kwargs: dict, output_format: str = "zarr", - ) -> None: + ): + """ + Compresses SmartSPIM image data. + + Parameters + ---------- + read_channel_stacks: Iterator[tuple] + Iterator that returns the delayed image stack, + image path and stack name. + """ if job_kwargs["n_jobs"] == -1: job_kwargs["n_jobs"] = os.cpu_count() @@ -221,10 +230,7 @@ def main(): ) else: # Construct settings from env vars - job_settings = SmartspimJobSettings( - input_source="/data/SmartSPIM_714635_2024-03-18_10-47-48", - output_directory="/scratch/", - ) + job_settings = SmartspimJobSettings() job = SmartspimCompressionJob(job_settings=job_settings) job_response = job.run_job() logging.info(job_response.model_dump_json())