-
Notifications
You must be signed in to change notification settings - Fork 347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Xarray Dataset Support #1490
Open
nilsleh
wants to merge
20
commits into
microsoft:main
Choose a base branch
from
nilsleh:xarray_ds
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Xarray Dataset Support #1490
Changes from 11 commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
bf1cd30
first draft
nilsleh 9cb8efe
progress
nilsleh 064cad6
test with time-step indexing
nilsleh 28766a4
add basic unit tests
nilsleh 5ffd160
time step optional
nilsleh 1459b12
xarray as dependency
nilsleh 8f93e07
try to get tests with deps to work
nilsleh dc4cc8c
some requested changes
nilsleh c58972f
rioxarray merge arrays
nilsleh 6d0dc39
resolve conflicts
nilsleh 0b7e0b5
make data_variables optional and infer spatial coordinates automatically
nilsleh 27999f3
working tests locally
nilsleh ff49477
tests remote
nilsleh 1e774fb
merge main
nilsleh a06ab61
bounds and transform inconsisten error for synthetic data
nilsleh 702dfef
latest attempt
nilsleh 67954d7
netcdf4 req
nilsleh 5c99a2f
merge main
nilsleh 48cdbc8
store changes
nilsleh f025a84
store changes
nilsleh File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import os | ||
import shutil | ||
|
||
import cftime | ||
import numpy as np | ||
import pandas as pd | ||
import xarray as xr | ||
|
||
SIZE = 32 | ||
|
||
LATS: list[tuple[float]] = [(40, 42), (60, 62), (80, 82)] | ||
|
||
LONS: list[tuple[float]] = [(-55, -50), (-60, -55), (-85, 80)] | ||
|
||
VAR_NAMES = ["zos", "tos"] | ||
|
||
DIR = "data" | ||
|
||
CF_TIME = [True, False, True] | ||
|
||
NUM_TIME_STEPS = 3 | ||
|
||
|
||
def create_rioxr_dataset( | ||
lat_min: float, | ||
lat_max: float, | ||
lon_min: float, | ||
lon_max: float, | ||
cf_time: bool, | ||
var_name: str, | ||
filename: str, | ||
): | ||
# Generate x and y coordinates | ||
lats = np.linspace(lat_min, lat_max, SIZE) | ||
lons = np.linspace(lon_min, lon_max, SIZE) | ||
|
||
if cf_time: | ||
times = [cftime.datetime(2000, 1, i + 1) for i in range(NUM_TIME_STEPS)] | ||
else: | ||
times = pd.date_range(start="2000-01-01", periods=NUM_TIME_STEPS, freq="D") | ||
|
||
# data with shape (time, x, y) | ||
data = np.random.rand(len(times), len(lats), len(lons)) | ||
|
||
# Create the xarray dataset | ||
ds = xr.Dataset( | ||
data_vars={var_name: (("time", "x", "y"), data)}, | ||
coords={"x": lats, "y": lons, "time": times}, | ||
) | ||
ds.to_netcdf(path=filename) | ||
|
||
|
||
if __name__ == "__main__": | ||
if os.path.isdir(DIR): | ||
shutil.rmtree(DIR) | ||
os.makedirs(DIR) | ||
for var_name in VAR_NAMES: | ||
for lats, lons, cf_time in zip(LATS, LONS, CF_TIME): | ||
path = os.path.join(DIR, f"{var_name}_{lats}_{lons}.nc") | ||
create_rioxr_dataset( | ||
lats[0], lats[1], lons[0], lons[1], cf_time, var_name, path | ||
) |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import os | ||
|
||
import pytest | ||
import torch | ||
|
||
from torchgeo.datasets import ( | ||
BoundingBox, | ||
IntersectionDataset, | ||
RioXarrayDataset, | ||
UnionDataset, | ||
) | ||
|
||
pytest.importorskip("rioxarray") | ||
|
||
|
||
class TestRioXarrayDataset: | ||
@pytest.fixture(scope="class") | ||
def dataset(self) -> RioXarrayDataset: | ||
root = os.path.join("tests", "data", "rioxr", "data") | ||
return RioXarrayDataset(root=root, data_variables=["zos", "tos"]) | ||
|
||
def test_getitem(self, dataset: RioXarrayDataset) -> None: | ||
x = dataset[dataset.bounds] | ||
assert isinstance(x, dict) | ||
assert isinstance(x["image"], torch.Tensor) | ||
|
||
def test_and(self, dataset: RioXarrayDataset) -> None: | ||
ds = dataset & dataset | ||
assert isinstance(ds, IntersectionDataset) | ||
|
||
def test_or(self, dataset: RioXarrayDataset) -> None: | ||
ds = dataset | dataset | ||
assert isinstance(ds, UnionDataset) | ||
|
||
def test_invalid_query(self, dataset: RioXarrayDataset) -> None: | ||
query = BoundingBox(0, 0, 0, 0, 0, 0) | ||
with pytest.raises( | ||
IndexError, match="query: .* not found in index with bounds:" | ||
): | ||
dataset[query] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,233 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
"""In-memory geographical xarray.DataArray and xarray.Dataset.""" | ||
|
||
import glob | ||
import os | ||
import re | ||
import sys | ||
from datetime import datetime | ||
from typing import Any, Callable, Optional, cast | ||
|
||
import numpy as np | ||
import torch | ||
import xarray as xr | ||
from rasterio.crs import CRS | ||
from rioxarray.merge import merge_arrays | ||
from rtree.index import Index, Property | ||
|
||
from .geo import GeoDataset | ||
from .utils import BoundingBox | ||
|
||
|
||
class RioXarrayDataset(GeoDataset): | ||
"""Wrapper for geographical datasets stored as Xarray Datasets. | ||
|
||
Relies on rioxarray. | ||
adamjstewart marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
|
||
filename_glob = "*" | ||
filename_regex = ".*" | ||
|
||
is_image = True | ||
|
||
@property | ||
def dtype(self) -> torch.dtype: | ||
"""The dtype of the dataset (overrides the dtype of the data file via a cast). | ||
|
||
Returns: | ||
the dtype of the dataset | ||
|
||
.. versionadded:: 5.0 | ||
adamjstewart marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
if self.is_image: | ||
return torch.float32 | ||
else: | ||
return torch.long | ||
|
||
def __init__( | ||
self, | ||
root: str, | ||
data_variables: Optional[list[str]] = None, | ||
crs: Optional[CRS] = None, | ||
res: Optional[float] = None, | ||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, | ||
) -> None: | ||
"""Initialize a new Dataset instance. | ||
|
||
Args: | ||
root: directory with files to be opened with xarray | ||
data_variables: data variables that should be gathered from the collection | ||
of xarray datasets | ||
crs: :term:`coordinate reference system (CRS)` to warp to | ||
(defaults to the CRS of dataarray) | ||
res: resolution of the dataset in units of CRS | ||
(defaults to the resolution of the dataarray) | ||
transforms: a function/transform that takes an input sample | ||
and returns a transformed version | ||
|
||
Raises: | ||
FileNotFoundError: if no files are found in ``root`` | ||
""" | ||
super().__init__(transforms) | ||
|
||
self.root = root | ||
|
||
if data_variables: | ||
self.data_variables = data_variables | ||
else: | ||
data_variables_to_collect: list[str] = [] | ||
|
||
self.transforms = transforms | ||
|
||
# Create an R-tree to index the dataset | ||
self.index = Index(interleaved=False, properties=Property(dimension=3)) | ||
|
||
# Populate the dataset index | ||
i = 0 | ||
pathname = os.path.join(root, self.filename_glob) | ||
filename_regex = re.compile(self.filename_regex, re.VERBOSE) | ||
for filepath in glob.iglob(pathname, recursive=True): | ||
match = re.match(filename_regex, os.path.basename(filepath)) | ||
if match is not None: | ||
with xr.open_dataset(filepath, decode_times=True) as ds: | ||
# rioxarray expects spatial dimensions to be named x and y | ||
x_name, y_name = self._infer_spatial_coordinate_names(ds) | ||
ds = ds.rename({x_name: "x", y_name: "y"}) | ||
|
||
if crs is None: | ||
crs = ds.rio.crs | ||
|
||
if res is None: | ||
res = ds.rio.resolution()[0] | ||
|
||
(minx, miny, maxx, maxy) = ds.rio.bounds() | ||
|
||
if hasattr(ds, "time"): | ||
try: | ||
indices = ds.indexes["time"].to_datetimeindex() | ||
except AttributeError: | ||
indices = ds.indexes["time"] | ||
|
||
mint = indices.min().to_pydatetime().timestamp() | ||
maxt = indices.max().to_pydatetime().timestamp() | ||
else: | ||
mint = 0 | ||
maxt = sys.maxsize | ||
coords = (minx, maxx, miny, maxy, mint, maxt) | ||
self.index.insert(i, coords, filepath) | ||
i += 1 | ||
|
||
# collect all possible data variables if self.data_variables is None | ||
if not data_variables: | ||
data_variables_to_collect.extend(list(ds.data_vars)) | ||
|
||
if i == 0: | ||
msg = f"No {self.__class__.__name__} data was found in `root='{self.root}'`" | ||
raise FileNotFoundError(msg) | ||
|
||
if not data_variables: | ||
self.data_variables = list(set(data_variables_to_collect)) | ||
|
||
if not crs: | ||
self._crs = "EPSG:4326" | ||
else: | ||
self._crs = cast(CRS, crs) | ||
self.res = cast(float, res) | ||
|
||
def _infer_spatial_coordinate_names(self, ds): | ||
"""Infer the names of the spatial coordinates. | ||
|
||
Args: | ||
ds: Dataset or DataArray of which to infer the spatial coordinates | ||
|
||
Returns: | ||
x and y coordinate names | ||
""" | ||
for dim_name, dim in ds.coords.items(): | ||
if hasattr(dim, "units"): | ||
if any( | ||
[x in dim.units.lower() for x in ["degrees_north", "degree_north"]] | ||
): | ||
y_name = dim_name | ||
elif any( | ||
[x in dim.units.lower() for x in ["degrees_east", "degree_east"]] | ||
): | ||
x_name = dim_name | ||
|
||
if not x_name or not y_name: | ||
raise ValueError("Spatial Coordinate Units not found in Dataset") | ||
|
||
return x_name, y_name | ||
|
||
def __getitem__(self, query: BoundingBox) -> dict[str, Any]: | ||
"""Retrieve image/mask and metadata indexed by query. | ||
|
||
Args: | ||
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index | ||
|
||
Returns: | ||
sample of image/mask and metadata at that index | ||
|
||
Raises: | ||
IndexError: if query is not found in the index | ||
""" | ||
hits = self.index.intersection(tuple(query), objects=True) | ||
items = [hit.object for hit in hits] | ||
|
||
if not items: | ||
raise IndexError( | ||
f"query: {query} not found in index with bounds: {self.bounds}" | ||
) | ||
|
||
data_arrays: list["np.typing.NDArray"] = [] | ||
for item in items: | ||
with xr.open_dataset(item, decode_cf=True) as ds: | ||
# rioxarray expects spatial dimensions to be named x and y | ||
x_name, y_name = self._infer_spatial_coordinate_names(ds) | ||
ds = ds.rename({x_name: "x", y_name: "y"}) | ||
|
||
if not ds.rio.crs: | ||
ds.rio.write_crs(self._crs, inplace=True) | ||
elif ds.rio.crs != self._crs: | ||
ds = ds.rio.reproject(self._crs) | ||
|
||
# clip box ignores time dimension | ||
clipped = ds.rio.clip_box( | ||
minx=query.minx, miny=query.miny, maxx=query.maxx, maxy=query.maxy | ||
) | ||
# select time dimension | ||
if hasattr(ds, "time"): | ||
try: | ||
clipped["time"] = clipped.indexes["time"].to_datetimeindex() | ||
except AttributeError: | ||
clipped["time"] = clipped.indexes["time"] | ||
clipped = clipped.sel( | ||
time=slice( | ||
datetime.fromtimestamp(query.mint), | ||
datetime.fromtimestamp(query.maxt), | ||
) | ||
) | ||
|
||
for variable in self.data_variables: | ||
if hasattr(clipped, variable): | ||
data_arrays.append(clipped[variable].squeeze()) | ||
|
||
merged_data = torch.from_numpy( | ||
merge_arrays( | ||
data_arrays, bounds=(query.minx, query.miny, query.maxx, query.maxy) | ||
).data | ||
) | ||
sample = {"crs": self.crs, "bbox": query} | ||
|
||
merged_data = merged_data.to(self.dtype) | ||
if self.is_image: | ||
sample["image"] = merged_data | ||
else: | ||
sample["mask"] = merged_data | ||
|
||
if self.transforms is not None: | ||
sample = self.transforms(sample) | ||
|
||
return sample |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will need to determine the minimum version that works before merging