diff --git a/satpy/readers/netcdf_utils.py b/satpy/readers/netcdf_utils.py index c8b8a3f85f..59bf1829d2 100644 --- a/satpy/readers/netcdf_utils.py +++ b/satpy/readers/netcdf_utils.py @@ -17,16 +17,17 @@ # satpy. If not, see . """Helpers for reading netcdf-based files.""" +import functools import logging +import warnings -import dask.array as da import netCDF4 import numpy as np import xarray as xr from satpy.readers import open_file_or_filename from satpy.readers.file_handlers import BaseFileHandler -from satpy.readers.utils import np2str +from satpy.readers.utils import get_serializable_dask_array, np2str from satpy.utils import get_legacy_chunk_size LOG = logging.getLogger(__name__) @@ -85,10 +86,12 @@ class NetCDF4FileHandler(BaseFileHandler): xarray_kwargs (dict): Addition arguments to `xarray.open_dataset` cache_var_size (int): Cache variables smaller than this size. cache_handle (bool): Keep files open for lifetime of filehandler. + Uses xarray.backends.CachingFileManager, which uses a least + recently used cache. """ - file_handle = None + manager = None def __init__(self, filename, filename_info, filetype_info, auto_maskandscale=False, xarray_kwargs=None, @@ -99,14 +102,22 @@ def __init__(self, filename, filename_info, filetype_info, self.file_content = {} self.cached_file_content = {} self._use_h5netcdf = False - try: - file_handle = self._get_file_handle() - except IOError: - LOG.exception( - "Failed reading file %s. Possibly corrupted file", self.filename) - raise + self._auto_maskandscale = auto_maskandscale + if cache_handle: + self.manager = xr.backends.CachingFileManager( + functools.partial(_nc_dataset_wrapper, + auto_maskandscale=auto_maskandscale), + self.filename, mode="r") + file_handle = self.manager.acquire() + else: + try: + file_handle = self._get_file_handle() + except IOError: + LOG.exception( + "Failed reading file %s. Possibly corrupted file", self.filename) + raise - self._set_file_handle_auto_maskandscale(file_handle, auto_maskandscale) + self._set_file_handle_auto_maskandscale(file_handle, auto_maskandscale) self._set_xarray_kwargs(xarray_kwargs, auto_maskandscale) listed_variables = filetype_info.get("required_netcdf_variables") @@ -117,14 +128,22 @@ def __init__(self, filename, filename_info, filetype_info, self.collect_dimensions("", file_handle) self.collect_cache_vars(cache_var_size) - if cache_handle: - self.file_handle = file_handle - else: + if not cache_handle: file_handle.close() def _get_file_handle(self): return netCDF4.Dataset(self.filename, "r") + @property + def file_handle(self): + """Backward-compatible way for file handle caching.""" + warnings.warn( + "attribute .file_handle is deprecated, use .manager instead", + DeprecationWarning) + if self.manager is None: + return None + return self.manager.acquire() + @staticmethod def _set_file_handle_auto_maskandscale(file_handle, auto_maskandscale): if hasattr(file_handle, "set_auto_maskandscale"): @@ -196,11 +215,8 @@ def _get_required_variable_names(listed_variables, variable_name_replacements): def __del__(self): """Delete the file handler.""" - if self.file_handle is not None: - try: - self.file_handle.close() - except RuntimeError: # presumably closed already - pass + if self.manager is not None: + self.manager.close() def _collect_global_attrs(self, obj): """Collect all the global attributes for the provided file object.""" @@ -289,8 +305,8 @@ def _get_variable(self, key, val): group, key = parts else: group = None - if self.file_handle is not None: - val = self._get_var_from_filehandle(group, key) + if self.manager is not None: + val = self._get_var_from_manager(group, key) else: val = self._get_var_from_xr(group, key) return val @@ -319,18 +335,27 @@ def _get_var_from_xr(self, group, key): val.load() return val - def _get_var_from_filehandle(self, group, key): + def _get_var_from_manager(self, group, key): # Not getting coordinates as this is more work, therefore more # overhead, and those are not used downstream. + + with self.manager.acquire_context() as ds: + if group is not None: + v = ds[group][key] + else: + v = ds[key] if group is None: - g = self.file_handle + dv = get_serializable_dask_array( + self.manager, key, + chunks=v.shape, dtype=v.dtype) else: - g = self.file_handle[group] - v = g[key] + dv = get_serializable_dask_array( + self.manager, "/".join([group, key]), + chunks=v.shape, dtype=v.dtype) attrs = self._get_object_attrs(v) x = xr.DataArray( - da.from_array(v), dims=v.dimensions, attrs=attrs, - name=v.name) + dv, + dims=v.dimensions, attrs=attrs, name=v.name) return x def __contains__(self, item): @@ -443,3 +468,15 @@ def _get_attr(self, obj, key): if self._use_h5netcdf: return obj.attrs[key] return super()._get_attr(obj, key) + +def _nc_dataset_wrapper(*args, auto_maskandscale, **kwargs): + """Wrap netcdf4.Dataset setting auto_maskandscale globally. + + Helper function that wraps netcdf4.Dataset while setting extra parameters. + By encapsulating this in a helper function, we can + pass it to CachingFileManager directly. Currently sets + auto_maskandscale globally (for all variables). + """ + nc = netCDF4.Dataset(*args, **kwargs) + nc.set_auto_maskandscale(auto_maskandscale) + return nc diff --git a/satpy/readers/utils.py b/satpy/readers/utils.py index 983225acd5..3a6f3dbe0b 100644 --- a/satpy/readers/utils.py +++ b/satpy/readers/utils.py @@ -29,6 +29,7 @@ from shutil import which from subprocess import PIPE, Popen # nosec +import dask.array as da import numpy as np import pyproj import xarray as xr @@ -497,6 +498,48 @@ def remove_earthsun_distance_correction(reflectance, utc_date=None): return reflectance +def get_serializable_dask_array(manager, varname, chunks, dtype): + """Construct a serializable dask array from a variable. + + When we construct a dask array using da.array from a file, and use + that to create an xarray dataarray, the result is not serializable + and dask graphs using this dataarray cannot be computed when the dask + distributed scheduler is in use. To circumvent this problem, xarray + provides the CachingFileManager. See GH#2815 for more information. + + Should have at least one dimension. + + Example:: + + >>> import netCDF4 + >>> from xarray.backends import CachingFileManager + >>> cfm = CachingFileManager(netCDF4.Dataset, filename, mode="r") + >>> arr = get_serializable_dask_array(cfm, "my_var", 1024, "f4") + + Args: + manager (xarray.backends.CachingFileManager): + Instance of :class:`~xarray.backends.CachingFileManager` encapsulating the + dataset to be read. + varname (str): + Name of the variable (possibly including a group path). + chunks (tuple): + Chunks to use when creating the dask array. + dtype (dtype): + What dtype to use. + """ + def get_chunk(block_info=None): + arrloc = block_info[None]["array-location"] + with manager.acquire_context() as nc: + var = nc[varname] + return var[tuple(slice(*x) for x in arrloc)] + + return da.map_blocks( + get_chunk, + chunks=chunks, + dtype=dtype, + meta=np.array([], dtype=dtype)) + + class _CalibrationCoefficientParser: """Parse user-defined calibration coefficients.""" diff --git a/satpy/tests/reader_tests/test_netcdf_utils.py b/satpy/tests/reader_tests/test_netcdf_utils.py index 60d8be48de..4c994db8b5 100644 --- a/satpy/tests/reader_tests/test_netcdf_utils.py +++ b/satpy/tests/reader_tests/test_netcdf_utils.py @@ -18,7 +18,6 @@ """Module for testing the satpy.readers.netcdf_utils module.""" import os -import unittest import numpy as np import pytest @@ -71,13 +70,15 @@ def get_test_content(self, filename, filename_info, filetype_info): raise NotImplementedError("Fake File Handler subclass must implement 'get_test_content'") -class TestNetCDF4FileHandler(unittest.TestCase): +class TestNetCDF4FileHandler: """Test NetCDF4 File Handler Utility class.""" - def setUp(self): + @pytest.fixture() + def dummy_nc_file(self, tmp_path): """Create a test NetCDF4 file.""" from netCDF4 import Dataset - with Dataset("test.nc", "w") as nc: + fn = tmp_path / "test.nc" + with Dataset(fn, "w") as nc: # Create dimensions nc.createDimension("rows", 10) nc.createDimension("cols", 100) @@ -116,17 +117,14 @@ def setUp(self): d.test_attr_str = "test_string" d.test_attr_int = 0 d.test_attr_float = 1.2 + return fn - def tearDown(self): - """Remove the previously created test file.""" - os.remove("test.nc") - - def test_all_basic(self): + def test_all_basic(self, dummy_nc_file): """Test everything about the NetCDF4 class.""" import xarray as xr from satpy.readers.netcdf_utils import NetCDF4FileHandler - file_handler = NetCDF4FileHandler("test.nc", {}, {}) + file_handler = NetCDF4FileHandler(dummy_nc_file, {}, {}) assert file_handler["/dimension/rows"] == 10 assert file_handler["/dimension/cols"] == 100 @@ -165,7 +163,7 @@ def test_all_basic(self): assert file_handler.file_handle is None assert file_handler["ds2_sc"] == 42 - def test_listed_variables(self): + def test_listed_variables(self, dummy_nc_file): """Test that only listed variables/attributes area collected.""" from satpy.readers.netcdf_utils import NetCDF4FileHandler @@ -175,12 +173,12 @@ def test_listed_variables(self): "attr/test_attr_str", ] } - file_handler = NetCDF4FileHandler("test.nc", {}, filetype_info) + file_handler = NetCDF4FileHandler(dummy_nc_file, {}, filetype_info) assert len(file_handler.file_content) == 2 assert "test_group/attr/test_attr_str" in file_handler.file_content assert "attr/test_attr_str" in file_handler.file_content - def test_listed_variables_with_composing(self): + def test_listed_variables_with_composing(self, dummy_nc_file): """Test that composing for listed variables is performed.""" from satpy.readers.netcdf_utils import NetCDF4FileHandler @@ -199,7 +197,7 @@ def test_listed_variables_with_composing(self): ], } } - file_handler = NetCDF4FileHandler("test.nc", {}, filetype_info) + file_handler = NetCDF4FileHandler(dummy_nc_file, {}, filetype_info) assert len(file_handler.file_content) == 3 assert "test_group/ds1_f/attr/test_attr_str" in file_handler.file_content assert "test_group/ds1_i/attr/test_attr_str" in file_handler.file_content @@ -208,10 +206,10 @@ def test_listed_variables_with_composing(self): assert not any("another_parameter" in var for var in file_handler.file_content) assert "test_group/attr/test_attr_str" in file_handler.file_content - def test_caching(self): + def test_caching(self, dummy_nc_file): """Test that caching works as intended.""" from satpy.readers.netcdf_utils import NetCDF4FileHandler - h = NetCDF4FileHandler("test.nc", {}, {}, cache_var_size=1000, + h = NetCDF4FileHandler(dummy_nc_file, {}, {}, cache_var_size=1000, cache_handle=True) assert h.file_handle is not None assert h.file_handle.isopen() @@ -226,8 +224,6 @@ def test_caching(self): np.testing.assert_array_equal( h["ds2_f"], np.arange(10. * 100).reshape((10, 100))) - h.__del__() - assert not h.file_handle.isopen() def test_filenotfound(self): """Test that error is raised when file not found.""" @@ -237,21 +233,21 @@ def test_filenotfound(self): with pytest.raises(IOError, match=".*(No such file or directory|Unknown file format).*"): NetCDF4FileHandler("/thisfiledoesnotexist.nc", {}, {}) - def test_get_and_cache_npxr_is_xr(self): + def test_get_and_cache_npxr_is_xr(self, dummy_nc_file): """Test that get_and_cache_npxr() returns xr.DataArray.""" import xarray as xr from satpy.readers.netcdf_utils import NetCDF4FileHandler - file_handler = NetCDF4FileHandler("test.nc", {}, {}, cache_handle=True) + file_handler = NetCDF4FileHandler(dummy_nc_file, {}, {}, cache_handle=True) data = file_handler.get_and_cache_npxr("test_group/ds1_f") assert isinstance(data, xr.DataArray) - def test_get_and_cache_npxr_data_is_cached(self): + def test_get_and_cache_npxr_data_is_cached(self, dummy_nc_file): """Test that the data are cached when get_and_cache_npxr() is called.""" from satpy.readers.netcdf_utils import NetCDF4FileHandler - file_handler = NetCDF4FileHandler("test.nc", {}, {}, cache_handle=True) + file_handler = NetCDF4FileHandler(dummy_nc_file, {}, {}, cache_handle=True) data = file_handler.get_and_cache_npxr("test_group/ds1_f") # Delete the dataset from the file content dict, it should be available from the cache @@ -265,7 +261,6 @@ class TestNetCDF4FsspecFileHandler: def test_default_to_netcdf4_lib(self): """Test that the NetCDF4 backend is used by default.""" - import os import tempfile import h5py @@ -393,3 +388,40 @@ def test_get_data_as_xarray_scalar_h5netcdf(tmp_path): res = get_data_as_xarray(fid["test_data"]) np.testing.assert_equal(res.data, np.array(data)) assert res.attrs == NC_ATTRS + + +@pytest.fixture() +def dummy_nc(tmp_path): + """Fixture to create a dummy NetCDF file and return its path.""" + import xarray as xr + + fn = tmp_path / "sjaunja.nc" + ds = xr.Dataset(data_vars={"kaitum": (["x"], np.arange(10))}) + ds.to_netcdf(fn) + return fn + + +def test_caching_distributed(dummy_nc): + """Test that the distributed scheduler works with file handle caching. + + This is a test for GitHub issue 2815. + """ + from dask.distributed import Client + + from satpy.readers.netcdf_utils import NetCDF4FileHandler + + fh = NetCDF4FileHandler(dummy_nc, {}, {}, cache_handle=True) + + def doubler(x): + return x * 2 + + # As documented in GH issue 2815, using dask distributed with the file + # handle cacher might fail in non-trivial ways, such as giving incorrect + # results. Testing map_blocks is one way to reproduce the problem + # reliably, even though the problem also manifests itself (in different + # ways) without map_blocks. + + + with Client(): + dask_doubler = fh["kaitum"].map_blocks(doubler) + dask_doubler.compute() diff --git a/satpy/tests/reader_tests/test_utils.py b/satpy/tests/reader_tests/test_utils.py index b36a2b1d60..40a872db29 100644 --- a/satpy/tests/reader_tests/test_utils.py +++ b/satpy/tests/reader_tests/test_utils.py @@ -514,6 +514,64 @@ def test_generic_open_binary(tmp_path, data, filename, mode): assert read_binary_data == dummy_data +class TestDistributed: + """Distributed-related tests. + + Distributed-related tests are grouped so that they can share a class-scoped + fixture setting up the distributed client, as this setup is relatively + slow. + """ + + @pytest.fixture(scope="class") + def dask_dist_client(self): + """Set up and close a dask distributed client.""" + from dask.distributed import Client + cl = Client() + yield cl + cl.close() + + + @pytest.mark.parametrize("shape", [(2,), (2, 3), (2, 3, 4)]) + @pytest.mark.parametrize("dtype", ["i4", "f4", "f8"]) + @pytest.mark.parametrize("grp", ["/", "/in/a/group"]) + def test_get_serializable_dask_array(self, tmp_path, dask_dist_client, shape, dtype, grp): + """Test getting a dask distributed friendly serialisable dask array.""" + import netCDF4 + from xarray.backends import CachingFileManager + + fn = tmp_path / "sjaunja.nc" + ds = xr.Dataset( + data_vars={ + "kaitum": (["x", "y", "z"][:len(shape)], + np.arange(np.prod(shape), + dtype=dtype).reshape(shape))}) + ds.to_netcdf(fn, group=grp) + + cfm = CachingFileManager(netCDF4.Dataset, fn, mode="r") + arr = hf.get_serializable_dask_array(cfm, "/".join([grp, "kaitum"]), + chunks=shape, dtype=dtype) + + # As documented in GH issue 2815, using dask distributed with the file + # handle cacher might fail in non-trivial ways, such as giving incorrect + # results. Testing map_blocks is one way to reproduce the problem + # reliably, even though the problem also manifests itself (in different + # ways) without map_blocks. + + def doubler(x): + # with a workaround for https://github.com/numpy/numpy/issues/27029 + return x * x.dtype.type(2) + + dask_doubler = arr.map_blocks(doubler, dtype=arr.dtype) + res = dask_doubler.compute() + # test before and after computation, as to confirm we have the correct + # shape and dtype and that computing doesn't change them + assert shape == dask_doubler.shape + assert shape == res.shape + assert dtype == dask_doubler.dtype + assert dtype == res.dtype + np.testing.assert_array_equal(res, np.arange(np.prod(shape)).reshape(shape)*2) + + class TestCalibrationCoefficientPicker: """Unit tests for calibration coefficient selection."""