diff --git a/tests/data/rwanda_field_boundary/data.py b/tests/data/rwanda_field_boundary/data.py
old mode 100644
new mode 100755
index a3522e8c962..bf9954e8935
--- a/tests/data/rwanda_field_boundary/data.py
+++ b/tests/data/rwanda_field_boundary/data.py
@@ -3,99 +3,46 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
-import hashlib
import os
-import shutil
import numpy as np
import rasterio
+from rasterio.crs import CRS
+from rasterio.transform import Affine
dates = ('2021_03', '2021_04', '2021_08', '2021_10', '2021_11', '2021_12')
all_bands = ('B01', 'B02', 'B03', 'B04')
SIZE = 32
-NUM_SAMPLES = 5
+DTYPE = np.uint16
+NUM_SAMPLES = 1
np.random.seed(0)
-
-def create_mask(fn: str) -> None:
- profile = {
- 'driver': 'GTiff',
- 'dtype': 'uint8',
- 'nodata': 0.0,
- 'width': SIZE,
- 'height': SIZE,
- 'count': 1,
- 'crs': 'epsg:3857',
- 'compress': 'lzw',
- 'predictor': 2,
- 'transform': rasterio.Affine(10.0, 0.0, 0.0, 0.0, -10.0, 0.0),
- 'blockysize': 32,
- 'tiled': False,
- 'interleave': 'band',
- }
- with rasterio.open(fn, 'w', **profile) as f:
- f.write(np.random.randint(0, 2, size=(SIZE, SIZE), dtype=np.uint8), 1)
-
-
-def create_img(fn: str) -> None:
- profile = {
- 'driver': 'GTiff',
- 'dtype': 'uint16',
- 'nodata': 0.0,
- 'width': SIZE,
- 'height': SIZE,
- 'count': 1,
- 'crs': 'epsg:3857',
- 'compress': 'lzw',
- 'predictor': 2,
- 'blockysize': 16,
- 'transform': rasterio.Affine(10.0, 0.0, 0.0, 0.0, -10.0, 0.0),
- 'tiled': False,
- 'interleave': 'band',
- }
- with rasterio.open(fn, 'w', **profile) as f:
- f.write(np.random.randint(0, 2, size=(SIZE, SIZE), dtype=np.uint16), 1)
-
-
-if __name__ == '__main__':
- # Train and test images
- for split in ('train', 'test'):
- for i in range(NUM_SAMPLES):
- for date in dates:
- directory = os.path.join(
- f'nasa_rwanda_field_boundary_competition_source_{split}',
- f'nasa_rwanda_field_boundary_competition_source_{split}_{i:02d}_{date}', # noqa: E501
- )
- os.makedirs(directory, exist_ok=True)
- for band in all_bands:
- create_img(os.path.join(directory, f'{band}.tif'))
-
- # Create collections.json, this isn't used by the dataset but is checked to
- # exist
- with open(
- f'nasa_rwanda_field_boundary_competition_source_{split}/collections.json',
- 'w',
- ) as f:
- f.write('Not used')
-
- # Train labels
- for i in range(NUM_SAMPLES):
- directory = os.path.join(
- 'nasa_rwanda_field_boundary_competition_labels_train',
- f'nasa_rwanda_field_boundary_competition_labels_train_{i:02d}',
- )
- os.makedirs(directory, exist_ok=True)
- create_mask(os.path.join(directory, 'raster_labels.tif'))
-
- # Create directories and compute checksums
- for filename in [
- 'nasa_rwanda_field_boundary_competition_source_train',
- 'nasa_rwanda_field_boundary_competition_source_test',
- 'nasa_rwanda_field_boundary_competition_labels_train',
- ]:
- shutil.make_archive(filename, 'gztar', '.', filename)
- # Compute checksums
- with open(f'{filename}.tar.gz', 'rb') as f:
- md5 = hashlib.md5(f.read()).hexdigest()
- print(f'{filename}: {md5}')
+profile = {
+ 'driver': 'GTiff',
+ 'dtype': DTYPE,
+ 'width': SIZE,
+ 'height': SIZE,
+ 'count': 1,
+ 'crs': CRS.from_epsg(3857),
+ 'transform': Affine(
+ 4.77731426716, 0.0, 3374518.037700199, 0.0, -4.77731426716, -168438.54642526805
+ ),
+}
+Z = np.random.randint(np.iinfo(DTYPE).max, size=(SIZE, SIZE), dtype=DTYPE)
+
+for sample in range(NUM_SAMPLES):
+ for split in ['train', 'test']:
+ for date in dates:
+ path = os.path.join('source', split, date)
+ os.makedirs(path, exist_ok=True)
+ for band in all_bands:
+ file = os.path.join(path, f'{sample:02}_{band}.tif')
+ with rasterio.open(file, 'w', **profile) as src:
+ src.write(Z, 1)
+
+ path = os.path.join('labels', 'train')
+ os.makedirs(path, exist_ok=True)
+ file = os.path.join(path, f'{sample:02}.tif')
+ with rasterio.open(file, 'w', **profile) as src:
+ src.write(Z, 1)
diff --git a/tests/data/rwanda_field_boundary/labels/train/00.tif b/tests/data/rwanda_field_boundary/labels/train/00.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/labels/train/00.tif differ
diff --git a/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_labels_train.tar.gz b/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_labels_train.tar.gz
deleted file mode 100644
index ffa98bb53d6..00000000000
Binary files a/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_labels_train.tar.gz and /dev/null differ
diff --git a/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_test.tar.gz b/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_test.tar.gz
deleted file mode 100644
index a834f66bf38..00000000000
Binary files a/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_test.tar.gz and /dev/null differ
diff --git a/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_train.tar.gz b/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_train.tar.gz
deleted file mode 100644
index 8239f70c200..00000000000
Binary files a/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_train.tar.gz and /dev/null differ
diff --git a/tests/data/rwanda_field_boundary/source/test/2021_03/00_B01.tif b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B01.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B01.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/test/2021_03/00_B02.tif b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B02.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B02.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/test/2021_03/00_B03.tif b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B03.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B03.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/test/2021_03/00_B04.tif b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B04.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B04.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/test/2021_04/00_B01.tif b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B01.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B01.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/test/2021_04/00_B02.tif b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B02.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B02.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/test/2021_04/00_B03.tif b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B03.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B03.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/test/2021_04/00_B04.tif b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B04.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B04.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/test/2021_08/00_B01.tif b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B01.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B01.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/test/2021_08/00_B02.tif b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B02.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B02.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/test/2021_08/00_B03.tif b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B03.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B03.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/test/2021_08/00_B04.tif b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B04.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B04.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/test/2021_10/00_B01.tif b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B01.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B01.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/test/2021_10/00_B02.tif b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B02.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B02.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/test/2021_10/00_B03.tif b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B03.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B03.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/test/2021_10/00_B04.tif b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B04.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B04.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/test/2021_11/00_B01.tif b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B01.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B01.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/test/2021_11/00_B02.tif b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B02.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B02.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/test/2021_11/00_B03.tif b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B03.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B03.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/test/2021_11/00_B04.tif b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B04.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B04.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/test/2021_12/00_B01.tif b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B01.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B01.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/test/2021_12/00_B02.tif b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B02.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B02.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/test/2021_12/00_B03.tif b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B03.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B03.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/test/2021_12/00_B04.tif b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B04.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B04.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/train/2021_03/00_B01.tif b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B01.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B01.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/train/2021_03/00_B02.tif b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B02.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B02.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/train/2021_03/00_B03.tif b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B03.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B03.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/train/2021_03/00_B04.tif b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B04.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B04.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/train/2021_04/00_B01.tif b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B01.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B01.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/train/2021_04/00_B02.tif b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B02.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B02.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/train/2021_04/00_B03.tif b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B03.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B03.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/train/2021_04/00_B04.tif b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B04.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B04.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/train/2021_08/00_B01.tif b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B01.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B01.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/train/2021_08/00_B02.tif b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B02.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B02.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/train/2021_08/00_B03.tif b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B03.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B03.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/train/2021_08/00_B04.tif b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B04.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B04.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/train/2021_10/00_B01.tif b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B01.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B01.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/train/2021_10/00_B02.tif b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B02.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B02.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/train/2021_10/00_B03.tif b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B03.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B03.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/train/2021_10/00_B04.tif b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B04.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B04.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/train/2021_11/00_B01.tif b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B01.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B01.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/train/2021_11/00_B02.tif b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B02.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B02.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/train/2021_11/00_B03.tif b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B03.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B03.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/train/2021_11/00_B04.tif b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B04.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B04.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/train/2021_12/00_B01.tif b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B01.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B01.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/train/2021_12/00_B02.tif b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B02.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B02.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/train/2021_12/00_B03.tif b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B03.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B03.tif differ
diff --git a/tests/data/rwanda_field_boundary/source/train/2021_12/00_B04.tif b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B04.tif
new file mode 100644
index 00000000000..bd39a26d5e3
Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B04.tif differ
diff --git a/tests/datasets/test_rwanda_field_boundary.py b/tests/datasets/test_rwanda_field_boundary.py
index 6f83b12a93d..ddf5b5df7fb 100644
--- a/tests/datasets/test_rwanda_field_boundary.py
+++ b/tests/datasets/test_rwanda_field_boundary.py
@@ -1,9 +1,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
-import glob
import os
-import shutil
from pathlib import Path
import matplotlib.pyplot as plt
@@ -19,45 +17,26 @@
RGBBandsMissingError,
RwandaFieldBoundary,
)
-
-
-class Collection:
- def download(self, output_dir: str, **kwargs: str) -> None:
- glob_path = os.path.join('tests', 'data', 'rwanda_field_boundary', '*.tar.gz')
- for tarball in glob.iglob(glob_path):
- shutil.copy(tarball, output_dir)
-
-
-def fetch(dataset_id: str, **kwargs: str) -> Collection:
- return Collection()
+from torchgeo.datasets.utils import Executable
class TestRwandaFieldBoundary:
@pytest.fixture(params=['train', 'test'])
def dataset(
- self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
+ self,
+ azcopy: Executable,
+ monkeypatch: MonkeyPatch,
+ tmp_path: Path,
+ request: SubRequest,
) -> RwandaFieldBoundary:
- radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3')
- monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch)
- monkeypatch.setattr(
- RwandaFieldBoundary, 'number_of_patches_per_split', {'train': 5, 'test': 5}
- )
- monkeypatch.setattr(
- RwandaFieldBoundary,
- 'md5s',
- {
- 'train_images': 'af9395e2e49deefebb35fa65fa378ba3',
- 'test_images': 'd104bb82323a39e7c3b3b7dd0156f550',
- 'train_labels': '6cceaf16a141cf73179253a783e7d51b',
- },
- )
+ url = os.path.join('tests', 'data', 'rwanda_field_boundary')
+ monkeypatch.setattr(RwandaFieldBoundary, 'url', url)
+ monkeypatch.setattr(RwandaFieldBoundary, 'splits', {'train': 1, 'test': 1})
root = str(tmp_path)
split = request.param
transforms = nn.Identity()
- return RwandaFieldBoundary(
- root, split, transforms=transforms, api_key='', download=True, checksum=True
- )
+ return RwandaFieldBoundary(root, split, transforms=transforms, download=True)
def test_getitem(self, dataset: RwandaFieldBoundary) -> None:
x = dataset[0]
@@ -69,23 +48,12 @@ def test_getitem(self, dataset: RwandaFieldBoundary) -> None:
assert 'mask' not in x
def test_len(self, dataset: RwandaFieldBoundary) -> None:
- assert len(dataset) == 5
+ assert len(dataset) == 1
def test_add(self, dataset: RwandaFieldBoundary) -> None:
ds = dataset + dataset
assert isinstance(ds, ConcatDataset)
- assert len(ds) == 10
-
- def test_needs_extraction(self, tmp_path: Path) -> None:
- root = str(tmp_path)
- for fn in [
- 'nasa_rwanda_field_boundary_competition_source_train.tar.gz',
- 'nasa_rwanda_field_boundary_competition_source_test.tar.gz',
- 'nasa_rwanda_field_boundary_competition_labels_train.tar.gz',
- ]:
- url = os.path.join('tests', 'data', 'rwanda_field_boundary', fn)
- shutil.copy(url, root)
- RwandaFieldBoundary(root, checksum=False)
+ assert len(ds) == 2
def test_already_downloaded(self, dataset: RwandaFieldBoundary) -> None:
RwandaFieldBoundary(root=dataset.root)
@@ -94,35 +62,8 @@ def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
RwandaFieldBoundary(str(tmp_path))
- def test_corrupted(self, tmp_path: Path) -> None:
- for fn in [
- 'nasa_rwanda_field_boundary_competition_source_train.tar.gz',
- 'nasa_rwanda_field_boundary_competition_source_test.tar.gz',
- 'nasa_rwanda_field_boundary_competition_labels_train.tar.gz',
- ]:
- with open(os.path.join(tmp_path, fn), 'w') as f:
- f.write('bad')
- with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
- RwandaFieldBoundary(root=str(tmp_path), checksum=True)
-
- def test_failed_download(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None:
- radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3')
- monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch)
- monkeypatch.setattr(
- RwandaFieldBoundary,
- 'md5s',
- {'train_images': 'bad', 'test_images': 'bad', 'train_labels': 'bad'},
- )
- root = str(tmp_path)
- with pytest.raises(RuntimeError, match='Dataset not found or corrupted.'):
- RwandaFieldBoundary(root, 'train', api_key='', download=True, checksum=True)
-
- def test_no_api_key(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match='Must provide an API key to download'):
- RwandaFieldBoundary(str(tmp_path), api_key=None, download=True)
-
def test_invalid_bands(self) -> None:
- with pytest.raises(ValueError, match='is an invalid band name.'):
+ with pytest.raises(AssertionError):
RwandaFieldBoundary(bands=('foo', 'bar'))
def test_plot(self, dataset: RwandaFieldBoundary) -> None:
diff --git a/torchgeo/datasets/rwanda_field_boundary.py b/torchgeo/datasets/rwanda_field_boundary.py
index 9439e525ab3..07a496ea974 100644
--- a/torchgeo/datasets/rwanda_field_boundary.py
+++ b/torchgeo/datasets/rwanda_field_boundary.py
@@ -3,6 +3,7 @@
"""Rwanda Field Boundary Competition dataset."""
+import glob
import os
from collections.abc import Callable, Sequence
@@ -16,11 +17,11 @@
from .errors import DatasetNotFoundError, RGBBandsMissingError
from .geo import NonGeoDataset
-from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive
+from .utils import which
class RwandaFieldBoundary(NonGeoDataset):
- r"""Rwanda Field Boundary Competition dataset.
+ """Rwanda Field Boundary Competition dataset.
This dataset contains field boundaries for smallholder farms in eastern Rwanda.
The Nasa Harvest program funded a team of annotators from TaQadam to label Planet
@@ -46,40 +47,20 @@ class RwandaFieldBoundary(NonGeoDataset):
This dataset requires the following additional library to be installed:
- * `radiant-mlhub `_ to download the
- imagery and labels from the Radiant Earth MLHub
+ * `azcopy `_: to download the
+ dataset from Source Cooperative.
.. versionadded:: 0.5
"""
- dataset_id = 'nasa_rwanda_field_boundary_competition'
- collection_ids = [
- 'nasa_rwanda_field_boundary_competition_source_train',
- 'nasa_rwanda_field_boundary_competition_labels_train',
- 'nasa_rwanda_field_boundary_competition_source_test',
- ]
- number_of_patches_per_split = {'train': 57, 'test': 13}
-
- filenames = {
- 'train_images': 'nasa_rwanda_field_boundary_competition_source_train.tar.gz',
- 'test_images': 'nasa_rwanda_field_boundary_competition_source_test.tar.gz',
- 'train_labels': 'nasa_rwanda_field_boundary_competition_labels_train.tar.gz',
- }
- md5s = {
- 'train_images': '1f9ec08038218e67e11f82a86849b333',
- 'test_images': '17bb0e56eedde2e7a43c57aa908dc125',
- 'train_labels': '10e4eb761523c57b6d3bdf9394004f5f',
- }
+ url = 'https://radiantearth.blob.core.windows.net/mlhub/nasa_rwanda_field_boundary_competition'
+ splits = {'train': 57, 'test': 13}
dates = ('2021_03', '2021_04', '2021_08', '2021_10', '2021_11', '2021_12')
-
all_bands = ('B01', 'B02', 'B03', 'B04')
rgb_bands = ('B03', 'B02', 'B01')
-
classes = ['No field-boundary', 'Field-boundary']
- splits = ['train', 'test']
-
def __init__(
self,
root: str = 'data',
@@ -87,8 +68,6 @@ def __init__(
bands: Sequence[str] = all_bands,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
- api_key: str | None = None,
- checksum: bool = False,
) -> None:
"""Initialize a new RwandaFieldBoundary instance.
@@ -99,49 +78,29 @@ def __init__(
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
- api_key: a RadiantEarth MLHub API key to use for downloading the dataset
- checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
+ AssertionError: If *split* or *bands* are invalid.
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
- self._validate_bands(bands)
assert split in self.splits
- if download and api_key is None:
- raise RuntimeError('Must provide an API key to download the dataset')
+ assert set(bands) <= set(self.all_bands)
+
self.root = root
+ self.split = split
self.bands = bands
self.transforms = transforms
- self.split = split
self.download = download
- self.api_key = api_key
- self.checksum = checksum
+
self._verify()
- self.image_filenames: list[list[list[str]]] = []
- self.mask_filenames: list[str] = []
- for i in range(self.number_of_patches_per_split[split]):
- dates = []
- for date in self.dates:
- patch = []
- for band in self.bands:
- fn = os.path.join(
- self.root,
- f'nasa_rwanda_field_boundary_competition_source_{split}',
- f'nasa_rwanda_field_boundary_competition_source_{split}_{i:02d}_{date}', # noqa: E501
- f'{band}.tif',
- )
- patch.append(fn)
- dates.append(patch)
- self.image_filenames.append(dates)
- self.mask_filenames.append(
- os.path.join(
- self.root,
- f'nasa_rwanda_field_boundary_competition_labels_{split}',
- f'nasa_rwanda_field_boundary_competition_labels_{split}_{i:02d}',
- 'raster_labels.tif',
- )
- )
+ def __len__(self) -> int:
+ """Return the number of chips in the dataset.
+
+ Returns:
+ length of the dataset
+ """
+ return self.splits[self.split]
def __getitem__(self, index: int) -> dict[str, Tensor]:
"""Return an index within the dataset.
@@ -150,83 +109,34 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
index: index to return
Returns:
- a dict containing image, mask, transform, crs, and metadata at index.
+ a dict containing image and mask at index.
"""
- img_fns = self.image_filenames[index]
- mask_fn = self.mask_filenames[index]
-
- imgs = []
- for date_fns in img_fns:
- bands = []
- for band_fn in date_fns:
- with rasterio.open(band_fn) as f:
- bands.append(f.read(1).astype(np.int32))
- imgs.append(bands)
- img = torch.from_numpy(np.array(imgs))
-
- sample = {'image': img}
+ images = []
+ for date in self.dates:
+ patches = []
+ for band in self.bands:
+ path = os.path.join(self.root, 'source', self.split, date)
+ with rasterio.open(os.path.join(path, f'{index:02}_{band}.tif')) as src:
+ patches.append(src.read(1).astype(np.float32))
+ images.append(patches)
+ sample = {'image': torch.from_numpy(np.array(images))}
if self.split == 'train':
- with rasterio.open(mask_fn) as f:
- mask = f.read(1)
- mask = torch.from_numpy(mask)
- sample['mask'] = mask
+ path = os.path.join(self.root, 'labels', self.split)
+ with rasterio.open(os.path.join(path, f'{index:02}.tif')) as src:
+ sample['mask'] = torch.from_numpy(src.read(1).astype(np.int64))
if self.transforms is not None:
sample = self.transforms(sample)
return sample
- def __len__(self) -> int:
- """Return the number of chips in the dataset.
-
- Returns:
- length of the dataset
- """
- return len(self.image_filenames)
-
- def _validate_bands(self, bands: Sequence[str]) -> None:
- """Validate list of bands.
-
- Args:
- bands: user-provided sequence of bands to load
-
- Raises:
- ValueError: if an invalid band name is provided
- """
- for band in bands:
- if band not in self.all_bands:
- raise ValueError(f"'{band}' is an invalid band name.")
-
def _verify(self) -> None:
"""Verify the integrity of the dataset."""
# Check if the subdirectories already exist and have the correct number of files
- checks = []
- for split, num_patches in self.number_of_patches_per_split.items():
- path = os.path.join(
- self.root, f'nasa_rwanda_field_boundary_competition_source_{split}'
- )
- if os.path.exists(path):
- num_files = len(os.listdir(path))
- # 6 dates + 1 collection.json file
- checks.append(num_files == (num_patches * 6) + 1)
- else:
- checks.append(False)
-
- if all(checks):
- return
-
- # Check if tar file already exists (if so then extract)
- have_all_files = True
- for group in ['train_images', 'train_labels', 'test_images']:
- filepath = os.path.join(self.root, self.filenames[group])
- if os.path.exists(filepath):
- if self.checksum and not check_integrity(filepath, self.md5s[group]):
- raise RuntimeError('Dataset found, but corrupted.')
- extract_archive(filepath)
- else:
- have_all_files = False
- if have_all_files:
+ path = os.path.join(self.root, 'source', self.split, '*', '*.tif')
+ expected = len(self.dates) * self.splits[self.split] * len(self.all_bands)
+ if len(glob.glob(path)) == expected:
return
# Check if the user requested to download the dataset
@@ -237,15 +147,10 @@ def _verify(self) -> None:
self._download()
def _download(self) -> None:
- """Download the dataset and extract it."""
- for collection_id in self.collection_ids:
- download_radiant_mlhub_collection(collection_id, self.root, self.api_key)
-
- for group in ['train_images', 'train_labels', 'test_images']:
- filepath = os.path.join(self.root, self.filenames[group])
- if self.checksum and not check_integrity(filepath, self.md5s[group]):
- raise RuntimeError('Dataset not found or corrupted.')
- extract_archive(filepath, self.root)
+ """Download the dataset."""
+ os.makedirs(self.root, exist_ok=True)
+ azcopy = which('azcopy')
+ azcopy('sync', self.url, self.root, '--recursive=true')
def plot(
self,