From f75fe9571792745795af20b43c0ffd0e7639939c Mon Sep 17 00:00:00 2001 From: Adeel Hassan Date: Wed, 5 Jun 2024 10:34:25 -0400 Subject: [PATCH] add MultiRasterSource.from_stac() constructor --- .../data/raster_source/multi_raster_source.py | 91 ++++++++++++++++++- .../temporal_multi_raster_source.py | 7 ++ .../core/data/raster_source/xarray_source.py | 22 +++-- .../raster_source/test_multi_raster_source.py | 37 ++++++++ .../test_temporal_multi_raster_source.py | 4 + 5 files changed, 146 insertions(+), 15 deletions(-) diff --git a/rastervision_core/rastervision/core/data/raster_source/multi_raster_source.py b/rastervision_core/rastervision/core/data/raster_source/multi_raster_source.py index e35964955..f617e9c78 100644 --- a/rastervision_core/rastervision/core/data/raster_source/multi_raster_source.py +++ b/rastervision_core/rastervision/core/data/raster_source/multi_raster_source.py @@ -1,13 +1,17 @@ -from typing import Optional, Sequence, List, Tuple +from typing import TYPE_CHECKING, Optional, Sequence, Self, Tuple from pydantic import conint import numpy as np +from pystac import Item from rastervision.core.box import Box -from rastervision.core.data.raster_source import RasterSource -from rastervision.core.data.crs_transformer import CRSTransformer +from rastervision.core.data.raster_source import RasterSource, RasterioSource +from rastervision.core.data.raster_source.stac_config import subset_assets from rastervision.core.data.utils import all_equal +if TYPE_CHECKING: + from rastervision.core.data import RasterTransformer, CRSTransformer + class MultiRasterSource(RasterSource): """Merge multiple ``RasterSources`` by concatenating along channel dim.""" @@ -69,6 +73,83 @@ def __init__(self, self.validate_raster_sources() + @classmethod + def from_stac( + cls, + item: Item, + assets: list[str] | None, + primary_source_idx: conint(ge=0) = 0, + raster_transformers: list['RasterTransformer'] = [], + force_same_dtype: bool = False, + channel_order: Sequence[int] | None = None, + bbox: Box | tuple[int, int, int, int] | None = None, + bbox_map_coords: Box | tuple[int, int, int, int] | None = None, + allow_streaming: bool = False) -> Self: + """Construct a ``MultiRasterSource`` from a STAC Item. + + This creates a :class:`.RasterioSource` for each asset and puts all + the raster sources together into a ``MultiRasterSource``. If ``assets`` + is not specified, all the assets in the STAC item are used. + + Only assets that are readable by rasterio are supported. + + Args: + item: STAC Item. + assets: List of names of assets to use. If ``None``, all assets + present in the item will be used. Defaults to ``None``. + primary_source_idx (0 <= int < len(raster_sources)): Index of the + raster source whose CRS, dtype, and other attributes will + override those of the other raster sources. + raster_transformers: RasterTransformers to use to transform chips + after they are read. + force_same_dtype: If true, force all sub-chips to have the + same dtype as the primary_source_idx-th sub-chip. No careful + conversion is done, just a quick cast. Use with caution. + channel_order: List of indices of channels to extract from raw + imagery. Can be a subset of the available channels. If None, + all channels available in the image will be read. + Defaults to None. + bbox: User-specified crop of the extent. Can be :class:`.Box` or + (ymin, xmin, ymax, xmax) tuple. If None, the full extent + available in the source file is used. Mutually exclusive with + ``bbox_map_coords``. Defaults to ``None``. + bbox_map_coords: User-specified bbox in EPSG:4326 coords. Can be + :class:`.Box` or (ymin, xmin, ymax, xmax) tuple. Useful for + cropping the raster source so that only part of the raster is + read from. Mutually exclusive with ``bbox``. + Defaults to ``None``. + allow_streaming: Passed to :class:`.RasterioSource`. If ``False``, + assets will be downloaded. Defaults to ``True``. + """ + if bbox is not None and bbox_map_coords is not None: + raise ValueError('Specify either bbox or bbox_map_coords, ' + 'but not both.') + + if assets is not None: + item = subset_assets(item, assets) + + uris = [asset.href for asset in item.assets.values()] + raster_sources = [ + RasterioSource(uri, allow_streaming=allow_streaming) + for uri in uris + ] + + crs_transformer = raster_sources[primary_source_idx].crs_transformer + if bbox_map_coords is not None: + bbox_map_coords = Box(*bbox_map_coords) + bbox = crs_transformer.map_to_pixel(bbox_map_coords).normalize() + elif bbox is not None: + bbox = Box(*bbox) + + raster_source = MultiRasterSource( + raster_sources, + primary_source_idx=primary_source_idx, + raster_transformers=raster_transformers, + channel_order=channel_order, + force_same_dtype=force_same_dtype, + bbox=bbox) + return raster_source + def validate_raster_sources(self) -> None: """Validate sub-``RasterSources``. @@ -101,13 +182,13 @@ def dtype(self) -> np.dtype: return self.primary_source.dtype @property - def crs_transformer(self) -> CRSTransformer: + def crs_transformer(self) -> 'CRSTransformer': return self.primary_source.crs_transformer def _get_sub_chips(self, window: Box, out_shape: Optional[Tuple[int, int]] = None - ) -> List[np.ndarray]: + ) -> list[np.ndarray]: """Return chips from sub raster sources as a list. If all extents are identical, simply retrieves chips from each sub diff --git a/rastervision_core/rastervision/core/data/raster_source/temporal_multi_raster_source.py b/rastervision_core/rastervision/core/data/raster_source/temporal_multi_raster_source.py index 4c1173442..e3ad95052 100644 --- a/rastervision_core/rastervision/core/data/raster_source/temporal_multi_raster_source.py +++ b/rastervision_core/rastervision/core/data/raster_source/temporal_multi_raster_source.py @@ -70,6 +70,13 @@ def __init__(self, self.validate_raster_sources() + @classmethod + def from_stac(cls, *args, **kwargs): + """Not implemented for ``TemporalMultiRasterSource``.""" + raise NotImplementedError( + 'Create raster sources by calling MultiRasterSource.from_stac() ' + 'on each Item and then pass them to TemporalMultiRasterSource.') + def _get_chip(self, window: Box, out_shape: Optional[Tuple[int, int]] = None) -> np.ndarray: diff --git a/rastervision_core/rastervision/core/data/raster_source/xarray_source.py b/rastervision_core/rastervision/core/data/raster_source/xarray_source.py index 0b65611f4..2e81fe8c3 100644 --- a/rastervision_core/rastervision/core/data/raster_source/xarray_source.py +++ b/rastervision_core/rastervision/core/data/raster_source/xarray_source.py @@ -95,11 +95,11 @@ def __init__(self, @classmethod def from_stac( cls, - item_or_item_collection: Union['Item', 'ItemCollection'], - raster_transformers: List['RasterTransformer'] = [], - channel_order: Optional[Sequence[int]] = None, - bbox: Optional[Box] = None, - bbox_map_coords: Optional[Box] = None, + item_or_item_collection: 'Item | ItemCollection', + raster_transformers: list['RasterTransformer'] = [], + channel_order: Sequence[int] | None = None, + bbox: Box | tuple[int, int, int, int] | None = None, + bbox_map_coords: Box | tuple[int, int, int, int] | None = None, temporal: bool = False, allow_streaming: bool = False, stackstac_args: dict = dict(rescale=False)) -> 'XarraySource': @@ -113,13 +113,15 @@ def from_stac( imagery. Can be a subset of the available channels. If None, all channels available in the image will be read. Defaults to None. - bbox: User-specified crop of the extent. If None, the full extent + bbox: User-specified crop of the extent. Can be :class:`.Box` or + (ymin, xmin, ymax, xmax) tuple. If None, the full extent available in the source file is used. Mutually exclusive with ``bbox_map_coords``. Defaults to ``None``. - bbox_map_coords: User-specified bbox in EPSG:4326 coords of the - form (ymin, xmin, ymax, xmax). Useful for cropping the raster - source so that only part of the raster is read from. Mutually - exclusive with ``bbox``. Defaults to ``None``. + bbox_map_coords: User-specified bbox in EPSG:4326 coords. Can be + :class:`.Box` or (ymin, xmin, ymax, xmax) tuple. Useful for + cropping the raster source so that only part of the raster is + read from. Mutually exclusive with ``bbox``. + Defaults to ``None``. temporal: If True, data_array is expected to have a "time" dimension and the chips returned will be of shape (T, H, W, C). allow_streaming: If False, load the entire DataArray into memory. diff --git a/tests/core/data/raster_source/test_multi_raster_source.py b/tests/core/data/raster_source/test_multi_raster_source.py index 2ab4add74..5dbbae8cb 100644 --- a/tests/core/data/raster_source/test_multi_raster_source.py +++ b/tests/core/data/raster_source/test_multi_raster_source.py @@ -3,6 +3,7 @@ import numpy as np from xarray import DataArray +from pystac import Item from rastervision.pipeline.file_system import get_tmp_dir from rastervision.core.box import Box @@ -84,6 +85,12 @@ def test_build_temporal(self): class TestMultiRasterSource(unittest.TestCase): + def assertNoError(self, fn: Callable, msg: str = ''): + try: + fn() + except Exception: + self.fail(msg) + def setUp(self): self.tmp_dir_obj = get_tmp_dir() self.tmp_dir = self.tmp_dir_obj.name @@ -218,6 +225,36 @@ def test_temporal_sub_raster_sources(self): chip_expected[..., 4:] *= np.arange(4, dtype=np.uint8) np.testing.assert_array_equal(chip, chip_expected) + def test_from_stac(self): + item = Item.from_file(data_file_path('stac/item.json')) + + # avoid reading actual remote files + mock_raster_uri = data_file_path('ones.tif') + item.assets['red'].__setattr__('href', mock_raster_uri) + item.assets['green'].__setattr__('href', mock_raster_uri) + + # test bbox + bbox = Box(ymin=0, xmin=0, ymax=100, xmax=100) + rs = MultiRasterSource.from_stac( + item, assets=['red', 'green'], bbox=bbox) + self.assertEqual(rs.bbox, bbox) + + # test bbox_map_coords + bbox_map_coords = Box( + ymin=29.978710, xmin=31.134949, ymax=29.977309, xmax=31.136567) + rs = MultiRasterSource.from_stac( + item, assets=['red', 'green'], bbox_map_coords=bbox_map_coords) + self.assertEqual(rs.bbox, Box(ymin=50, xmin=50, ymax=206, xmax=206)) + + # test error if both bbox and bbox_map_coords specified + args = dict( + item=item, + assets=['red', 'green'], + bbox=bbox, + bbox_map_coords=bbox_map_coords) + self.assertRaises(ValueError, + lambda: MultiRasterSource.from_stac(**args)) + if __name__ == '__main__': unittest.main() diff --git a/tests/core/data/raster_source/test_temporal_multi_raster_source.py b/tests/core/data/raster_source/test_temporal_multi_raster_source.py index a5646e385..22cf04058 100644 --- a/tests/core/data/raster_source/test_temporal_multi_raster_source.py +++ b/tests/core/data/raster_source/test_temporal_multi_raster_source.py @@ -105,6 +105,10 @@ def test_getitem(self): ], dtype=dtype) np.testing.assert_array_equal(chip, chip_expected) + def test_from_stac(self): + self.assertRaises(NotImplementedError, + TemporalMultiRasterSource.from_stac) + if __name__ == '__main__': unittest.main()