Skip to content

Commit

Permalink
South Africa Crop Type: add download support (#2181)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart authored Jul 26, 2024
1 parent 811b3d9 commit 7ca1342
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 9 deletions.
17 changes: 14 additions & 3 deletions tests/datasets/test_south_africa_crop_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from rasterio.crs import CRS

from torchgeo.datasets import (
Expand All @@ -19,15 +20,25 @@
SouthAfricaCropType,
UnionDataset,
)
from torchgeo.datasets.utils import Executable


class TestSouthAfricaCropType:
@pytest.fixture(params=[SouthAfricaCropType.s1_bands, SouthAfricaCropType.s2_bands])
def dataset(self, request: SubRequest) -> SouthAfricaCropType:
path = os.path.join('tests', 'data', 'south_africa_crop_type')
def dataset(
self,
request: SubRequest,
azcopy: Executable,
monkeypatch: MonkeyPatch,
tmp_path: Path,
) -> SouthAfricaCropType:
url = os.path.join('tests', 'data', 'south_africa_crop_type')
monkeypatch.setattr(SouthAfricaCropType, 'url', url)
bands = request.param
transforms = nn.Identity()
return SouthAfricaCropType(path, bands=bands, transforms=transforms)
return SouthAfricaCropType(
tmp_path, bands=bands, transforms=transforms, download=True
)

def test_getitem(self, dataset: SouthAfricaCropType) -> None:
x = dataset[dataset.bounds]
Expand Down
44 changes: 38 additions & 6 deletions torchgeo/datasets/south_africa_crop_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from rasterio.crs import CRS
from torch import Tensor

from .errors import RGBBandsMissingError
from .errors import DatasetNotFoundError, RGBBandsMissingError
from .geo import RasterDataset
from .utils import BoundingBox, Path
from .utils import BoundingBox, Path, which


class SouthAfricaCropType(RasterDataset):
Expand Down Expand Up @@ -60,9 +60,17 @@ class SouthAfricaCropType(RasterDataset):
"Crop Type Classification Dataset for Western Cape, South Africa",
Version 1.0, Radiant MLHub, https://doi.org/10.34911/rdnt.j0co8q
.. note::
This dataset requires the following additional library to be installed:
* `azcopy <https://github.com/Azure/azure-storage-azcopy>`_: to download the
dataset from Source Cooperative.
.. versionadded:: 0.6
"""

url = 'https://radiantearth.blob.core.windows.net/mlhub/ref-south-africa-crops-competition-v1'

filename_glob = '*_07_*_{}_10m.*'
filename_regex = r"""
^(?P<field_id>\d+)
Expand Down Expand Up @@ -108,6 +116,7 @@ def __init__(
classes: list[int] = list(cmap.keys()),
bands: list[str] = s2_bands,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
) -> None:
"""Initialize a new South Africa Crop Type dataset instance.
Expand All @@ -118,6 +127,7 @@ def __init__(
bands: the subset of bands to load
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
Raises:
DatasetNotFoundError: If dataset is not found and *download* is False.
Expand All @@ -128,15 +138,17 @@ def __init__(
assert 0 in classes, 'Classes must include the background class: 0'

self.paths = paths
self.classes = classes
self.ordinal_map = torch.zeros(max(self.cmap.keys()) + 1, dtype=self.dtype)
self.ordinal_cmap = torch.zeros((len(self.classes), 4), dtype=torch.uint8)
self.download = download
self.filename_glob = self.filename_glob.format(bands[0])

self._verify()

super().__init__(paths=paths, crs=crs, bands=bands, transforms=transforms)

# Map chosen classes to ordinal numbers, all others mapped to background class
for v, k in enumerate(self.classes):
self.ordinal_map = torch.zeros(max(self.cmap.keys()) + 1, dtype=self.dtype)
self.ordinal_cmap = torch.zeros((len(classes), 4), dtype=torch.uint8)
for v, k in enumerate(classes):
self.ordinal_map[k] = v
self.ordinal_cmap[v] = torch.tensor(self.cmap[k])

Expand Down Expand Up @@ -226,6 +238,26 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:

return sample

def _verify(self) -> None:
"""Verify the integrity of the dataset."""
# Check if the files already exist
if self.files:
return

# Check if the user requested to download the dataset
if not self.download:
raise DatasetNotFoundError(self)

# Download the dataset
self._download()

def _download(self) -> None:
"""Download the dataset."""
assert isinstance(self.paths, str | pathlib.Path)
os.makedirs(self.paths, exist_ok=True)
azcopy = which('azcopy')
azcopy('sync', f'{self.url}', self.paths, '--recursive=true')

def plot(
self,
sample: dict[str, Tensor],
Expand Down

0 comments on commit 7ca1342

Please sign in to comment.