diff --git a/tests/data/rwanda_field_boundary/data.py b/tests/data/rwanda_field_boundary/data.py old mode 100644 new mode 100755 index a3522e8c962..bf9954e8935 --- a/tests/data/rwanda_field_boundary/data.py +++ b/tests/data/rwanda_field_boundary/data.py @@ -3,99 +3,46 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import hashlib import os -import shutil import numpy as np import rasterio +from rasterio.crs import CRS +from rasterio.transform import Affine dates = ('2021_03', '2021_04', '2021_08', '2021_10', '2021_11', '2021_12') all_bands = ('B01', 'B02', 'B03', 'B04') SIZE = 32 -NUM_SAMPLES = 5 +DTYPE = np.uint16 +NUM_SAMPLES = 1 np.random.seed(0) - -def create_mask(fn: str) -> None: - profile = { - 'driver': 'GTiff', - 'dtype': 'uint8', - 'nodata': 0.0, - 'width': SIZE, - 'height': SIZE, - 'count': 1, - 'crs': 'epsg:3857', - 'compress': 'lzw', - 'predictor': 2, - 'transform': rasterio.Affine(10.0, 0.0, 0.0, 0.0, -10.0, 0.0), - 'blockysize': 32, - 'tiled': False, - 'interleave': 'band', - } - with rasterio.open(fn, 'w', **profile) as f: - f.write(np.random.randint(0, 2, size=(SIZE, SIZE), dtype=np.uint8), 1) - - -def create_img(fn: str) -> None: - profile = { - 'driver': 'GTiff', - 'dtype': 'uint16', - 'nodata': 0.0, - 'width': SIZE, - 'height': SIZE, - 'count': 1, - 'crs': 'epsg:3857', - 'compress': 'lzw', - 'predictor': 2, - 'blockysize': 16, - 'transform': rasterio.Affine(10.0, 0.0, 0.0, 0.0, -10.0, 0.0), - 'tiled': False, - 'interleave': 'band', - } - with rasterio.open(fn, 'w', **profile) as f: - f.write(np.random.randint(0, 2, size=(SIZE, SIZE), dtype=np.uint16), 1) - - -if __name__ == '__main__': - # Train and test images - for split in ('train', 'test'): - for i in range(NUM_SAMPLES): - for date in dates: - directory = os.path.join( - f'nasa_rwanda_field_boundary_competition_source_{split}', - f'nasa_rwanda_field_boundary_competition_source_{split}_{i:02d}_{date}', # noqa: E501 - ) - os.makedirs(directory, exist_ok=True) - for band in all_bands: - create_img(os.path.join(directory, f'{band}.tif')) - - # Create collections.json, this isn't used by the dataset but is checked to - # exist - with open( - f'nasa_rwanda_field_boundary_competition_source_{split}/collections.json', - 'w', - ) as f: - f.write('Not used') - - # Train labels - for i in range(NUM_SAMPLES): - directory = os.path.join( - 'nasa_rwanda_field_boundary_competition_labels_train', - f'nasa_rwanda_field_boundary_competition_labels_train_{i:02d}', - ) - os.makedirs(directory, exist_ok=True) - create_mask(os.path.join(directory, 'raster_labels.tif')) - - # Create directories and compute checksums - for filename in [ - 'nasa_rwanda_field_boundary_competition_source_train', - 'nasa_rwanda_field_boundary_competition_source_test', - 'nasa_rwanda_field_boundary_competition_labels_train', - ]: - shutil.make_archive(filename, 'gztar', '.', filename) - # Compute checksums - with open(f'{filename}.tar.gz', 'rb') as f: - md5 = hashlib.md5(f.read()).hexdigest() - print(f'{filename}: {md5}') +profile = { + 'driver': 'GTiff', + 'dtype': DTYPE, + 'width': SIZE, + 'height': SIZE, + 'count': 1, + 'crs': CRS.from_epsg(3857), + 'transform': Affine( + 4.77731426716, 0.0, 3374518.037700199, 0.0, -4.77731426716, -168438.54642526805 + ), +} +Z = np.random.randint(np.iinfo(DTYPE).max, size=(SIZE, SIZE), dtype=DTYPE) + +for sample in range(NUM_SAMPLES): + for split in ['train', 'test']: + for date in dates: + path = os.path.join('source', split, date) + os.makedirs(path, exist_ok=True) + for band in all_bands: + file = os.path.join(path, f'{sample:02}_{band}.tif') + with rasterio.open(file, 'w', **profile) as src: + src.write(Z, 1) + + path = os.path.join('labels', 'train') + os.makedirs(path, exist_ok=True) + file = os.path.join(path, f'{sample:02}.tif') + with rasterio.open(file, 'w', **profile) as src: + src.write(Z, 1) diff --git a/tests/data/rwanda_field_boundary/labels/train/00.tif b/tests/data/rwanda_field_boundary/labels/train/00.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/labels/train/00.tif differ diff --git a/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_labels_train.tar.gz b/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_labels_train.tar.gz deleted file mode 100644 index ffa98bb53d6..00000000000 Binary files a/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_labels_train.tar.gz and /dev/null differ diff --git a/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_test.tar.gz b/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_test.tar.gz deleted file mode 100644 index a834f66bf38..00000000000 Binary files a/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_test.tar.gz and /dev/null differ diff --git a/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_train.tar.gz b/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_train.tar.gz deleted file mode 100644 index 8239f70c200..00000000000 Binary files a/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_train.tar.gz and /dev/null differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_03/00_B01.tif b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_03/00_B02.tif b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_03/00_B03.tif b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_03/00_B04.tif b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_04/00_B01.tif b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_04/00_B02.tif b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_04/00_B03.tif b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_04/00_B04.tif b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_08/00_B01.tif b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_08/00_B02.tif b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_08/00_B03.tif b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_08/00_B04.tif b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_10/00_B01.tif b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_10/00_B02.tif b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_10/00_B03.tif b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_10/00_B04.tif b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_11/00_B01.tif b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_11/00_B02.tif b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_11/00_B03.tif b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_11/00_B04.tif b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_12/00_B01.tif b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_12/00_B02.tif b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_12/00_B03.tif b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_12/00_B04.tif b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_03/00_B01.tif b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_03/00_B02.tif b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_03/00_B03.tif b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_03/00_B04.tif b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_04/00_B01.tif b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_04/00_B02.tif b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_04/00_B03.tif b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_04/00_B04.tif b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_08/00_B01.tif b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_08/00_B02.tif b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_08/00_B03.tif b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_08/00_B04.tif b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_10/00_B01.tif b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_10/00_B02.tif b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_10/00_B03.tif b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_10/00_B04.tif b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_11/00_B01.tif b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_11/00_B02.tif b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_11/00_B03.tif b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_11/00_B04.tif b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_12/00_B01.tif b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_12/00_B02.tif b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_12/00_B03.tif b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_12/00_B04.tif b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B04.tif differ diff --git a/tests/datasets/test_rwanda_field_boundary.py b/tests/datasets/test_rwanda_field_boundary.py index 6f83b12a93d..ddf5b5df7fb 100644 --- a/tests/datasets/test_rwanda_field_boundary.py +++ b/tests/datasets/test_rwanda_field_boundary.py @@ -1,9 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import glob import os -import shutil from pathlib import Path import matplotlib.pyplot as plt @@ -19,45 +17,26 @@ RGBBandsMissingError, RwandaFieldBoundary, ) - - -class Collection: - def download(self, output_dir: str, **kwargs: str) -> None: - glob_path = os.path.join('tests', 'data', 'rwanda_field_boundary', '*.tar.gz') - for tarball in glob.iglob(glob_path): - shutil.copy(tarball, output_dir) - - -def fetch(dataset_id: str, **kwargs: str) -> Collection: - return Collection() +from torchgeo.datasets.utils import Executable class TestRwandaFieldBoundary: @pytest.fixture(params=['train', 'test']) def dataset( - self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest + self, + azcopy: Executable, + monkeypatch: MonkeyPatch, + tmp_path: Path, + request: SubRequest, ) -> RwandaFieldBoundary: - radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') - monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch) - monkeypatch.setattr( - RwandaFieldBoundary, 'number_of_patches_per_split', {'train': 5, 'test': 5} - ) - monkeypatch.setattr( - RwandaFieldBoundary, - 'md5s', - { - 'train_images': 'af9395e2e49deefebb35fa65fa378ba3', - 'test_images': 'd104bb82323a39e7c3b3b7dd0156f550', - 'train_labels': '6cceaf16a141cf73179253a783e7d51b', - }, - ) + url = os.path.join('tests', 'data', 'rwanda_field_boundary') + monkeypatch.setattr(RwandaFieldBoundary, 'url', url) + monkeypatch.setattr(RwandaFieldBoundary, 'splits', {'train': 1, 'test': 1}) root = str(tmp_path) split = request.param transforms = nn.Identity() - return RwandaFieldBoundary( - root, split, transforms=transforms, api_key='', download=True, checksum=True - ) + return RwandaFieldBoundary(root, split, transforms=transforms, download=True) def test_getitem(self, dataset: RwandaFieldBoundary) -> None: x = dataset[0] @@ -69,23 +48,12 @@ def test_getitem(self, dataset: RwandaFieldBoundary) -> None: assert 'mask' not in x def test_len(self, dataset: RwandaFieldBoundary) -> None: - assert len(dataset) == 5 + assert len(dataset) == 1 def test_add(self, dataset: RwandaFieldBoundary) -> None: ds = dataset + dataset assert isinstance(ds, ConcatDataset) - assert len(ds) == 10 - - def test_needs_extraction(self, tmp_path: Path) -> None: - root = str(tmp_path) - for fn in [ - 'nasa_rwanda_field_boundary_competition_source_train.tar.gz', - 'nasa_rwanda_field_boundary_competition_source_test.tar.gz', - 'nasa_rwanda_field_boundary_competition_labels_train.tar.gz', - ]: - url = os.path.join('tests', 'data', 'rwanda_field_boundary', fn) - shutil.copy(url, root) - RwandaFieldBoundary(root, checksum=False) + assert len(ds) == 2 def test_already_downloaded(self, dataset: RwandaFieldBoundary) -> None: RwandaFieldBoundary(root=dataset.root) @@ -94,35 +62,8 @@ def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): RwandaFieldBoundary(str(tmp_path)) - def test_corrupted(self, tmp_path: Path) -> None: - for fn in [ - 'nasa_rwanda_field_boundary_competition_source_train.tar.gz', - 'nasa_rwanda_field_boundary_competition_source_test.tar.gz', - 'nasa_rwanda_field_boundary_competition_labels_train.tar.gz', - ]: - with open(os.path.join(tmp_path, fn), 'w') as f: - f.write('bad') - with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - RwandaFieldBoundary(root=str(tmp_path), checksum=True) - - def test_failed_download(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None: - radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') - monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch) - monkeypatch.setattr( - RwandaFieldBoundary, - 'md5s', - {'train_images': 'bad', 'test_images': 'bad', 'train_labels': 'bad'}, - ) - root = str(tmp_path) - with pytest.raises(RuntimeError, match='Dataset not found or corrupted.'): - RwandaFieldBoundary(root, 'train', api_key='', download=True, checksum=True) - - def test_no_api_key(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match='Must provide an API key to download'): - RwandaFieldBoundary(str(tmp_path), api_key=None, download=True) - def test_invalid_bands(self) -> None: - with pytest.raises(ValueError, match='is an invalid band name.'): + with pytest.raises(AssertionError): RwandaFieldBoundary(bands=('foo', 'bar')) def test_plot(self, dataset: RwandaFieldBoundary) -> None: diff --git a/torchgeo/datasets/rwanda_field_boundary.py b/torchgeo/datasets/rwanda_field_boundary.py index 9439e525ab3..07a496ea974 100644 --- a/torchgeo/datasets/rwanda_field_boundary.py +++ b/torchgeo/datasets/rwanda_field_boundary.py @@ -3,6 +3,7 @@ """Rwanda Field Boundary Competition dataset.""" +import glob import os from collections.abc import Callable, Sequence @@ -16,11 +17,11 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive +from .utils import which class RwandaFieldBoundary(NonGeoDataset): - r"""Rwanda Field Boundary Competition dataset. + """Rwanda Field Boundary Competition dataset. This dataset contains field boundaries for smallholder farms in eastern Rwanda. The Nasa Harvest program funded a team of annotators from TaQadam to label Planet @@ -46,40 +47,20 @@ class RwandaFieldBoundary(NonGeoDataset): This dataset requires the following additional library to be installed: - * `radiant-mlhub `_ to download the - imagery and labels from the Radiant Earth MLHub + * `azcopy `_: to download the + dataset from Source Cooperative. .. versionadded:: 0.5 """ - dataset_id = 'nasa_rwanda_field_boundary_competition' - collection_ids = [ - 'nasa_rwanda_field_boundary_competition_source_train', - 'nasa_rwanda_field_boundary_competition_labels_train', - 'nasa_rwanda_field_boundary_competition_source_test', - ] - number_of_patches_per_split = {'train': 57, 'test': 13} - - filenames = { - 'train_images': 'nasa_rwanda_field_boundary_competition_source_train.tar.gz', - 'test_images': 'nasa_rwanda_field_boundary_competition_source_test.tar.gz', - 'train_labels': 'nasa_rwanda_field_boundary_competition_labels_train.tar.gz', - } - md5s = { - 'train_images': '1f9ec08038218e67e11f82a86849b333', - 'test_images': '17bb0e56eedde2e7a43c57aa908dc125', - 'train_labels': '10e4eb761523c57b6d3bdf9394004f5f', - } + url = 'https://radiantearth.blob.core.windows.net/mlhub/nasa_rwanda_field_boundary_competition' + splits = {'train': 57, 'test': 13} dates = ('2021_03', '2021_04', '2021_08', '2021_10', '2021_11', '2021_12') - all_bands = ('B01', 'B02', 'B03', 'B04') rgb_bands = ('B03', 'B02', 'B01') - classes = ['No field-boundary', 'Field-boundary'] - splits = ['train', 'test'] - def __init__( self, root: str = 'data', @@ -87,8 +68,6 @@ def __init__( bands: Sequence[str] = all_bands, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, - api_key: str | None = None, - checksum: bool = False, ) -> None: """Initialize a new RwandaFieldBoundary instance. @@ -99,49 +78,29 @@ def __init__( transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: + AssertionError: If *split* or *bands* are invalid. DatasetNotFoundError: If dataset is not found and *download* is False. """ - self._validate_bands(bands) assert split in self.splits - if download and api_key is None: - raise RuntimeError('Must provide an API key to download the dataset') + assert set(bands) <= set(self.all_bands) + self.root = root + self.split = split self.bands = bands self.transforms = transforms - self.split = split self.download = download - self.api_key = api_key - self.checksum = checksum + self._verify() - self.image_filenames: list[list[list[str]]] = [] - self.mask_filenames: list[str] = [] - for i in range(self.number_of_patches_per_split[split]): - dates = [] - for date in self.dates: - patch = [] - for band in self.bands: - fn = os.path.join( - self.root, - f'nasa_rwanda_field_boundary_competition_source_{split}', - f'nasa_rwanda_field_boundary_competition_source_{split}_{i:02d}_{date}', # noqa: E501 - f'{band}.tif', - ) - patch.append(fn) - dates.append(patch) - self.image_filenames.append(dates) - self.mask_filenames.append( - os.path.join( - self.root, - f'nasa_rwanda_field_boundary_competition_labels_{split}', - f'nasa_rwanda_field_boundary_competition_labels_{split}_{i:02d}', - 'raster_labels.tif', - ) - ) + def __len__(self) -> int: + """Return the number of chips in the dataset. + + Returns: + length of the dataset + """ + return self.splits[self.split] def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. @@ -150,83 +109,34 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: index: index to return Returns: - a dict containing image, mask, transform, crs, and metadata at index. + a dict containing image and mask at index. """ - img_fns = self.image_filenames[index] - mask_fn = self.mask_filenames[index] - - imgs = [] - for date_fns in img_fns: - bands = [] - for band_fn in date_fns: - with rasterio.open(band_fn) as f: - bands.append(f.read(1).astype(np.int32)) - imgs.append(bands) - img = torch.from_numpy(np.array(imgs)) - - sample = {'image': img} + images = [] + for date in self.dates: + patches = [] + for band in self.bands: + path = os.path.join(self.root, 'source', self.split, date) + with rasterio.open(os.path.join(path, f'{index:02}_{band}.tif')) as src: + patches.append(src.read(1).astype(np.float32)) + images.append(patches) + sample = {'image': torch.from_numpy(np.array(images))} if self.split == 'train': - with rasterio.open(mask_fn) as f: - mask = f.read(1) - mask = torch.from_numpy(mask) - sample['mask'] = mask + path = os.path.join(self.root, 'labels', self.split) + with rasterio.open(os.path.join(path, f'{index:02}.tif')) as src: + sample['mask'] = torch.from_numpy(src.read(1).astype(np.int64)) if self.transforms is not None: sample = self.transforms(sample) return sample - def __len__(self) -> int: - """Return the number of chips in the dataset. - - Returns: - length of the dataset - """ - return len(self.image_filenames) - - def _validate_bands(self, bands: Sequence[str]) -> None: - """Validate list of bands. - - Args: - bands: user-provided sequence of bands to load - - Raises: - ValueError: if an invalid band name is provided - """ - for band in bands: - if band not in self.all_bands: - raise ValueError(f"'{band}' is an invalid band name.") - def _verify(self) -> None: """Verify the integrity of the dataset.""" # Check if the subdirectories already exist and have the correct number of files - checks = [] - for split, num_patches in self.number_of_patches_per_split.items(): - path = os.path.join( - self.root, f'nasa_rwanda_field_boundary_competition_source_{split}' - ) - if os.path.exists(path): - num_files = len(os.listdir(path)) - # 6 dates + 1 collection.json file - checks.append(num_files == (num_patches * 6) + 1) - else: - checks.append(False) - - if all(checks): - return - - # Check if tar file already exists (if so then extract) - have_all_files = True - for group in ['train_images', 'train_labels', 'test_images']: - filepath = os.path.join(self.root, self.filenames[group]) - if os.path.exists(filepath): - if self.checksum and not check_integrity(filepath, self.md5s[group]): - raise RuntimeError('Dataset found, but corrupted.') - extract_archive(filepath) - else: - have_all_files = False - if have_all_files: + path = os.path.join(self.root, 'source', self.split, '*', '*.tif') + expected = len(self.dates) * self.splits[self.split] * len(self.all_bands) + if len(glob.glob(path)) == expected: return # Check if the user requested to download the dataset @@ -237,15 +147,10 @@ def _verify(self) -> None: self._download() def _download(self) -> None: - """Download the dataset and extract it.""" - for collection_id in self.collection_ids: - download_radiant_mlhub_collection(collection_id, self.root, self.api_key) - - for group in ['train_images', 'train_labels', 'test_images']: - filepath = os.path.join(self.root, self.filenames[group]) - if self.checksum and not check_integrity(filepath, self.md5s[group]): - raise RuntimeError('Dataset not found or corrupted.') - extract_archive(filepath, self.root) + """Download the dataset.""" + os.makedirs(self.root, exist_ok=True) + azcopy = which('azcopy') + azcopy('sync', self.url, self.root, '--recursive=true') def plot( self,