Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ StackSTACMosaicIterDataPipe to mosaic tiles into one piece #63

Merged
merged 4 commits into from
Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@

```{eval-rst}
.. automodule:: zen3geo.datapipes.stackstac
.. autoclass:: zen3geo.datapipes.StackSTACMosaic
.. autoclass:: zen3geo.datapipes.stackstac.StackSTACMosaicIterDataPipe
.. autoclass:: zen3geo.datapipes.StackSTACStacker
.. autoclass:: zen3geo.datapipes.stackstac.StackSTACStackerIterDataPipe
:show-inheritance:
Expand Down
5 changes: 4 additions & 1 deletion zen3geo/datapipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,8 @@
PySTACAPISearchIterDataPipe as PySTACAPISearch,
)
from zen3geo.datapipes.rioxarray import RioXarrayReaderIterDataPipe as RioXarrayReader
from zen3geo.datapipes.stackstac import StackSTACStackerIterDataPipe as StackSTACStacker
from zen3geo.datapipes.stackstac import (
StackSTACMosaicIterDataPipe as StackSTACMosaic,
StackSTACStackerIterDataPipe as StackSTACStacker,
)
from zen3geo.datapipes.xbatcher import XbatcherSlicerIterDataPipe as XbatcherSlicer
90 changes: 90 additions & 0 deletions zen3geo/datapipes/stackstac.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,96 @@
from torchdata.datapipes.iter import IterDataPipe


@functional_datapipe("mosaic_dataarray")
class StackSTACMosaicIterDataPipe(IterDataPipe[xr.DataArray]):
"""
Takes :py:class:`xarray.DataArray` objects, flattens a dimension by picking
the first valid pixel, to yield mosaicked :py:class:`xarray.DataArray`
objects (functional name: ``mosaic_dataarray``).

Parameters
----------
source_datapipe : IterDataPipe[xarray.DataArray]
A DataPipe that contains :py:class:`xarray.DataArray` objects, with
e.g. dimensions ("time", "band", "y", "x").

kwargs : Optional
Extra keyword arguments to pass to :py:func:`stackstac.mosaic`.

Yields
------
dataarray : xarray.DataArray
An :py:class:`xarray.DataArray` that has been mosaicked with e.g.
dimensions ("band", "y", "x").

Raises
------
ModuleNotFoundError
If ``stackstac`` is not installed. See
:doc:`install instructions for stackstac <stackstac:index>`, (e.g. via
``pip install stackstac``) before using this class.

Example
-------
>>> import pytest
>>> import xarray as xr
>>> pystac = pytest.importorskip("pystac")
>>> stackstac = pytest.importorskip("stackstac")
...
>>> from torchdata.datapipes.iter import IterableWrapper
>>> from zen3geo.datapipes import StackSTACMosaic
...
>>> # Get list of ALOS DEM tiles to mosaic together later
>>> item_urls = [
... "https://planetarycomputer.microsoft.com/api/stac/v1/collections/alos-dem/items/ALPSMLC30_N022E113_DSM",
... "https://planetarycomputer.microsoft.com/api/stac/v1/collections/alos-dem/items/ALPSMLC30_N022E114_DSM",
... ]
>>> stac_items = [pystac.Item.from_file(href=url) for url in item_urls]
>>> dataarray = stackstac.stack(items=stac_items)
>>> assert dataarray.sizes == {'time': 2, 'band': 1, 'y': 3600, 'x': 7200}
...
>>> # Mosaic different tiles in an xarray.DataArray using DataPipe
>>> dp = IterableWrapper(iterable=[dataarray])
>>> dp_mosaic = dp.mosaic_dataarray()
...
>>> # Loop or iterate over the DataPipe stream
>>> it = iter(dp_mosaic)
>>> dataarray = next(it)
>>> print(dataarray.sizes)
Frozen({'band': 1, 'y': 3600, 'x': 7200})
>>> print(dataarray.coords)
Coordinates:
* band (band) <U4 'data'
* x (x) float64 113.0 113.0 113.0 113.0 ... 115.0 115.0 115.0 115.0
* y (y) float64 23.0 23.0 23.0 23.0 23.0 ... 22.0 22.0 22.0 22.0
...
>>> print(dataarray.attrs["spec"])
RasterSpec(epsg=4326, bounds=(113.0, 22.0, 115.0, 23.0), resolutions_xy=(0.0002777777777777778, 0.0002777777777777778))
"""

def __init__(
self,
source_datapipe: IterDataPipe[xr.DataArray],
**kwargs: Optional[Dict[str, Any]]
) -> None:
if stackstac is None:
raise ModuleNotFoundError(
"Package `stackstac` is required to be installed to use this datapipe. "
"Please use `pip install stackstac` or "
"`conda install -c conda-forge stackstac` "
"to install the package"
)
self.source_datapipe: IterDataPipe = source_datapipe
self.kwargs = kwargs

def __iter__(self) -> Iterator[xr.DataArray]:
for dataarray in self.source_datapipe:
yield stackstac.mosaic(arr=dataarray, **self.kwargs)

def __len__(self) -> int:
return len(self.source_datapipe)


@functional_datapipe("stack_stac_items")
class StackSTACStackerIterDataPipe(IterDataPipe[xr.DataArray]):
"""
Expand Down
15 changes: 15 additions & 0 deletions zen3geo/tests/test_datapipes_stackstac.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""
Tests for stackstac datapipes.
"""
import numpy as np
import pytest
import xarray as xr
from torchdata.datapipes.iter import IterableWrapper

from zen3geo.datapipes import StackSTACStacker
Expand All @@ -10,6 +12,19 @@
stackstac = pytest.importorskip("stackstac")

# %%
def test_stackstac_mosaic():
"""
Ensure that StackSTACMosaic works to mosaic tiles within a 4D
xarray.DataArray to a 3D xarray.DataArray.
"""
datacube: xr.DataArray = xr.DataArray(
data=np.ones(shape=(3, 1, 32, 32)), dims=["tile", "band", "y", "x"]
)
dataarray = stackstac.mosaic(arr=datacube, dim="tile")
assert dataarray.sizes == {"band": 1, "y": 32, "x": 32}
assert dataarray.sum() == 1 * 32 * 32


def test_stackstac_stacker():
"""
Ensure that StackSTACStacker works to stack multiple bands within a STAC
Expand Down