Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add methods for extracting true footprint for sampling valid data only #1881

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
b3c9efc
Add methods for extracting true footprint for sampling valid data only
adriantre Feb 14, 2024
ae7a22e
Add max_retries to get_random_bounding_box_check_valid_overlap
adriantre Feb 14, 2024
691dfc0
Set nodata value for raster if None
adriantre Feb 15, 2024
2daeaab
Handle spatial shift between band an multipolygonal footprint
adriantre Feb 15, 2024
aafae24
When merging dataset indices, also merge raster footprints
adriantre Feb 15, 2024
9102520
Use correct bounds in merge_indices
adriantre Feb 16, 2024
0e015a0
Add missing check for masks dimension
adriantre Aug 12, 2024
82895db
Run ruff after merge
adriantre Aug 12, 2024
6069d2f
Fix incorrect import
adriantre Aug 12, 2024
8cf0e8f
Set max_hole_size to be closed on raster footprint dynamically
adriantre Aug 12, 2024
a1e98f6
Compute nodata-mask across more channels than three
adriantre Aug 12, 2024
28ca6b7
Cap max_hole_size to 800 pixels
adriantre Aug 12, 2024
93b3395
Undo download: true in io_raw.yaml
adriantre Aug 12, 2024
9ec068c
Add failsafe for max_hole_size = 0
adriantre Aug 12, 2024
c7e85ef
Merge branch 'main' into feature/geosampler_discard_nodata
adriantre Aug 14, 2024
ce6f74a
Refactor valid_footprint unification and respect VectorDataset
adriantre Aug 14, 2024
6ddee5a
Make VectorDataset store filpath in dict within hit.object
adriantre Aug 14, 2024
8e16150
Introduce TypedDict for rtree.index object
adriantre Aug 14, 2024
f614f08
Suggest method for extracting filepaths from rtree
adriantre Aug 14, 2024
22ae05f
Use filespaths_intersecting_query in subclasses of RasterDataset
adriantre Aug 14, 2024
58821fa
Merge branch 'refs/heads/main' into feature/geosampler_discard_nodata
adriantre Aug 22, 2024
4332601
Fix wrong import from shapely
adriantre Aug 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions torchgeo/datasets/agrifieldnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pathlib
import re
from collections.abc import Callable, Iterable, Sequence
from typing import Any, cast
from typing import Any

import matplotlib.pyplot as plt
import torch
Expand Down Expand Up @@ -183,8 +183,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
"""
assert isinstance(self.paths, str | pathlib.Path)

hits = self.index.intersection(tuple(query), objects=True)
filepaths = cast(list[Path], [hit.object for hit in hits])
filepaths = self.filespaths_intersecting_query(query)

if not filepaths:
raise IndexError(
Expand Down
5 changes: 2 additions & 3 deletions torchgeo/datasets/chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pathlib
import sys
from collections.abc import Callable, Iterable, Sequence
from typing import Any, cast
from typing import Any

import fiona
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -604,8 +604,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
Raises:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(tuple(query), objects=True)
filepaths = cast(list[dict[str, str]], [hit.object for hit in hits])
filepaths = self.filespaths_intersecting_query(query)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chesapeake and enviroatlas inserts special objects into the index, so this breaks for them.


sample = {'image': [], 'mask': [], 'crs': self.crs, 'bounds': query}

Expand Down
5 changes: 2 additions & 3 deletions torchgeo/datasets/enviroatlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import sys
from collections.abc import Callable, Sequence
from typing import Any, cast
from typing import Any

import fiona
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -344,8 +344,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
Raises:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(tuple(query), objects=True)
filepaths = cast(list[dict[str, str]], [hit.object for hit in hits])
filepaths = self.filespaths_intersecting_query(query)

sample = {'image': [], 'mask': [], 'crs': self.crs, 'bounds': query}

Expand Down
81 changes: 71 additions & 10 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from rasterio.io import DatasetReader
from rasterio.vrt import WarpedVRT
from rtree.index import Index, Property
from rtree.index import Item as IndexItem
from shapely import MultiPolygon, Polygon
from torch import Tensor
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
Expand All @@ -35,10 +37,13 @@
from .errors import DatasetNotFoundError
from .utils import (
BoundingBox,
IndexData,
Path,
array_to_tensor,
calc_valid_data_footprint_from_datasource,
concat_samples,
disambiguate_timestamp,
get_valid_footprint_between_datasets,
merge_samples,
path_is_vsi,
)
Expand Down Expand Up @@ -322,6 +327,31 @@ def files(self) -> list[Path]:
# Sort the output to enforce deterministic behavior.
return sorted(files)

def filepath_for_hit(self, hit: IndexItem) -> Path:
"""Utility method for fetching filepath from rtee index item.

Alleviates the type casting from the user.

Args:
hit: Item of rtree Index

Returns:
filepath for the hit
"""
return cast(IndexData, hit.object)['filepath']

def filespaths_intersecting_query(self, query: BoundingBox) -> list[Path]:
"""Find all filepaths that intersects with query.

Args:
query: BoundingBox to intersect with

Returns:
list of all filepaths in rtree that intersects with query
"""
hits = self.index.intersection(tuple(query), objects=True)
return [self.filepath_for_hit(hit) for hit in hits]


class RasterDataset(GeoDataset):
"""Abstract base class for :class:`GeoDataset` stored as raster files."""
Expand Down Expand Up @@ -462,6 +492,15 @@ def __init__(
if crs is None:
crs = src.crs

valid_footprint = calc_valid_data_footprint_from_datasource(
masks=src.read_masks(),
src_crs=src.crs,
src_transform=src.transform,
raster_width=src.width,
raster_resolution_x=src.res[0],
dst_crs=crs,
)

with WarpedVRT(src, crs=crs) as vrt:
minx, miny, maxx, maxy = vrt.bounds
if res is None:
Expand All @@ -482,7 +521,13 @@ def __init__(
_, maxt = disambiguate_timestamp(stop, self.date_format)

coords = (minx, maxx, miny, maxy, mint, maxt)
self.index.insert(i, coords, filepath)

index_object: IndexData = {
'filepath': filepath,
'valid_footprint': valid_footprint,
}

self.index.insert(i, coords, index_object)
i += 1

if i == 0:
Expand Down Expand Up @@ -517,8 +562,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
Raises:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(tuple(query), objects=True)
filepaths = cast(list[Path], [hit.object for hit in hits])
filepaths = self.filespaths_intersecting_query(query)

if not filepaths:
raise IndexError(
Expand Down Expand Up @@ -609,11 +653,21 @@ def _load_warp_file(self, filepath: Path) -> DatasetReader:
Returns:
file handle of warped VRT
"""
# Todo: need to update meta kwarg `nodata` if it is not set.
# create memory-file?
# https://rasterio.groups.io/g/main/topic/change_the_nodata_value_in_a/28801885?p=
src = rasterio.open(filepath)

# Only warp if necessary
if src.crs != self.crs:
vrt = WarpedVRT(src, crs=self.crs)
valid_nodatavals = [val for val in src.nodatavals if val is not None]
if not valid_nodatavals: # Case for Sentinel2 L1C
nodata = 0.0 # Is it safe to assume? For Sentinel2 L1C it is 0.0
else:
# Get the first valid nodata value.
# Usualy the same value for all bands
nodata = valid_nodatavals[0]
vrt = WarpedVRT(src, nodata=nodata, crs=self.crs)
Copy link
Contributor Author

@adriantre adriantre Feb 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sentinel-2 has as far I can see no value set for nodata. I looked everywhere. Even enabling alpha-layer in the Sentinel-2 gdal driver, and looking through the MSK_QUALIT-file I found nothing.

Copy link
Contributor Author

@adriantre adriantre Feb 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change will set the nodata-value. Some datasets have other nodata-values, and we should probably let the user overwrite this, for example in their subclass of RasterDataset.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, the nodata is only overridden for the warped datasources. The non-warped (already correct CRS) are opened as is, but would also need to have the nodata overridden.

src.close()
return vrt
else:
Expand Down Expand Up @@ -708,7 +762,7 @@ def __init__(
date = match.group('date')
mint, maxt = disambiguate_timestamp(date, self.date_format)
coords = (minx, maxx, miny, maxy, mint, maxt)
self.index.insert(i, coords, filepath)
self.index.insert(i, coords, {'filepath': filepath})
i += 1

if i == 0:
Expand All @@ -729,8 +783,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
Raises:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(tuple(query), objects=True)
filepaths = [hit.object for hit in hits]
filepaths = self.filespaths_intersecting_query(query)

if not filepaths:
raise IndexError(
Expand Down Expand Up @@ -990,14 +1043,22 @@ def _merge_dataset_indices(self) -> None:
"""Create a new R-tree out of the individual indices from two datasets."""
i = 0
ds1, ds2 = self.datasets

for hit1 in ds1.index.intersection(ds1.index.bounds, objects=True):
for hit2 in ds2.index.intersection(hit1.bounds, objects=True):
box1 = BoundingBox(*hit1.bounds)
box2 = BoundingBox(*hit2.bounds)
box3 = box1 & box2
box_intersection = box1 & box2
# Skip 0 area overlap (unless 0 area dataset)
if box3.area > 0 or box1.area == 0 or box2.area == 0:
self.index.insert(i, tuple(box3))
if box_intersection.area > 0 or box1.area == 0 or box2.area == 0:
valid_footprint: Polygon | MultiPolygon | None = (
get_valid_footprint_between_datasets(
ds1.index if isinstance(ds1, RasterDataset) else None,
ds2.index if isinstance(ds2, RasterDataset) else None,
box_intersection,
)
)
self.index.insert(i, tuple(box_intersection), valid_footprint)
i += 1

if i == 0:
Expand Down
5 changes: 2 additions & 3 deletions torchgeo/datasets/globbiomass.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
import pathlib
from collections.abc import Callable, Iterable
from typing import Any, cast
from typing import Any

import matplotlib.pyplot as plt
import torch
Expand Down Expand Up @@ -192,8 +192,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
Raises:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(tuple(query), objects=True)
filepaths = cast(list[Path], [hit.object for hit in hits])
filepaths = self.filespaths_intersecting_query(query)

if not filepaths:
raise IndexError(
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/l7irish.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pathlib
import re
from collections.abc import Callable, Iterable, Sequence
from typing import Any, cast
from typing import Any

import matplotlib.pyplot as plt
import torch
Expand Down Expand Up @@ -94,7 +94,7 @@ def __init__(
filename_regex = re.compile(L7IrishImage.filename_regex, re.VERBOSE)
index = Index(interleaved=False, properties=Property(dimension=3))
for hit in self.index.intersection(self.index.bounds, objects=True):
dirname = os.path.dirname(cast(Path, hit.object))
dirname = os.path.dirname(self.filepath_for_hit(hit))
image = glob.glob(os.path.join(dirname, L7IrishImage.filename_glob))[0]
minx, maxx, miny, maxy, mint, maxt = hit.bounds
if match := re.match(filename_regex, os.path.basename(image)):
Expand Down
5 changes: 2 additions & 3 deletions torchgeo/datasets/landcoverai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import os
from collections.abc import Callable
from functools import lru_cache
from typing import Any, cast
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -253,8 +253,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
Raises:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(tuple(query), objects=True)
img_filepaths = cast(list[Path], [hit.object for hit in hits])
img_filepaths = self.filespaths_intersecting_query(query)
mask_filepaths = [
str(path).replace('images', 'masks') for path in img_filepaths
]
Expand Down
5 changes: 2 additions & 3 deletions torchgeo/datasets/openbuildings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pathlib
import sys
from collections.abc import Callable, Iterable
from typing import Any, cast
from typing import Any

import fiona
import fiona.transform
Expand Down Expand Up @@ -304,8 +304,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
Raises:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(tuple(query), objects=True)
filepaths = cast(list[Path], [hit.object for hit in hits])
filepaths = self.filespaths_intersecting_query(query)

if not filepaths:
raise IndexError(
Expand Down
5 changes: 2 additions & 3 deletions torchgeo/datasets/south_africa_crop_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pathlib
import re
from collections.abc import Callable, Iterable
from typing import Any, cast
from typing import Any

import matplotlib.pyplot as plt
import torch
Expand Down Expand Up @@ -164,8 +164,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
assert isinstance(self.paths, str | pathlib.Path)

# Get all files matching the given query
hits = self.index.intersection(tuple(query), objects=True)
filepaths = cast(list[Path], [hit.object for hit in hits])
filepaths = self.filespaths_intersecting_query(query)

if not filepaths:
raise IndexError(
Expand Down
Loading
Loading