Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MultiRasterSource.from_stac() constructor #2156

Merged
merged 1 commit into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from typing import Optional, Sequence, List, Tuple
from typing import TYPE_CHECKING, Optional, Sequence, Self, Tuple
from pydantic import conint

import numpy as np
from pystac import Item

from rastervision.core.box import Box
from rastervision.core.data.raster_source import RasterSource
from rastervision.core.data.crs_transformer import CRSTransformer
from rastervision.core.data.raster_source import RasterSource, RasterioSource
from rastervision.core.data.raster_source.stac_config import subset_assets
from rastervision.core.data.utils import all_equal

if TYPE_CHECKING:
from rastervision.core.data import RasterTransformer, CRSTransformer


class MultiRasterSource(RasterSource):
"""Merge multiple ``RasterSources`` by concatenating along channel dim."""
Expand Down Expand Up @@ -69,6 +73,83 @@ def __init__(self,

self.validate_raster_sources()

@classmethod
def from_stac(
cls,
item: Item,
assets: list[str] | None,
primary_source_idx: conint(ge=0) = 0,
raster_transformers: list['RasterTransformer'] = [],
force_same_dtype: bool = False,
channel_order: Sequence[int] | None = None,
bbox: Box | tuple[int, int, int, int] | None = None,
bbox_map_coords: Box | tuple[int, int, int, int] | None = None,
allow_streaming: bool = False) -> Self:
"""Construct a ``MultiRasterSource`` from a STAC Item.

This creates a :class:`.RasterioSource` for each asset and puts all
the raster sources together into a ``MultiRasterSource``. If ``assets``
is not specified, all the assets in the STAC item are used.

Only assets that are readable by rasterio are supported.

Args:
item: STAC Item.
assets: List of names of assets to use. If ``None``, all assets
present in the item will be used. Defaults to ``None``.
primary_source_idx (0 <= int < len(raster_sources)): Index of the
raster source whose CRS, dtype, and other attributes will
override those of the other raster sources.
raster_transformers: RasterTransformers to use to transform chips
after they are read.
force_same_dtype: If true, force all sub-chips to have the
same dtype as the primary_source_idx-th sub-chip. No careful
conversion is done, just a quick cast. Use with caution.
channel_order: List of indices of channels to extract from raw
imagery. Can be a subset of the available channels. If None,
all channels available in the image will be read.
Defaults to None.
bbox: User-specified crop of the extent. Can be :class:`.Box` or
(ymin, xmin, ymax, xmax) tuple. If None, the full extent
available in the source file is used. Mutually exclusive with
``bbox_map_coords``. Defaults to ``None``.
bbox_map_coords: User-specified bbox in EPSG:4326 coords. Can be
:class:`.Box` or (ymin, xmin, ymax, xmax) tuple. Useful for
cropping the raster source so that only part of the raster is
read from. Mutually exclusive with ``bbox``.
Defaults to ``None``.
allow_streaming: Passed to :class:`.RasterioSource`. If ``False``,
assets will be downloaded. Defaults to ``True``.
"""
if bbox is not None and bbox_map_coords is not None:
raise ValueError('Specify either bbox or bbox_map_coords, '
'but not both.')

if assets is not None:
item = subset_assets(item, assets)

uris = [asset.href for asset in item.assets.values()]
raster_sources = [
RasterioSource(uri, allow_streaming=allow_streaming)
for uri in uris
]

crs_transformer = raster_sources[primary_source_idx].crs_transformer
if bbox_map_coords is not None:
bbox_map_coords = Box(*bbox_map_coords)
bbox = crs_transformer.map_to_pixel(bbox_map_coords).normalize()
elif bbox is not None:
bbox = Box(*bbox)

raster_source = MultiRasterSource(
raster_sources,
primary_source_idx=primary_source_idx,
raster_transformers=raster_transformers,
channel_order=channel_order,
force_same_dtype=force_same_dtype,
bbox=bbox)
return raster_source

def validate_raster_sources(self) -> None:
"""Validate sub-``RasterSources``.

Expand Down Expand Up @@ -101,13 +182,13 @@ def dtype(self) -> np.dtype:
return self.primary_source.dtype

@property
def crs_transformer(self) -> CRSTransformer:
def crs_transformer(self) -> 'CRSTransformer':
return self.primary_source.crs_transformer

def _get_sub_chips(self,
window: Box,
out_shape: Optional[Tuple[int, int]] = None
) -> List[np.ndarray]:
) -> list[np.ndarray]:
"""Return chips from sub raster sources as a list.

If all extents are identical, simply retrieves chips from each sub
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ def __init__(self,

self.validate_raster_sources()

@classmethod
def from_stac(cls, *args, **kwargs):
"""Not implemented for ``TemporalMultiRasterSource``."""
raise NotImplementedError(
'Create raster sources by calling MultiRasterSource.from_stac() '
'on each Item and then pass them to TemporalMultiRasterSource.')

def _get_chip(self,
window: Box,
out_shape: Optional[Tuple[int, int]] = None) -> np.ndarray:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ def __init__(self,
@classmethod
def from_stac(
cls,
item_or_item_collection: Union['Item', 'ItemCollection'],
raster_transformers: List['RasterTransformer'] = [],
channel_order: Optional[Sequence[int]] = None,
bbox: Optional[Box] = None,
bbox_map_coords: Optional[Box] = None,
item_or_item_collection: 'Item | ItemCollection',
raster_transformers: list['RasterTransformer'] = [],
channel_order: Sequence[int] | None = None,
bbox: Box | tuple[int, int, int, int] | None = None,
bbox_map_coords: Box | tuple[int, int, int, int] | None = None,
temporal: bool = False,
allow_streaming: bool = False,
stackstac_args: dict = dict(rescale=False)) -> 'XarraySource':
Expand All @@ -113,13 +113,15 @@ def from_stac(
imagery. Can be a subset of the available channels. If None,
all channels available in the image will be read.
Defaults to None.
bbox: User-specified crop of the extent. If None, the full extent
bbox: User-specified crop of the extent. Can be :class:`.Box` or
(ymin, xmin, ymax, xmax) tuple. If None, the full extent
available in the source file is used. Mutually exclusive with
``bbox_map_coords``. Defaults to ``None``.
bbox_map_coords: User-specified bbox in EPSG:4326 coords of the
form (ymin, xmin, ymax, xmax). Useful for cropping the raster
source so that only part of the raster is read from. Mutually
exclusive with ``bbox``. Defaults to ``None``.
bbox_map_coords: User-specified bbox in EPSG:4326 coords. Can be
:class:`.Box` or (ymin, xmin, ymax, xmax) tuple. Useful for
cropping the raster source so that only part of the raster is
read from. Mutually exclusive with ``bbox``.
Defaults to ``None``.
temporal: If True, data_array is expected to have a "time"
dimension and the chips returned will be of shape (T, H, W, C).
allow_streaming: If False, load the entire DataArray into memory.
Expand Down
37 changes: 37 additions & 0 deletions tests/core/data/raster_source/test_multi_raster_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
from xarray import DataArray
from pystac import Item

from rastervision.pipeline.file_system import get_tmp_dir
from rastervision.core.box import Box
Expand Down Expand Up @@ -84,6 +85,12 @@ def test_build_temporal(self):


class TestMultiRasterSource(unittest.TestCase):
def assertNoError(self, fn: Callable, msg: str = ''):
try:
fn()
except Exception:
self.fail(msg)

def setUp(self):
self.tmp_dir_obj = get_tmp_dir()
self.tmp_dir = self.tmp_dir_obj.name
Expand Down Expand Up @@ -218,6 +225,36 @@ def test_temporal_sub_raster_sources(self):
chip_expected[..., 4:] *= np.arange(4, dtype=np.uint8)
np.testing.assert_array_equal(chip, chip_expected)

def test_from_stac(self):
item = Item.from_file(data_file_path('stac/item.json'))

# avoid reading actual remote files
mock_raster_uri = data_file_path('ones.tif')
item.assets['red'].__setattr__('href', mock_raster_uri)
item.assets['green'].__setattr__('href', mock_raster_uri)

# test bbox
bbox = Box(ymin=0, xmin=0, ymax=100, xmax=100)
rs = MultiRasterSource.from_stac(
item, assets=['red', 'green'], bbox=bbox)
self.assertEqual(rs.bbox, bbox)

# test bbox_map_coords
bbox_map_coords = Box(
ymin=29.978710, xmin=31.134949, ymax=29.977309, xmax=31.136567)
rs = MultiRasterSource.from_stac(
item, assets=['red', 'green'], bbox_map_coords=bbox_map_coords)
self.assertEqual(rs.bbox, Box(ymin=50, xmin=50, ymax=206, xmax=206))

# test error if both bbox and bbox_map_coords specified
args = dict(
item=item,
assets=['red', 'green'],
bbox=bbox,
bbox_map_coords=bbox_map_coords)
self.assertRaises(ValueError,
lambda: MultiRasterSource.from_stac(**args))


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def test_getitem(self):
], dtype=dtype)
np.testing.assert_array_equal(chip, chip_expected)

def test_from_stac(self):
self.assertRaises(NotImplementedError,
TemporalMultiRasterSource.from_stac)


if __name__ == '__main__':
unittest.main()