From 3fe232ebf3f912de21aaf3c02f1cec5f3147fc5a Mon Sep 17 00:00:00 2001 From: Niklas Netter Date: Fri, 27 Sep 2024 15:38:00 +0200 Subject: [PATCH 1/3] dask chunking on frame level implemented --- src/nd2/nd2file.py | 68 ++++++++++++++++++++++++++++++++++++++++---- tests/test_reader.py | 43 ++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+), 5 deletions(-) diff --git a/src/nd2/nd2file.py b/src/nd2/nd2file.py index c432aac..c46bb74 100644 --- a/src/nd2/nd2file.py +++ b/src/nd2/nd2file.py @@ -888,7 +888,12 @@ def write_tiff( modify_ome=modify_ome, ) - def to_dask(self, wrapper: bool = True, copy: bool = True) -> dask.array.core.Array: + def to_dask( + self, + wrapper: bool = True, + copy: bool = True, + frame_chunks: int | tuple | None = None, + ) -> dask.array.core.Array: """Create dask array (delayed reader) representing image. This generally works well, but it remains to be seen whether performance @@ -913,6 +918,11 @@ def to_dask(self, wrapper: bool = True, copy: bool = True) -> dask.array.core.Ar If `True` (the default), the dask chunk-reading function will return an array copy. This can avoid segfaults in certain cases, though it may also add overhead. + frame_chunks : tuple | int | None + If `None` (the default), the file will not be chunked on the frame level. + Otherwise expects the dask compatible chunks to chunk the frames along + channel, y, and x axis. If a tuple, must have same length as + `self._frame_shape`. Returns ------- @@ -922,7 +932,46 @@ def to_dask(self, wrapper: bool = True, copy: bool = True) -> dask.array.core.Ar from dask.array.core import map_blocks chunks = [(1,) * x for x in self._coord_shape] - chunks += [(x,) for x in self._frame_shape] + if frame_chunks is None: + chunks += [(x,) for x in self._frame_shape] + elif isinstance(frame_chunks, int): + for frame_len in self._frame_shape: + div = frame_len // frame_chunks + if div == 0: + chunks.append((frame_len,)) + else: + _chunks = (frame_chunks,) * div + if frame_len % frame_chunks != 0: + _chunks += (frame_len - div * frame_chunks,) + chunks.append(_chunks) + elif len(frame_chunks) != len(self._frame_shape): + raise ValueError( + f"frame_chunks must be of length {len(self._frame_shape)}." + ) + elif isinstance(frame_chunks[0], int): + if not all(isinstance(frame_chunk, int) for frame_chunk in frame_chunks): + raise ValueError( + "frame_chunks must be a tuple of ints or tuple of tuple of ints." + ) + for frame_len, frame_chunk in zip(self._frame_shape, frame_chunks): + div = frame_len // frame_chunk + if div == 0: + chunks.append((frame_len,)) + else: + _chunks = (frame_chunk,) * div + if frame_len % frame_chunk != 0: + _chunks += (frame_len - div * frame_chunk,) + chunks.append(_chunks) + else: + if not all( + sum(frame_chunk) == frame_len + for frame_chunk, frame_len in zip(frame_chunks, self._frame_shape) + ): + raise ValueError( + "Sum of frame_chunks does not align with frame shape of file." + ) + chunks.extend(frame_chunks) + dask_arr = map_blocks( self._dask_block, copy=copy, @@ -944,8 +993,8 @@ def _seq_index_from_coords(self, coords: Sequence) -> Sequence[int] | SupportsIn return self._NO_IDX return np.ravel_multi_index(coords, self._coord_shape) # type: ignore - def _dask_block(self, copy: bool, block_id: tuple[int]) -> np.ndarray: - if isinstance(block_id, np.ndarray): + def _dask_block(self, copy: bool, block_info: dict) -> np.ndarray: + if isinstance(block_info, np.ndarray): return None with self._lock: was_closed = self.closed @@ -953,6 +1002,7 @@ def _dask_block(self, copy: bool, block_id: tuple[int]) -> np.ndarray: self.open() try: ncoords = len(self._coord_shape) + block_id = block_info[None]["chunk-location"] idx = self._seq_index_from_coords(block_id[:ncoords]) if idx == self._NO_IDX: @@ -962,6 +1012,11 @@ def _dask_block(self, copy: bool, block_id: tuple[int]) -> np.ndarray: ) idx = 0 data = self.read_frame(int(idx)) # type: ignore + slices = tuple( + slice(al[0], al[1]) + for al in block_info[None]["array-location"][ncoords:] + ) + data = data[slices] data = data.copy() if copy else data return data[(np.newaxis,) * ncoords] finally: @@ -1207,7 +1262,10 @@ def binary_data(self) -> BinaryLayers | None: return self._rdr.binary_data() def ome_metadata( - self, *, include_unstructured: bool = True, tiff_file_name: str | None = None + self, + *, + include_unstructured: bool = True, + tiff_file_name: str | None = None, ) -> OME: """Return `ome_types.OME` metadata object for this file. diff --git a/tests/test_reader.py b/tests/test_reader.py index 6b3d6e5..4575b02 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -52,6 +52,49 @@ def test_dask_closed(single_nd2): assert isinstance(dsk.compute(), np.ndarray) +@pytest.fixture( + params=[ + (None, ((2,), (32,), (32,))), + (2, ((2,), (2,) * 16, (2,) * 16)), + (3, ((2,), (3,) * 10 + (2,), (3,) * 10 + (2,))), + ((3, 17, 33), ((2,), (17, 15), (32,))), + ((2, 16, 16), ((2,), (16, 16), (16, 16))), + (((1, 1), (8, 8, 8, 8), (16, 16)), ((1, 1), (8, 8, 8, 8), (16, 16))), + (((2,), (20, 12), (32,)), ((2,), (20, 12), (32,))), + ], + ids=lambda x: str(x), +) +def passing_frame_chunks(request): + return request.param + + +def test_dask_chunking(single_nd2, passing_frame_chunks): + with ND2File(single_nd2) as nd: + dsk = nd.to_dask(frame_chunks=passing_frame_chunks[0]) + assert len(dsk.chunks) == 4 + assert dsk.chunks[1:] == passing_frame_chunks[1] + unchunked = nd.to_dask() + assert (dsk.compute() == unchunked.compute()).all() + + +@pytest.fixture( + params=[ + (2, 3), + (2, (16, 16), (16, 16)), + ((1, 1, 1), (16, 16), (32,)), + ], + ids=lambda x: str(x), +) +def failing_frame_chunks(request): + return request.params + + +def test_value_error_dask_chunking(single_nd2, failing_frame_chunks): + with ND2File(single_nd2) as nd: + with pytest.raises(ValueError): + nd.to_dask(frame_chunks=passing_frame_chunks) + + def test_full_read(new_nd2): pytest.importorskip("xarray") with ND2File(new_nd2) as nd: From 7fc4787529376f969fd1595c9d9ff0ecb1d803ac Mon Sep 17 00:00:00 2001 From: Niklas Netter Date: Fri, 27 Sep 2024 15:54:07 +0200 Subject: [PATCH 2/3] fix: test_value_error_dask_chunking called wrong fixture --- tests/test_reader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_reader.py b/tests/test_reader.py index 4575b02..a1f1cbd 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -86,13 +86,13 @@ def test_dask_chunking(single_nd2, passing_frame_chunks): ids=lambda x: str(x), ) def failing_frame_chunks(request): - return request.params + return request.param def test_value_error_dask_chunking(single_nd2, failing_frame_chunks): with ND2File(single_nd2) as nd: with pytest.raises(ValueError): - nd.to_dask(frame_chunks=passing_frame_chunks) + nd.to_dask(frame_chunks=failing_frame_chunks) def test_full_read(new_nd2): From 83b185aa293e644b0656a876fdc379b218aa8dcb Mon Sep 17 00:00:00 2001 From: Niklas Netter Date: Fri, 27 Sep 2024 16:30:00 +0200 Subject: [PATCH 3/3] test: simplify with @pytest.mark.parametrize --- tests/test_reader.py | 53 ++++++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 29 deletions(-) diff --git a/tests/test_reader.py b/tests/test_reader.py index a1f1cbd..90d08a7 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -52,47 +52,42 @@ def test_dask_closed(single_nd2): assert isinstance(dsk.compute(), np.ndarray) -@pytest.fixture( - params=[ - (None, ((2,), (32,), (32,))), - (2, ((2,), (2,) * 16, (2,) * 16)), - (3, ((2,), (3,) * 10 + (2,), (3,) * 10 + (2,))), - ((3, 17, 33), ((2,), (17, 15), (32,))), - ((2, 16, 16), ((2,), (16, 16), (16, 16))), - (((1, 1), (8, 8, 8, 8), (16, 16)), ((1, 1), (8, 8, 8, 8), (16, 16))), - (((2,), (20, 12), (32,)), ((2,), (20, 12), (32,))), - ], - ids=lambda x: str(x), -) -def passing_frame_chunks(request): - return request.param +PASSING_FRAME_CHUNKS = [ + (None, ((2,), (32,), (32,))), + (2, ((2,), (2,) * 16, (2,) * 16)), + (3, ((2,), (3,) * 10 + (2,), (3,) * 10 + (2,))), + ((3, 17, 33), ((2,), (17, 15), (32,))), + ((2, 16, 16), ((2,), (16, 16), (16, 16))), + ( + ((1, 1), (8, 8, 8, 8), (16, 16)), + ((1, 1), (8, 8, 8, 8), (16, 16)), + ), + (((2,), (20, 12), (32,)), ((2,), (20, 12), (32,))), +] -def test_dask_chunking(single_nd2, passing_frame_chunks): +@pytest.mark.parametrize("frame_chunks, expected_chunks", PASSING_FRAME_CHUNKS) +def test_dask_chunking(single_nd2, frame_chunks, expected_chunks): with ND2File(single_nd2) as nd: - dsk = nd.to_dask(frame_chunks=passing_frame_chunks[0]) + dsk = nd.to_dask(frame_chunks=frame_chunks) assert len(dsk.chunks) == 4 - assert dsk.chunks[1:] == passing_frame_chunks[1] + assert dsk.chunks[1:] == expected_chunks unchunked = nd.to_dask() assert (dsk.compute() == unchunked.compute()).all() -@pytest.fixture( - params=[ - (2, 3), - (2, (16, 16), (16, 16)), - ((1, 1, 1), (16, 16), (32,)), - ], - ids=lambda x: str(x), -) -def failing_frame_chunks(request): - return request.param +FAILING_FRAME_CHUNKS = [ + (2, 3), + (2, (16, 16), (16, 16)), + ((1, 1, 1), (16, 16), (32,)), +] -def test_value_error_dask_chunking(single_nd2, failing_frame_chunks): +@pytest.mark.parametrize("frame_chunks", FAILING_FRAME_CHUNKS) +def test_value_error_dask_chunking(single_nd2, frame_chunks): with ND2File(single_nd2) as nd: with pytest.raises(ValueError): - nd.to_dask(frame_chunks=failing_frame_chunks) + nd.to_dask(frame_chunks=frame_chunks) def test_full_read(new_nd2):