Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Xarray Dataset Support #1490

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions requirements/datasets.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@ pyvista==0.41.1
radiant-mlhub==0.4.1
rarfile==4.0
scikit-image==0.21.0
xarray==2023.7.0
rioxarray==0.14.1
scipy==1.11.2
zipfile-deflate64==0.2.0
1 change: 1 addition & 0 deletions requirements/min-reqs.old
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ radiant-mlhub==0.3.0
rarfile==4.0
scikit-image==0.18.0
scipy==1.6.2
xarray
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will need to determine the minimum version that works before merging

zipfile-deflate64==0.2.0

# docs
Expand Down
65 changes: 65 additions & 0 deletions tests/data/rioxr/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import shutil

import cftime
import numpy as np
import pandas as pd
import xarray as xr

SIZE = 32

LATS: list[tuple[float]] = [(40, 42), (60, 62), (80, 82)]

LONS: list[tuple[float]] = [(-55, -50), (-60, -55), (-85, 80)]

VAR_NAMES = ["zos", "tos"]

DIR = "data"

CF_TIME = [True, False, True]

NUM_TIME_STEPS = 3


def create_rioxr_dataset(
lat_min: float,
lat_max: float,
lon_min: float,
lon_max: float,
cf_time: bool,
var_name: str,
filename: str,
):
# Generate x and y coordinates
lats = np.linspace(lat_min, lat_max, SIZE)
lons = np.linspace(lon_min, lon_max, SIZE)

if cf_time:
times = [cftime.datetime(2000, 1, i + 1) for i in range(NUM_TIME_STEPS)]
else:
times = pd.date_range(start="2000-01-01", periods=NUM_TIME_STEPS, freq="D")

# data with shape (time, x, y)
data = np.random.rand(len(times), len(lats), len(lons))

# Create the xarray dataset
ds = xr.Dataset(
data_vars={var_name: (("time", "x", "y"), data)},
coords={"x": lats, "y": lons, "time": times},
)
ds.to_netcdf(path=filename)


if __name__ == "__main__":
if os.path.isdir(DIR):
shutil.rmtree(DIR)
os.makedirs(DIR)
for var_name in VAR_NAMES:
for lats, lons, cf_time in zip(LATS, LONS, CF_TIME):
path = os.path.join(DIR, f"{var_name}_{lats}_{lons}.nc")
create_rioxr_dataset(
lats[0], lats[1], lons[0], lons[1], cf_time, var_name, path
)
Binary file added tests/data/rioxr/data/tos_(40, 42)_(-55, -50).nc
Binary file not shown.
Binary file added tests/data/rioxr/data/tos_(60, 62)_(-60, -55).nc
Binary file not shown.
Binary file added tests/data/rioxr/data/tos_(80, 82)_(-85, 80).nc
Binary file not shown.
Binary file added tests/data/rioxr/data/zos_(40, 42)_(-55, -50).nc
Binary file not shown.
Binary file added tests/data/rioxr/data/zos_(60, 62)_(-60, -55).nc
Binary file not shown.
Binary file added tests/data/rioxr/data/zos_(80, 82)_(-85, 80).nc
Binary file not shown.
43 changes: 43 additions & 0 deletions tests/datasets/test_rioxr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os

import pytest
import torch

from torchgeo.datasets import (
BoundingBox,
IntersectionDataset,
RioXarrayDataset,
UnionDataset,
)

pytest.importorskip("rioxarray")


class TestRioXarrayDataset:
@pytest.fixture(scope="class")
def dataset(self) -> RioXarrayDataset:
root = os.path.join("tests", "data", "rioxr", "data")
return RioXarrayDataset(root=root, data_variables=["zos", "tos"])

def test_getitem(self, dataset: RioXarrayDataset) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)

def test_and(self, dataset: RioXarrayDataset) -> None:
ds = dataset & dataset
assert isinstance(ds, IntersectionDataset)

def test_or(self, dataset: RioXarrayDataset) -> None:
ds = dataset | dataset
assert isinstance(ds, UnionDataset)

def test_invalid_query(self, dataset: RioXarrayDataset) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
with pytest.raises(
IndexError, match="query: .* not found in index with bounds:"
):
dataset[query]
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
from .potsdam import Potsdam2D
from .reforestree import ReforesTree
from .resisc45 import RESISC45
from .rioxr import RioXarrayDataset
from .seco import SeasonalContrastS2
from .sen12ms import SEN12MS
from .sentinel import Sentinel, Sentinel1, Sentinel2
Expand Down Expand Up @@ -231,6 +232,7 @@
"NonGeoClassificationDataset",
"NonGeoDataset",
"RasterDataset",
"RioXarrayDataset",
"UnionDataset",
"VectorDataset",
# Utilities
Expand Down
233 changes: 233 additions & 0 deletions torchgeo/datasets/rioxr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""In-memory geographical xarray.DataArray and xarray.Dataset."""

import glob
import os
import re
import sys
from datetime import datetime
from typing import Any, Callable, Optional, cast

import numpy as np
import torch
import xarray as xr
from rasterio.crs import CRS
from rioxarray.merge import merge_arrays
from rtree.index import Index, Property

from .geo import GeoDataset
from .utils import BoundingBox


class RioXarrayDataset(GeoDataset):
"""Wrapper for geographical datasets stored as Xarray Datasets.

Relies on rioxarray.
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
"""

filename_glob = "*"
filename_regex = ".*"

is_image = True

@property
def dtype(self) -> torch.dtype:
"""The dtype of the dataset (overrides the dtype of the data file via a cast).

Returns:
the dtype of the dataset

.. versionadded:: 5.0
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
"""
if self.is_image:
return torch.float32
else:
return torch.long

def __init__(
self,
root: str,
data_variables: Optional[list[str]] = None,
crs: Optional[CRS] = None,
res: Optional[float] = None,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
) -> None:
"""Initialize a new Dataset instance.

Args:
root: directory with files to be opened with xarray
data_variables: data variables that should be gathered from the collection
of xarray datasets
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of dataarray)
res: resolution of the dataset in units of CRS
(defaults to the resolution of the dataarray)
transforms: a function/transform that takes an input sample
and returns a transformed version

Raises:
FileNotFoundError: if no files are found in ``root``
"""
super().__init__(transforms)

self.root = root

if data_variables:
self.data_variables = data_variables
else:
data_variables_to_collect: list[str] = []

self.transforms = transforms

# Create an R-tree to index the dataset
self.index = Index(interleaved=False, properties=Property(dimension=3))

# Populate the dataset index
i = 0
pathname = os.path.join(root, self.filename_glob)
filename_regex = re.compile(self.filename_regex, re.VERBOSE)
for filepath in glob.iglob(pathname, recursive=True):
match = re.match(filename_regex, os.path.basename(filepath))
if match is not None:
with xr.open_dataset(filepath, decode_times=True) as ds:
# rioxarray expects spatial dimensions to be named x and y
x_name, y_name = self._infer_spatial_coordinate_names(ds)
ds = ds.rename({x_name: "x", y_name: "y"})

if crs is None:
crs = ds.rio.crs

if res is None:
res = ds.rio.resolution()[0]

(minx, miny, maxx, maxy) = ds.rio.bounds()

if hasattr(ds, "time"):
try:
indices = ds.indexes["time"].to_datetimeindex()
except AttributeError:
indices = ds.indexes["time"]

mint = indices.min().to_pydatetime().timestamp()
maxt = indices.max().to_pydatetime().timestamp()
else:
mint = 0
maxt = sys.maxsize
coords = (minx, maxx, miny, maxy, mint, maxt)
self.index.insert(i, coords, filepath)
i += 1

# collect all possible data variables if self.data_variables is None
if not data_variables:
data_variables_to_collect.extend(list(ds.data_vars))

if i == 0:
msg = f"No {self.__class__.__name__} data was found in `root='{self.root}'`"
raise FileNotFoundError(msg)

if not data_variables:
self.data_variables = list(set(data_variables_to_collect))

if not crs:
self._crs = "EPSG:4326"
else:
self._crs = cast(CRS, crs)
self.res = cast(float, res)

def _infer_spatial_coordinate_names(self, ds):
"""Infer the names of the spatial coordinates.

Args:
ds: Dataset or DataArray of which to infer the spatial coordinates

Returns:
x and y coordinate names
"""
for dim_name, dim in ds.coords.items():
if hasattr(dim, "units"):
if any(
[x in dim.units.lower() for x in ["degrees_north", "degree_north"]]
):
y_name = dim_name
elif any(
[x in dim.units.lower() for x in ["degrees_east", "degree_east"]]
):
x_name = dim_name

if not x_name or not y_name:
raise ValueError("Spatial Coordinate Units not found in Dataset")

return x_name, y_name

def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
"""Retrieve image/mask and metadata indexed by query.

Args:
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index

Returns:
sample of image/mask and metadata at that index

Raises:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(tuple(query), objects=True)
items = [hit.object for hit in hits]

if not items:
raise IndexError(
f"query: {query} not found in index with bounds: {self.bounds}"
)

data_arrays: list["np.typing.NDArray"] = []
for item in items:
with xr.open_dataset(item, decode_cf=True) as ds:
# rioxarray expects spatial dimensions to be named x and y
x_name, y_name = self._infer_spatial_coordinate_names(ds)
ds = ds.rename({x_name: "x", y_name: "y"})

if not ds.rio.crs:
ds.rio.write_crs(self._crs, inplace=True)
elif ds.rio.crs != self._crs:
ds = ds.rio.reproject(self._crs)

# clip box ignores time dimension
clipped = ds.rio.clip_box(
minx=query.minx, miny=query.miny, maxx=query.maxx, maxy=query.maxy
)
# select time dimension
if hasattr(ds, "time"):
try:
clipped["time"] = clipped.indexes["time"].to_datetimeindex()
except AttributeError:
clipped["time"] = clipped.indexes["time"]
clipped = clipped.sel(
time=slice(
datetime.fromtimestamp(query.mint),
datetime.fromtimestamp(query.maxt),
)
)

for variable in self.data_variables:
if hasattr(clipped, variable):
data_arrays.append(clipped[variable].squeeze())

merged_data = torch.from_numpy(
merge_arrays(
data_arrays, bounds=(query.minx, query.miny, query.maxx, query.maxy)
).data
)
sample = {"crs": self.crs, "bbox": query}

merged_data = merged_data.to(self.dtype)
if self.is_image:
sample["image"] = merged_data
else:
sample["mask"] = merged_data

if self.transforms is not None:
sample = self.transforms(sample)

return sample
Loading