Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add methods for extracting true footprint for sampling valid data only #1881

Open
wants to merge 22 commits into
base: main
Choose a base branch
from

Conversation

adriantre
Copy link
Contributor

@adriantre adriantre commented Feb 14, 2024

Fix #1330

RasterDatasets may contain nodata regions due to projecting all file to the same CRS, and due to eventual inherit nodata regions in the images.
When IntersectionDataset joins this with VectorDataset, this may yield

  1. false positive samples (bad for learning)
  2. empty negative samples (may be bad for learning)

The solution can be summarised as:

  • In RasterDataset, when opening each file, extract footprint and add to rtree index object
  • In IntersectionDataset._merge_dataset_indices copy over the footprint to the new rtree index.
  • In the same method, could optimise by minimizing bbox to cover only actual intersection of valid data.
  • In RandomGeoSampler.__iter__, use this footprint to validate that sample bbox actually overlaps, and don't yield until a valid box is found.
  • Enable the same for GridGeoSampler (probably other PR)
  • Remove label mask for eventual nodata-regions that outside regions in boundary. (As the criteria above is overlaps and not contains, corners of the resulting sample may still contain nodata, while the label mask still may cover this.) (probably other PR)
  • Add ability to balance positive and negative samples. The VectorData can be intersected with the raster valid data footprint in the GeoSampler to facilitate balancing positives and negatives. Right now torchgeo gives the user no control of this. (probably other PR, see Add BalancedRandomGeoSampler balancing positives and negatives #1883)

Useful resources:
Rasterio nodata masks:
https://rasterio.readthedocs.io/en/latest/topics/masks.html#nodata-masks

Extract valid data footprint as vector
https://gist.github.com/sgillies/9713809

Reproject valid data footprint with rasterio
https://geopandas.org/en/stable/docs/user_guide/reproject_fiona.html#rasterio-example

@github-actions github-actions bot added datasets Geospatial or benchmark datasets samplers Samplers for indexing datasets labels Feb 14, 2024
@adamjstewart adamjstewart added this to the 0.6.0 milestone Feb 14, 2024
torchgeo/datasets/utils.py Outdated Show resolved Hide resolved
# Read valid/nodata-mask
mask = src.read_masks()
# Close holes
sieved_mask = sieve(mask, 500)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we didn't have to know the minimum size accepted for any raster... Maybe it can have a factor to compute this based on the mask shape?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The target here is a polygon with no holes. Probably 500 is never too big (22 x 22 pixels). Could increase it, too.

If there are more (bigger) holes left, we could close them using shapely after converting to vector. What do you think?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something I thought of -- it if possible -- was to use the size of the window to compute the size to close the polygons

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, you are probably on to something. I'm struggling to decide what effect it might have if we set size too big or too small

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought of something like: closing holes that are bigger than the window size, we can still be getting some cases retrieving samples with just nodata ... considering the multi-polygon thing here.

One example is if we beforehand masks Sentinel-2 clouds as no data when closing the holes considering a size bigger than the window size we still can get random samples inside/within this nodata regions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds smart!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing is that the desired patch_size to be used by sampler is not available at this point in the code. This happens on RasterDataset init, separate from the Sampler init.

torchgeo/datasets/utils.py Outdated Show resolved Hide resolved
torchgeo/samplers/single.py Outdated Show resolved Hide resolved
# Get the first valid nodata value.
# Usualy the same value for all bands
nodata = valid_nodatavals[0]
vrt = WarpedVRT(src, nodata=nodata, crs=self.crs)
Copy link
Contributor Author

@adriantre adriantre Feb 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sentinel-2 has as far I can see no value set for nodata. I looked everywhere. Even enabling alpha-layer in the Sentinel-2 gdal driver, and looking through the MSK_QUALIT-file I found nothing.

Copy link
Contributor Author

@adriantre adriantre Feb 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change will set the nodata-value. Some datasets have other nodata-values, and we should probably let the user overwrite this, for example in their subclass of RasterDataset.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, the nodata is only overridden for the warped datasources. The non-warped (already correct CRS) are opened as is, but would also need to have the nodata overridden.

torchgeo/datasets/geo.py Outdated Show resolved Hide resolved
# Read valid/nodata-mask
mask = src.read_masks()
# Close holes
sieved_mask = sieve(mask, 500)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds smart!

@adamjstewart adamjstewart mentioned this pull request Apr 1, 2024
5 tasks
@adamjstewart
Copy link
Collaborator

I'm trying to get back up to speed on this. Can you rebase? Also, would be good to test with the new I/O Bench dataset I added. See https://torchgeo.readthedocs.io/en/latest/user/contributing.html#i-o-benchmarking for instructions.

@adriantre adriantre force-pushed the feature/geosampler_discard_nodata branch from bbb21d5 to 0e015a0 Compare August 12, 2024 11:09
@adriantre
Copy link
Contributor Author

I'm trying to get back up to speed on this. Can you rebase? Also, would be good to test with the new I/O Bench dataset I added. See https://torchgeo.readthedocs.io/en/latest/user/contributing.html#i-o-benchmarking for instructions.

I rebased now, have a look. I'll look into your new dataset!

@github-actions github-actions bot added the testing Continuous integration testing label Aug 12, 2024
@github-actions github-actions bot removed the testing Continuous integration testing label Aug 12, 2024
@adriantre
Copy link
Contributor Author

adriantre commented Aug 12, 2024

I did a quick test.

My initial implementation is slower on your benchmark dataset io_raw. One thought on it is that it could be useful to benchmark the effectiveness of the training. While a change can be slower per epoch, the model can possibly obtain a lower loss at fewer epochs.

I assume the slowdown is due to two reasons, as setup and dataloader_next is the bad apples:

  1. Computing the true_footprints takes about 2 seconds. (This can be precomputed and stored/cached)
  2. My RandomGeoSampler is naive, as it just tries until it finds a random box that overlaps with true_footprint, which may result in thousands of re-tries.

Having patch_size=128 and num_workers = 0:
6 seconds slower in total.
Reduces the number of no-data only samples from 1206 to 0.

What threw me off a little was that it seems like a lot of the nodata pixels for some reason get value1., is this correct?
This can e.g. be observed in RasterDataset.__getitem__ under if self.is_image.

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

@adriantre adriantre requested a review from johnnv1 August 13, 2024 13:13
@@ -604,8 +604,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
Raises:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(tuple(query), objects=True)
filepaths = cast(list[dict[str, str]], [hit.object for hit in hits])
filepaths = self.filespaths_intersecting_query(query)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chesapeake and enviroatlas inserts special objects into the index, so this breaks for them.

@adamjstewart
Copy link
Collaborator

So if I'm understanding correctly, the current state of this PR is that the naïve implementation of this is actually slower than loading all nodata pixels? In that case, we may want to push this to the 0.7.0 milestone so we have more time to ensure that everything is bug-free.

May be worth looking into how Raster Vision does this, they've been around longer than us: #1330 (comment)

P.S. Need to resolve merge conflicts.

@adriantre
Copy link
Contributor Author

adriantre commented Aug 21, 2024

If the only goal is low speed per epoch, then yes. But if the goal is low speed for the model to converge, this might still be faster.

In my opinion, this PR is the first step towards #1883 which is really important for faster overall training.

I'm fine with pushing this to 0.7, but would love some dialogue on how this could evolve.

@adamjstewart
Copy link
Collaborator

I guess I would need to see additional tests to verify if it is actually faster for model convergence. Unfortunately that's even harder to time since it's inherently stochastic. If we can't also make it faster to load data, it seems unlikely to be merged.

@adamjstewart adamjstewart removed this from the 0.6.0 milestone Aug 21, 2024
@adriantre
Copy link
Contributor Author

adriantre commented Aug 22, 2024

There's no mechanism for on-the-fly filtering based on nodata content in RV GeoDatasets. However, RV also provides a way to pre-chip a dataset and then treat it as a non-geospatial dataset. And that does allow filtering based on nodata percentage via a nodata_threshold option.

from #1330 (comment)

This "pre-chipping" is my idea for how this could be solved, but my idea still keeps it as a GeoDataset.

Alternatives
One good alternative/evolution that I see is replacing the rtree-index created in IntersectionDataset. Instead of the current rtree-index, add the vectorized raster-footprints and/or vector data (actual features/shapes from e.g. shapefiles) to GeoPandas GeoDataFrames, and use their implementation of rtree to rapidly find areas where they intersect. The temporal dimension would not be part of the index (not supported), but can be added as a filter once spatial matches are found. At the end of this article they split their polygons [raster-footprints] into small grid cells [of patch size] and can rapidly retrieve cells [samples] where there is overlap with the other data set [labels], or no overlap. The underlying rtree-indices in RasterDatasets and VectorDatasets will be the same, and still be used to read the data upon sampling.

from first post in #1883

(In this nodata-PR it would only use the raster footprint, but this could then be expanded to any number of label classes)

I would need an indication that this is a good/promising solution before starting to dive into it.

@adamjstewart
Copy link
Collaborator

Both of these would be pretty big core changes and would need a lot of justification. Losing explicit temporal sampling would be unfortunate, as we're trying to improve time series support in the next 6 months. But the ability to check for overlap of arbitrary polygons (not just boxes) would be very nice. I am worried about the decrease in performance though, especially as you approach millions of images. A formal mathematical R-tree will always be significantly faster than a hacky R-tree that supports arbitrary shapes. Basically, I don't know the best way to do this, and I don't want you to sink tens of hours into something that may never be accepted, but I also wish someone would solve this problem and don't have the time/need myself.

@adamjstewart adamjstewart added this to the 0.7.0 milestone Aug 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
datasets Geospatial or benchmark datasets samplers Samplers for indexing datasets
Projects
None yet
Development

Successfully merging this pull request may close these issues.

How to avoid nodata-only patches
3 participants