-
Notifications
You must be signed in to change notification settings - Fork 347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add methods for extracting true footprint for sampling valid data only #1881
base: main
Are you sure you want to change the base?
Changes from all commits
b3c9efc
ae7a22e
691dfc0
2daeaab
aafae24
9102520
0e015a0
82895db
6069d2f
8cf0e8f
a1e98f6
28ca6b7
93b3395
9ec068c
c7e85ef
ce6f74a
6ddee5a
8e16150
f614f08
22ae05f
58821fa
4332601
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,8 @@ | |
from rasterio.io import DatasetReader | ||
from rasterio.vrt import WarpedVRT | ||
from rtree.index import Index, Property | ||
from rtree.index import Item as IndexItem | ||
from shapely.geometry import MultiPolygon, Polygon | ||
from torch import Tensor | ||
from torch.utils.data import Dataset | ||
from torchvision.datasets import ImageFolder | ||
|
@@ -36,10 +38,13 @@ | |
from .errors import DatasetNotFoundError | ||
from .utils import ( | ||
BoundingBox, | ||
IndexData, | ||
Path, | ||
array_to_tensor, | ||
calc_valid_data_footprint_from_datasource, | ||
concat_samples, | ||
disambiguate_timestamp, | ||
get_valid_footprint_between_datasets, | ||
merge_samples, | ||
path_is_vsi, | ||
) | ||
|
@@ -325,6 +330,31 @@ def files(self) -> list[Path]: | |
# Sort the output to enforce deterministic behavior. | ||
return sorted(files) | ||
|
||
def filepath_for_hit(self, hit: IndexItem) -> Path: | ||
"""Utility method for fetching filepath from rtee index item. | ||
|
||
Alleviates the type casting from the user. | ||
|
||
Args: | ||
hit: Item of rtree Index | ||
|
||
Returns: | ||
filepath for the hit | ||
""" | ||
return cast(IndexData, hit.object)['filepath'] | ||
|
||
def filespaths_intersecting_query(self, query: BoundingBox) -> list[Path]: | ||
"""Find all filepaths that intersects with query. | ||
|
||
Args: | ||
query: BoundingBox to intersect with | ||
|
||
Returns: | ||
list of all filepaths in rtree that intersects with query | ||
""" | ||
hits = self.index.intersection(tuple(query), objects=True) | ||
return [self.filepath_for_hit(hit) for hit in hits] | ||
|
||
|
||
class RasterDataset(GeoDataset): | ||
"""Abstract base class for :class:`GeoDataset` stored as raster files.""" | ||
|
@@ -465,6 +495,15 @@ def __init__( | |
if crs is None: | ||
crs = src.crs | ||
|
||
valid_footprint = calc_valid_data_footprint_from_datasource( | ||
masks=src.read_masks(), | ||
src_crs=src.crs, | ||
src_transform=src.transform, | ||
raster_width=src.width, | ||
raster_resolution_x=src.res[0], | ||
dst_crs=crs, | ||
) | ||
|
||
with WarpedVRT(src, crs=crs) as vrt: | ||
minx, miny, maxx, maxy = vrt.bounds | ||
if res is None: | ||
|
@@ -485,7 +524,13 @@ def __init__( | |
_, maxt = disambiguate_timestamp(stop, self.date_format) | ||
|
||
coords = (minx, maxx, miny, maxy, mint, maxt) | ||
self.index.insert(i, coords, filepath) | ||
|
||
index_object: IndexData = { | ||
'filepath': filepath, | ||
'valid_footprint': valid_footprint, | ||
} | ||
|
||
self.index.insert(i, coords, index_object) | ||
i += 1 | ||
|
||
if i == 0: | ||
|
@@ -520,8 +565,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: | |
Raises: | ||
IndexError: if query is not found in the index | ||
""" | ||
hits = self.index.intersection(tuple(query), objects=True) | ||
filepaths = cast(list[Path], [hit.object for hit in hits]) | ||
filepaths = self.filespaths_intersecting_query(query) | ||
|
||
if not filepaths: | ||
raise IndexError( | ||
|
@@ -612,11 +656,21 @@ def _load_warp_file(self, filepath: Path) -> DatasetReader: | |
Returns: | ||
file handle of warped VRT | ||
""" | ||
# Todo: need to update meta kwarg `nodata` if it is not set. | ||
# create memory-file? | ||
# https://rasterio.groups.io/g/main/topic/change_the_nodata_value_in_a/28801885?p= | ||
src = rasterio.open(filepath) | ||
|
||
# Only warp if necessary | ||
if src.crs != self.crs: | ||
vrt = WarpedVRT(src, crs=self.crs) | ||
valid_nodatavals = [val for val in src.nodatavals if val is not None] | ||
if not valid_nodatavals: # Case for Sentinel2 L1C | ||
nodata = 0.0 # Is it safe to assume? For Sentinel2 L1C it is 0.0 | ||
else: | ||
# Get the first valid nodata value. | ||
# Usualy the same value for all bands | ||
nodata = valid_nodatavals[0] | ||
vrt = WarpedVRT(src, nodata=nodata, crs=self.crs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sentinel-2 has as far I can see no value set for nodata. I looked everywhere. Even enabling alpha-layer in the Sentinel-2 gdal driver, and looking through the MSK_QUALIT-file I found nothing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change will set the nodata-value. Some datasets have other nodata-values, and we should probably let the user overwrite this, for example in their subclass of RasterDataset. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently, the nodata is only overridden for the warped datasources. The non-warped (already correct CRS) are opened as is, but would also need to have the nodata overridden. |
||
src.close() | ||
return vrt | ||
else: | ||
|
@@ -711,7 +765,7 @@ def __init__( | |
date = match.group('date') | ||
mint, maxt = disambiguate_timestamp(date, self.date_format) | ||
coords = (minx, maxx, miny, maxy, mint, maxt) | ||
self.index.insert(i, coords, filepath) | ||
self.index.insert(i, coords, {'filepath': filepath}) | ||
i += 1 | ||
|
||
if i == 0: | ||
|
@@ -732,8 +786,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: | |
Raises: | ||
IndexError: if query is not found in the index | ||
""" | ||
hits = self.index.intersection(tuple(query), objects=True) | ||
filepaths = [hit.object for hit in hits] | ||
filepaths = self.filespaths_intersecting_query(query) | ||
|
||
if not filepaths: | ||
raise IndexError( | ||
|
@@ -993,14 +1046,22 @@ def _merge_dataset_indices(self) -> None: | |
"""Create a new R-tree out of the individual indices from two datasets.""" | ||
i = 0 | ||
ds1, ds2 = self.datasets | ||
|
||
for hit1 in ds1.index.intersection(ds1.index.bounds, objects=True): | ||
for hit2 in ds2.index.intersection(hit1.bounds, objects=True): | ||
box1 = BoundingBox(*hit1.bounds) | ||
box2 = BoundingBox(*hit2.bounds) | ||
box3 = box1 & box2 | ||
box_intersection = box1 & box2 | ||
# Skip 0 area overlap (unless 0 area dataset) | ||
if box3.area > 0 or box1.area == 0 or box2.area == 0: | ||
self.index.insert(i, tuple(box3)) | ||
if box_intersection.area > 0 or box1.area == 0 or box2.area == 0: | ||
valid_footprint: Polygon | MultiPolygon | None = ( | ||
get_valid_footprint_between_datasets( | ||
ds1.index if isinstance(ds1, RasterDataset) else None, | ||
ds2.index if isinstance(ds2, RasterDataset) else None, | ||
box_intersection, | ||
) | ||
) | ||
self.index.insert(i, tuple(box_intersection), valid_footprint) | ||
i += 1 | ||
|
||
if i == 0: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
chesapeake and enviroatlas inserts special objects into the index, so this breaks for them.