diff --git a/pyproject.toml b/pyproject.toml index 59982abe5ea..6c5b892b468 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -243,6 +243,8 @@ filterwarnings = [ "ignore:ANTIALIAS is deprecated and will be removed in Pillow 10:DeprecationWarning:tensorboardX.summary", # https://github.com/Lightning-AI/lightning/issues/16756 "ignore:Deprecated call to `pkg_resources.declare_namespace:DeprecationWarning", + # https://github.com/pydata/xarray/issues/7259 + "ignore: numpy.ndarray size changed, may indicate binary incompatibility. Expected 16 from C header, got 96 from PyObject", "ignore:pkg_resources is deprecated as an API.:DeprecationWarning:lightning_utilities.core.imports", "ignore:Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated:DeprecationWarning:jsonargparse", # https://github.com/pytorch/pytorch/issues/110549 diff --git a/requirements/datasets.txt b/requirements/datasets.txt index 15f710b0126..35c9a8f5158 100644 --- a/requirements/datasets.txt +++ b/requirements/datasets.txt @@ -7,5 +7,9 @@ pyvista==0.42.3 radiant-mlhub==0.4.1 rarfile==4.1 scikit-image==0.22.0 +xarray==2023.7.0 +rioxarray==0.14.1 +xarray +netCDF4 scipy==1.11.3 zipfile-deflate64==0.2.0 diff --git a/requirements/min-reqs.old b/requirements/min-reqs.old index c19e625314b..b7e79a4801a 100644 --- a/requirements/min-reqs.old +++ b/requirements/min-reqs.old @@ -31,6 +31,7 @@ radiant-mlhub==0.3.0 rarfile==4.0 scikit-image==0.18.0 scipy==1.6.2 +xarray zipfile-deflate64==0.2.0 # docs diff --git a/tests/data/rioxr/data.py b/tests/data/rioxr/data.py new file mode 100644 index 00000000000..44cf63446db --- /dev/null +++ b/tests/data/rioxr/data.py @@ -0,0 +1,69 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil + +import cftime +import numpy as np +import pandas as pd +import xarray as xr + +SIZE = 32 + +LATS: list[tuple[float]] = [(40, 42), (60, 62), (80, 82)] + +LONS: list[tuple[float]] = [(-55, -50), (-5, 5), (80, 85)] + +VAR_NAMES = ["zos", "tos"] + +DIR = "data" + +CF_TIME = [True, False, True] + +NUM_TIME_STEPS = 3 + + +def create_rioxr_dataset( + lat_min: float, + lat_max: float, + lon_min: float, + lon_max: float, + cf_time: bool, + var_name: str, + filename: str, +): + # Generate x and y coordinates + lats = np.linspace(lat_min, lat_max, SIZE) + lons = np.linspace(lon_min, lon_max, SIZE) + + if cf_time: + times = [cftime.datetime(2000, 1, i + 1) for i in range(NUM_TIME_STEPS)] + else: + times = pd.date_range(start="2000-01-01", periods=NUM_TIME_STEPS, freq="D") + + # data with shape (time, x, y) + data = np.random.rand(len(times), len(lons), len(lats)) + + # Create the xarray dataset + ds = xr.Dataset( + data_vars={var_name: (("time", "x", "y"), data)}, + coords={"x": lons, "y": lats, "time": times}, + ) + ds["x"].attrs["units"] = "degrees_east" + ds["x"].attrs["crs"] = "EPSG:4326" + ds["y"].attrs["units"] = "degrees_north" + ds["y"].attrs["crs"] = "EPSG:4326" + ds.to_netcdf(path=filename) + + +if __name__ == "__main__": + if os.path.isdir(DIR): + shutil.rmtree(DIR) + os.makedirs(DIR) + for var_name in VAR_NAMES: + for lats, lons, cf_time in zip(LATS, LONS, CF_TIME): + path = os.path.join(DIR, f"{var_name}_{lats}_{lons}.nc") + create_rioxr_dataset( + lats[0], lats[1], lons[0], lons[1], cf_time, var_name, path + ) diff --git a/tests/data/rioxr/data/tos_(40, 42)_(-55, -50).nc b/tests/data/rioxr/data/tos_(40, 42)_(-55, -50).nc new file mode 100644 index 00000000000..33a9ccc92ac Binary files /dev/null and b/tests/data/rioxr/data/tos_(40, 42)_(-55, -50).nc differ diff --git a/tests/data/rioxr/data/tos_(60, 62)_(-5, 5).nc b/tests/data/rioxr/data/tos_(60, 62)_(-5, 5).nc new file mode 100644 index 00000000000..735f07d5039 Binary files /dev/null and b/tests/data/rioxr/data/tos_(60, 62)_(-5, 5).nc differ diff --git a/tests/data/rioxr/data/tos_(80, 82)_(80, 85).nc b/tests/data/rioxr/data/tos_(80, 82)_(80, 85).nc new file mode 100644 index 00000000000..451b43cb0ee Binary files /dev/null and b/tests/data/rioxr/data/tos_(80, 82)_(80, 85).nc differ diff --git a/tests/data/rioxr/data/zos_(40, 42)_(-55, -50).nc b/tests/data/rioxr/data/zos_(40, 42)_(-55, -50).nc new file mode 100644 index 00000000000..ee835c65220 Binary files /dev/null and b/tests/data/rioxr/data/zos_(40, 42)_(-55, -50).nc differ diff --git a/tests/data/rioxr/data/zos_(60, 62)_(-5, 5).nc b/tests/data/rioxr/data/zos_(60, 62)_(-5, 5).nc new file mode 100644 index 00000000000..9d38556b595 Binary files /dev/null and b/tests/data/rioxr/data/zos_(60, 62)_(-5, 5).nc differ diff --git a/tests/data/rioxr/data/zos_(80, 82)_(80, 85).nc b/tests/data/rioxr/data/zos_(80, 82)_(80, 85).nc new file mode 100644 index 00000000000..9f40b482f36 Binary files /dev/null and b/tests/data/rioxr/data/zos_(80, 82)_(80, 85).nc differ diff --git a/tests/datasets/test_rioxr.py b/tests/datasets/test_rioxr.py new file mode 100644 index 00000000000..5b547ea8306 --- /dev/null +++ b/tests/datasets/test_rioxr.py @@ -0,0 +1,43 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest +import torch + +from torchgeo.datasets import ( + BoundingBox, + IntersectionDataset, + RioXarrayDataset, + UnionDataset, +) + +pytest.importorskip("rioxarray") + + +class TestRioXarrayDataset: + @pytest.fixture(scope="class") + def dataset(self) -> RioXarrayDataset: + root = os.path.join("tests", "data", "rioxr", "data") + return RioXarrayDataset(root=root, data_variables=["zos", "tos"]) + + def test_getitem(self, dataset: RioXarrayDataset) -> None: + x = dataset[dataset.bounds] + assert isinstance(x, dict) + assert isinstance(x["image"], torch.Tensor) + + def test_and(self, dataset: RioXarrayDataset) -> None: + ds = dataset & dataset + assert isinstance(ds, IntersectionDataset) + + def test_or(self, dataset: RioXarrayDataset) -> None: + ds = dataset | dataset + assert isinstance(ds, UnionDataset) + + def test_invalid_query(self, dataset: RioXarrayDataset) -> None: + query = BoundingBox(0, 0, 0, 0, 0, 0) + with pytest.raises( + IndexError, match="query: .* not found in index with bounds:" + ): + dataset[query] diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index c4a0db0caa8..0a2261aab3b 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -85,6 +85,7 @@ from .potsdam import Potsdam2D from .reforestree import ReforesTree from .resisc45 import RESISC45 +from .rioxr import RioXarrayDataset from .rwanda_field_boundary import RwandaFieldBoundary from .seasonet import SeasoNet from .seco import SeasonalContrastS2 @@ -240,6 +241,7 @@ "NonGeoClassificationDataset", "NonGeoDataset", "RasterDataset", + "RioXarrayDataset", "UnionDataset", "VectorDataset", # Utilities diff --git a/torchgeo/datasets/rioxr.py b/torchgeo/datasets/rioxr.py new file mode 100644 index 00000000000..d28c1481efb --- /dev/null +++ b/torchgeo/datasets/rioxr.py @@ -0,0 +1,280 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""In-memory geographical xarray.DataArray and xarray.Dataset.""" + +import glob +import os +import re +import sys +from datetime import datetime +from typing import Any, Callable, Optional, cast + +import numpy as np +import torch +import xarray as xr +from rasterio.crs import CRS +from rioxarray.merge import merge_arrays +from rtree.index import Index, Property + +from .geo import GeoDataset +from .utils import BoundingBox + + +class RioXarrayDataset(GeoDataset): + """Wrapper for geographical datasets stored as Xarray Datasets. + + Relies on rioxarray. + + .. versionadded:: 5.0 + """ + + filename_glob = "*" + filename_regex = ".*" + + is_image = True + + spatial_x_name = "x" + spatial_y_name = "y" + + transform = None + + @property + def dtype(self) -> torch.dtype: + """The dtype of the dataset (overrides the dtype of the data file via a cast). + + Returns: + the dtype of the dataset + """ + if self.is_image: + return torch.float32 + else: + return torch.long + + def harmonize_format(self, ds): + """Convert the dataset to the standard format. + + Args: + ds: dataset or array to harmonize + + Returns: + the harmonized dataset or array + """ + # rioxarray expects spatial dimensions to be named x and y + ds.rio.set_spatial_dims(self.spatial_x_name, self.spatial_y_name, inplace=True) + + # if x coords go from 0 to 360, convert to -180 to 180 + if ds[self.spatial_x_name].min() > 180: + ds = ds.assign_coords( + {self.spatial_x_name: ds[self.spatial_x_name] % 360 - 180} + ) + + # if y coords go from 0 to 180, convert to -90 to 90 + if ds[self.spatial_x_name].min() > 90: + ds = ds.assign_coords( + {self.spatial_y_name: ds[self.spatial_y_name] % 180 - 90} + ) + # expect asceding coordinate values + ds = ds.sortby(self.spatial_x_name, ascending=True) + ds = ds.sortby(self.spatial_y_name, ascending=True) + return ds + + def __init__( + self, + root: str, + data_variables: Optional[list[str]] = None, + crs: Optional[CRS] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + ) -> None: + """Initialize a new Dataset instance. + + Args: + root: directory with files to be opened with xarray + data_variables: data variables that should be gathered from the collection + of xarray datasets + crs: :term:`coordinate reference system (CRS)` to warp to + (defaults to the CRS of dataarray) + transforms: a function/transform that takes an input sample + and returns a transformed version + + Raises: + FileNotFoundError: if no files are found in ``root`` + """ + super().__init__(transforms) + + self.root = root + + if data_variables: + self.data_variables = data_variables + else: + data_variables_to_collect: list[str] = [] + + self.transforms = transforms + + # Create an R-tree to index the dataset + self.index = Index(interleaved=False, properties=Property(dimension=3)) + + # Populate the dataset index + i = 0 + pathname = os.path.join(root, self.filename_glob) + filename_regex = re.compile(self.filename_regex, re.VERBOSE) + for filepath in glob.iglob(pathname, recursive=True): + match = re.match(filename_regex, os.path.basename(filepath)) + if match is not None: + with xr.open_dataset(filepath, decode_times=True) as ds: + ds = self.harmonize_format(ds) + + if crs is None: + crs = ds.rio.crs + + try: + (minx, miny, maxx, maxy) = ds.rio.bounds() + except AttributeError: + # or take the shape of the data variable? + continue + + if hasattr(ds, "time"): + try: + indices = ds.indexes["time"].to_datetimeindex() + except AttributeError: + indices = ds.indexes["time"] + + mint = indices.min().to_pydatetime().timestamp() + maxt = indices.max().to_pydatetime().timestamp() + else: + mint = 0 + maxt = sys.maxsize + coords = (minx, maxx, miny, maxy, mint, maxt) + self.index.insert(i, coords, filepath) + i += 1 + + # collect all possible data variables if self.data_variables is None + if not data_variables: + data_variables_to_collect.extend(list(ds.data_vars)) + + if i == 0: + import pdb + + pdb.set_trace() + msg = f"No {self.__class__.__name__} data was found in `root='{self.root}'`" + raise FileNotFoundError(msg) + + if not data_variables: + self.data_variables = list(set(data_variables_to_collect)) + + if not crs: + self._crs = "EPSG:4326" + else: + self._crs = cast(CRS, crs) + self.res = 1.0 + + def _infer_spatial_coordinate_names(self, ds) -> tuple[str]: + """Infer the names of the spatial coordinates. + + Args: + ds: Dataset or DataArray of which to infer the spatial coordinates + + Returns: + x and y coordinate names + """ + x_name = None + y_name = None + for coord_name, coord in ds.coords.items(): + if hasattr(coord, "units"): + if any( + [ + x in coord.units.lower() + for x in ["degrees_north", "degree_north"] + ] + ): + y_name = coord_name + elif any( + [x in coord.units.lower() for x in ["degrees_east", "degree_east"]] + ): + x_name = coord_name + + if not x_name or not y_name: + raise ValueError("Spatial Coordinate Units not found in Dataset.") + + return x_name, y_name + + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + """Retrieve image/mask and metadata indexed by query. + + Args: + query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index + + Returns: + sample of image/mask and metadata at that index + + Raises: + IndexError: if query is not found in the index + """ + hits = self.index.intersection(tuple(query), objects=True) + items = [hit.object for hit in hits] + + if not items: + raise IndexError( + f"query: {query} not found in index with bounds: {self.bounds}" + ) + + data_arrays: list["np.typing.NDArray"] = [] + for item in items: + with xr.open_dataset(item, decode_cf=True) as ds: + ds = self.harmonize_format(ds) + # select time dimension + if hasattr(ds, "time"): + try: + ds["time"] = ds.indexes["time"].to_datetimeindex() + except AttributeError: + ds["time"] = ds.indexes["time"] + ds = ds.sel( + time=slice( + datetime.fromtimestamp(query.mint), + datetime.fromtimestamp(query.maxt), + ) + ) + + for variable in self.data_variables: + if hasattr(ds, variable): + da = ds[variable] + if not da.rio.crs: + da.rio.write_crs(self._crs, inplace=True) + elif da.rio.crs != self._crs: + da = da.rio.reproject(self._crs) + # clip box ignores time dimension + clipped = da.rio.clip_box( + minx=query.minx, + miny=query.miny, + maxx=query.maxx, + maxy=query.maxy, + ) + # rioxarray expects this order + clipped = clipped.transpose( + "time", self.spatial_y_name, self.spatial_x_name, ... + ) + + # set proper transform # TODO not working + clipped.rio.write_transform(self.transform) + data_arrays.append(clipped.squeeze()) + + import pdb + + pdb.set_trace() + merged_data = torch.from_numpy( + merge_arrays( + data_arrays, bounds=(query.minx, query.miny, query.maxx, query.maxy) + ).data + ) + sample = {"crs": self.crs, "bbox": query} + + merged_data = merged_data.to(self.dtype) + if self.is_image: + sample["image"] = merged_data + else: + sample["mask"] = merged_data + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample