diff --git a/opendap_protocol/protocol.py b/opendap_protocol/protocol.py index 7cc9614..657395c 100644 --- a/opendap_protocol/protocol.py +++ b/opendap_protocol/protocol.py @@ -40,19 +40,26 @@ clients using the netCDF4 library. PyDAP client libraries are not supported. """ +import importlib import re from dataclasses import dataclass +from dask.cache import Cache import dask.array as da import numpy as np +has_xarray = bool(importlib.util.find_spec("xarray")) + +if has_xarray: + from xarray import Variable + INDENT = ' ' SLICE_CONSTRAINT_RE = r'\[([\d,\W]+)\]$' @dataclass class Config: - DASK_ENCODE_CHUNK_SIZE: int = 20e6 + STREAMING_BLOCK_SIZE: int = 20e6 class DAPError(Exception): @@ -488,16 +495,21 @@ def dods_encode(data, dtype): yield packed_length + chunk_size = int(Config.STREAMING_BLOCK_SIZE / data.dtype.itemsize) if isinstance(data, da.Array): - # Encode in chunks of a defined size if we work with dask.Array - chunk_size = int(Config.DASK_ENCODE_CHUNK_SIZE / data.dtype.itemsize) - serialize_data = data.ravel().rechunk(chunk_size) - for block in serialize_data.blocks: - yield block.astype(dtype.str).compute().tobytes() - else: - # Make sure we always encode an array or we will get wrong results - data = np.asarray(data) - yield data.astype(dtype.str).tobytes() + data = data.ravel() + + for start in range(0, data.size, chunk_size): + end = start + chunk_size + if isinstance(data, da.Array): + block = data[slice(start, end)].compute() + elif has_xarray and isinstance(data, Variable): + npidxr = np.unravel_index(np.arange(start, min(end, data.size)), shape=data.shape) + xridxr = tuple(Variable(dims="__points__", data=idxr) for idxr in npidxr) + block = data[xridxr].to_numpy() + else: + block = np.asarray(data).ravel()[slice(start, end)] + yield block.astype(dtype.str).tobytes() def parse_slice_constraint(constraint): @@ -562,6 +574,6 @@ def set_dask_encoding_chunk_size(chunk_size: int): """ chunk_size = int(chunk_size) if chunk_size > 0: - Config.DASK_ENCODE_CHUNK_SIZE = chunk_size + Config.STREAMING_BLOCK_SIZE = chunk_size else: raise ValueError('Encoding chunk size needs to be greather than 0.') diff --git a/setup.py b/setup.py index 7105303..2585ff5 100644 --- a/setup.py +++ b/setup.py @@ -43,6 +43,7 @@ test_requirements = [ 'pytest', + 'xarray', ] extras = { diff --git a/tests/test_all.py b/tests/test_all.py index 726ad03..f4ac1a7 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -32,6 +32,7 @@ import dask.array as da import numpy as np import opendap_protocol as dap +import xarray as xr import pytest from opendap_protocol.protocol import dods_encode @@ -68,9 +69,13 @@ def test_dods_encode(): data_vals = da.from_array(np_data, chunks=(14, y_dim, 1, vertical_dim, 1, 1)) + variable = xr.Variable(dims=("x", "y", "time", "vertical", "real", "ref_time"), + data=np_data) + x = dap.dods_encode(data_vals, dap.Int32) y = dap.dods_encode(np_data, dap.Int32) - assert b''.join(x) == b''.join(y) + z = dap.dods_encode(variable, dap.Int32) + assert b''.join(x) == b''.join(y) == b''.join(z) int_arrdata = np.arange(0, 20, 2, dtype='