diff --git a/src/nd2/nd2file.py b/src/nd2/nd2file.py index c432aac..c46bb74 100644 --- a/src/nd2/nd2file.py +++ b/src/nd2/nd2file.py @@ -888,7 +888,12 @@ def write_tiff( modify_ome=modify_ome, ) - def to_dask(self, wrapper: bool = True, copy: bool = True) -> dask.array.core.Array: + def to_dask( + self, + wrapper: bool = True, + copy: bool = True, + frame_chunks: int | tuple | None = None, + ) -> dask.array.core.Array: """Create dask array (delayed reader) representing image. This generally works well, but it remains to be seen whether performance @@ -913,6 +918,11 @@ def to_dask(self, wrapper: bool = True, copy: bool = True) -> dask.array.core.Ar If `True` (the default), the dask chunk-reading function will return an array copy. This can avoid segfaults in certain cases, though it may also add overhead. + frame_chunks : tuple | int | None + If `None` (the default), the file will not be chunked on the frame level. + Otherwise expects the dask compatible chunks to chunk the frames along + channel, y, and x axis. If a tuple, must have same length as + `self._frame_shape`. Returns ------- @@ -922,7 +932,46 @@ def to_dask(self, wrapper: bool = True, copy: bool = True) -> dask.array.core.Ar from dask.array.core import map_blocks chunks = [(1,) * x for x in self._coord_shape] - chunks += [(x,) for x in self._frame_shape] + if frame_chunks is None: + chunks += [(x,) for x in self._frame_shape] + elif isinstance(frame_chunks, int): + for frame_len in self._frame_shape: + div = frame_len // frame_chunks + if div == 0: + chunks.append((frame_len,)) + else: + _chunks = (frame_chunks,) * div + if frame_len % frame_chunks != 0: + _chunks += (frame_len - div * frame_chunks,) + chunks.append(_chunks) + elif len(frame_chunks) != len(self._frame_shape): + raise ValueError( + f"frame_chunks must be of length {len(self._frame_shape)}." + ) + elif isinstance(frame_chunks[0], int): + if not all(isinstance(frame_chunk, int) for frame_chunk in frame_chunks): + raise ValueError( + "frame_chunks must be a tuple of ints or tuple of tuple of ints." + ) + for frame_len, frame_chunk in zip(self._frame_shape, frame_chunks): + div = frame_len // frame_chunk + if div == 0: + chunks.append((frame_len,)) + else: + _chunks = (frame_chunk,) * div + if frame_len % frame_chunk != 0: + _chunks += (frame_len - div * frame_chunk,) + chunks.append(_chunks) + else: + if not all( + sum(frame_chunk) == frame_len + for frame_chunk, frame_len in zip(frame_chunks, self._frame_shape) + ): + raise ValueError( + "Sum of frame_chunks does not align with frame shape of file." + ) + chunks.extend(frame_chunks) + dask_arr = map_blocks( self._dask_block, copy=copy, @@ -944,8 +993,8 @@ def _seq_index_from_coords(self, coords: Sequence) -> Sequence[int] | SupportsIn return self._NO_IDX return np.ravel_multi_index(coords, self._coord_shape) # type: ignore - def _dask_block(self, copy: bool, block_id: tuple[int]) -> np.ndarray: - if isinstance(block_id, np.ndarray): + def _dask_block(self, copy: bool, block_info: dict) -> np.ndarray: + if isinstance(block_info, np.ndarray): return None with self._lock: was_closed = self.closed @@ -953,6 +1002,7 @@ def _dask_block(self, copy: bool, block_id: tuple[int]) -> np.ndarray: self.open() try: ncoords = len(self._coord_shape) + block_id = block_info[None]["chunk-location"] idx = self._seq_index_from_coords(block_id[:ncoords]) if idx == self._NO_IDX: @@ -962,6 +1012,11 @@ def _dask_block(self, copy: bool, block_id: tuple[int]) -> np.ndarray: ) idx = 0 data = self.read_frame(int(idx)) # type: ignore + slices = tuple( + slice(al[0], al[1]) + for al in block_info[None]["array-location"][ncoords:] + ) + data = data[slices] data = data.copy() if copy else data return data[(np.newaxis,) * ncoords] finally: @@ -1207,7 +1262,10 @@ def binary_data(self) -> BinaryLayers | None: return self._rdr.binary_data() def ome_metadata( - self, *, include_unstructured: bool = True, tiff_file_name: str | None = None + self, + *, + include_unstructured: bool = True, + tiff_file_name: str | None = None, ) -> OME: """Return `ome_types.OME` metadata object for this file. diff --git a/tests/test_reader.py b/tests/test_reader.py index 6b3d6e5..90d08a7 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -52,6 +52,44 @@ def test_dask_closed(single_nd2): assert isinstance(dsk.compute(), np.ndarray) +PASSING_FRAME_CHUNKS = [ + (None, ((2,), (32,), (32,))), + (2, ((2,), (2,) * 16, (2,) * 16)), + (3, ((2,), (3,) * 10 + (2,), (3,) * 10 + (2,))), + ((3, 17, 33), ((2,), (17, 15), (32,))), + ((2, 16, 16), ((2,), (16, 16), (16, 16))), + ( + ((1, 1), (8, 8, 8, 8), (16, 16)), + ((1, 1), (8, 8, 8, 8), (16, 16)), + ), + (((2,), (20, 12), (32,)), ((2,), (20, 12), (32,))), +] + + +@pytest.mark.parametrize("frame_chunks, expected_chunks", PASSING_FRAME_CHUNKS) +def test_dask_chunking(single_nd2, frame_chunks, expected_chunks): + with ND2File(single_nd2) as nd: + dsk = nd.to_dask(frame_chunks=frame_chunks) + assert len(dsk.chunks) == 4 + assert dsk.chunks[1:] == expected_chunks + unchunked = nd.to_dask() + assert (dsk.compute() == unchunked.compute()).all() + + +FAILING_FRAME_CHUNKS = [ + (2, 3), + (2, (16, 16), (16, 16)), + ((1, 1, 1), (16, 16), (32,)), +] + + +@pytest.mark.parametrize("frame_chunks", FAILING_FRAME_CHUNKS) +def test_value_error_dask_chunking(single_nd2, frame_chunks): + with ND2File(single_nd2) as nd: + with pytest.raises(ValueError): + nd.to_dask(frame_chunks=frame_chunks) + + def test_full_read(new_nd2): pytest.importorskip("xarray") with ND2File(new_nd2) as nd: