Skip to content

Commit

Permalink
adding unit tests and updating package version
Browse files Browse the repository at this point in the history
  • Loading branch information
camilolaiton committed Jun 10, 2024
1 parent d79e6f4 commit 9410f94
Show file tree
Hide file tree
Showing 17 changed files with 589 additions and 743 deletions.
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ dependencies = [
'numcodecs==0.11.0',
'dask-image==2024.5.3',
'xarray_multiscale==1.1.0',
'bokeh==2.4.2',
'pims==0.6.1',
'dask[distributed]==2024.5.1',
'dask[distributed]==2024.5.2',
'ome-zarr==0.8.2',
'imagecodecs[all]==2023.3.16',
'natsort==8.4.0',
Expand Down Expand Up @@ -79,7 +78,8 @@ exclude_lines = [
"import",
"pragma: no cover"
]
fail_under = 100
fail_under = 75
show_missing = true

[tool.isort]
line_length = 79
Expand Down
2 changes: 1 addition & 1 deletion src/aind_smartspim_data_transformation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Init package"""

__version__ = "0.0.6"
__version__ = "0.0.7"
3 changes: 3 additions & 0 deletions src/aind_smartspim_data_transformation/_shared/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
Shared constants
"""
12 changes: 12 additions & 0 deletions src/aind_smartspim_data_transformation/_shared/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
Defines all the types used in the module
"""

from pathlib import Path
from typing import Union

import dask.array as da
import numpy as np

PathLike = Union[Path, str]
ArrayLike = Union[da.Array, np.ndarray]
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def log_dashboard_address(
user = os.getenv("USER") # noqa: F841
LOGGER.info(
f"To access the dashboard, run the following in "
"a terminal: ssh -L {port}:{host}:{port} {user}@"
f"a terminal: ssh -L {port}:{host}:{port} {user}@"
f"{login_node_address} "
)

Expand Down
106 changes: 51 additions & 55 deletions src/aind_smartspim_data_transformation/compress/png_to_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from dask import config as da_cfg
from dask.array.core import Array
from dask.base import tokenize
from dask.distributed import Client, LocalCluster, performance_report
from dask.distributed import Client, LocalCluster # , performance_report

# from distributed import wait
from numcodecs import blosc
Expand All @@ -31,15 +31,15 @@
from ome_zarr.writer import write_multiscales_metadata
from skimage.io import imread as sk_imread

from aind_smartspim_data_transformation.compress.zarr_utilities import (
from aind_smartspim_data_transformation._shared.types import (
ArrayLike,
PathLike,
pad_array_n_d,
)
from aind_smartspim_data_transformation.compress.zarr_writer import (
BlockedArrayWriter,
)
from aind_smartspim_data_transformation.io import PngReader
from aind_smartspim_data_transformation.io.utils import pad_array_n_d


def set_dask_config(dask_folder: str):
Expand Down Expand Up @@ -646,66 +646,62 @@ def smartspim_channel_zarr_writer(
metadata=_get_pyramid_metadata(),
)

performance_report_path = f"{output_path}/report_{stack_name}.html"
# performance_report_path = f"{output_path}/report_{stack_name}.html"

start_time = time.time()
# Writing zarr and performance report
with performance_report(filename=performance_report_path):
logger.info(f"{'='*40}Writing channel {channel_name}{'='*40}")
# with performance_report(filename=performance_report_path):
logger.info(f"{'='*40}Writing channel {channel_name}{'='*40}")

# Writing zarr
block_shape = list(
BlockedArrayWriter.get_block_shape(
arr=image_data, target_size_mb=12800 # 51200,
)
# Writing zarr
block_shape = list(
BlockedArrayWriter.get_block_shape(
arr=image_data, target_size_mb=12800 # 51200,
)
)

# Formatting to 5D block shape
block_shape = ([1] * (5 - len(block_shape))) + block_shape
written_pyramid = []
pyramid_group = None

# Writing multiple levels
for level in range(n_lvls):
if not level:
array_to_write = image_data

else:
# It's faster to write the scale and then read it back
# to compute the next scale
previous_scale = da.from_zarr(
pyramid_group, pyramid_group.chunks
)
new_scale_factor = (
[1] * (len(previous_scale.shape) - len(scale_factor))
) + scale_factor

previous_scale_pyramid, _ = compute_pyramid(
data=previous_scale,
scale_axis=new_scale_factor,
chunks=image_data.chunksize,
n_lvls=2,
)
array_to_write = previous_scale_pyramid[-1]

logger.info(f"[level {level}]: pyramid level: {array_to_write}")

# Create the scale dataset
pyramid_group = new_channel_group.create_dataset(
name=level,
shape=array_to_write.shape,
chunks=array_to_write.chunksize,
dtype=array_to_write.dtype,
compressor=writing_options,
dimension_separator="/",
overwrite=True,
# Formatting to 5D block shape
block_shape = ([1] * (5 - len(block_shape))) + block_shape
written_pyramid = []
pyramid_group = None

# Writing multiple levels
for level in range(n_lvls):
if not level:
array_to_write = image_data

else:
# It's faster to write the scale and then read it back
# to compute the next scale
previous_scale = da.from_zarr(pyramid_group, pyramid_group.chunks)
new_scale_factor = (
[1] * (len(previous_scale.shape) - len(scale_factor))
) + scale_factor

previous_scale_pyramid, _ = compute_pyramid(
data=previous_scale,
scale_axis=new_scale_factor,
chunks=image_data.chunksize,
n_lvls=2,
)
array_to_write = previous_scale_pyramid[-1]

logger.info(f"[level {level}]: pyramid level: {array_to_write}")

# Create the scale dataset
pyramid_group = new_channel_group.create_dataset(
name=level,
shape=array_to_write.shape,
chunks=array_to_write.chunksize,
dtype=array_to_write.dtype,
compressor=writing_options,
dimension_separator="/",
overwrite=True,
)

# Block Zarr Writer
BlockedArrayWriter.store(
array_to_write, pyramid_group, block_shape
)
written_pyramid.append(array_to_write)
# Block Zarr Writer
BlockedArrayWriter.store(array_to_write, pyramid_group, block_shape)
written_pyramid.append(array_to_write)

end_time = time.time()
logger.info(f"Time to write the dataset: {end_time - start_time}")
Expand Down
Loading

0 comments on commit 9410f94

Please sign in to comment.