Skip to content

Commit

Permalink
WesternUSALiveFuelMoisture: radiant mlhub -> source coop (#2206)
Browse files Browse the repository at this point in the history
* WesternUSALiveFuelMoisture: radiant mlhub -> source coop

* Finish updating dataset

* Recursive

* Update tests
  • Loading branch information
adamjstewart authored Aug 5, 2024
1 parent 9a92290 commit 06aa33f
Show file tree
Hide file tree
Showing 10 changed files with 33 additions and 164 deletions.
66 changes: 4 additions & 62 deletions tests/data/western_usa_live_fuel_moisture/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import json
import os
import shutil

NUM_SAMPLES = 3

Expand Down Expand Up @@ -159,65 +157,9 @@
'geometry': {'type': 'Point', 'coordinates': [-115.8855556, 42.44111111]},
}

STAC = {
'assets': {
'documentation': {
'href': '../_common/documentation.pdf',
'type': 'application/pdf',
},
'labels': {'href': 'labels.geojson', 'type': 'application/geo+json'},
'training_features_descriptions': {
'href': '../_common/training_features_descriptions.csv',
'title': 'Training Features Descriptions',
'type': 'text/csv',
},
},
'bbox': [-115.8855556, 42.44111111, -115.8855556, 42.44111111],
'collection': 'su_sar_moisture_content',
'geometry': {'coordinates': [-115.8855556, 42.44111111], 'type': 'Point'},
'id': 'su_sar_moisture_content_0001',
'links': [
{'href': '../collection.json', 'rel': 'collection'},
{'href': '../collection.json', 'rel': 'parent'},
],
'properties': {
'datetime': '2015-06-30T00:00:00Z',
'label:description': '',
'label:properties': ['percent(t)'],
'label:type': 'vector',
},
'stac_extensions': ['label'],
'stac_version': '1.0.0-beta.2',
'type': 'Feature',
}


def create_file(path: str) -> None:
label_path = os.path.join(path, 'labels.geojson')
with open(label_path, 'w') as f:
os.makedirs(data_dir, exist_ok=True)
for i in range(1, NUM_SAMPLES + 1):
filename = os.path.join(data_dir, f'feature_{i:04}.geojson')
with open(filename, 'w') as f:
json.dump(LABELS, f)

stac_path = os.path.join(path, 'stac.json')
with open(stac_path, 'w') as f:
json.dump(STAC, f)


if __name__ == '__main__':
# Remove old data
if os.path.isdir(data_dir):
shutil.rmtree(data_dir)

os.makedirs(os.path.join(os.getcwd(), data_dir))

for i in range(NUM_SAMPLES):
sample_dir = os.path.join(data_dir, data_dir + f'_{i}')
os.makedirs(sample_dir)
create_file(sample_dir)

# Compress data
shutil.make_archive(data_dir, 'gztar', '.', data_dir)

# Compute checksums
with open(data_dir + '.tar.gz', 'rb') as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f'{data_dir}.tar.gz: {md5}')
Binary file not shown.

This file was deleted.

This file was deleted.

This file was deleted.

49 changes: 9 additions & 40 deletions tests/datasets/test_western_usa_live_fuel_moisture.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Licensed under the MIT License.

import os
import shutil
from pathlib import Path

import pytest
Expand All @@ -11,63 +10,33 @@
from pytest import MonkeyPatch

from torchgeo.datasets import DatasetNotFoundError, WesternUSALiveFuelMoisture


class Collection:
def download(self, output_dir: str, **kwargs: str) -> None:
tarball_path = os.path.join(
'tests',
'data',
'western_usa_live_fuel_moisture',
'su_sar_moisture_content.tar.gz',
)
shutil.copy(tarball_path, output_dir)


def fetch(collection_id: str, **kwargs: str) -> Collection:
return Collection()
from torchgeo.datasets.utils import Executable


class TestWesternUSALiveFuelMoisture:
@pytest.fixture
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path
self, azcopy: Executable, monkeypatch: MonkeyPatch, tmp_path: Path
) -> WesternUSALiveFuelMoisture:
radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3')
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch)
md5 = 'ecbc9269dd27c4efe7aa887960054351'
monkeypatch.setattr(WesternUSALiveFuelMoisture, 'md5', md5)
root = tmp_path
url = os.path.join('tests', 'data', 'western_usa_live_fuel_moisture')
monkeypatch.setattr(WesternUSALiveFuelMoisture, 'url', url)
transforms = nn.Identity()
return WesternUSALiveFuelMoisture(
root, transforms=transforms, download=True, api_key='', checksum=True
tmp_path, transforms=transforms, download=True
)

@pytest.mark.parametrize('index', [0, 1, 2])
def test_getitem(self, dataset: WesternUSALiveFuelMoisture, index: int) -> None:
x = dataset[index]
def test_getitem(self, dataset: WesternUSALiveFuelMoisture) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x['input'], torch.Tensor)
assert isinstance(x['label'], torch.Tensor)

def test_len(self, dataset: WesternUSALiveFuelMoisture) -> None:
assert len(dataset) == 3

def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join(
'tests',
'data',
'western_usa_live_fuel_moisture',
'su_sar_moisture_content.tar.gz',
)
root = tmp_path
shutil.copy(pathname, root)
WesternUSALiveFuelMoisture(root)
def test_already_downloaded(self, dataset: WesternUSALiveFuelMoisture) -> None:
WesternUSALiveFuelMoisture(dataset.root)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
WesternUSALiveFuelMoisture(tmp_path)

def test_invalid_features(self, dataset: WesternUSALiveFuelMoisture) -> None:
with pytest.raises(AssertionError, match='Invalid input variable name.'):
WesternUSALiveFuelMoisture(dataset.root, input_features=['foo'])
79 changes: 20 additions & 59 deletions torchgeo/datasets/western_usa_live_fuel_moisture.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@

import pandas as pd
import torch
from torch import Tensor

from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import Path, download_radiant_mlhub_collection, extract_archive
from .utils import Path, which


class WesternUSALiveFuelMoisture(NonGeoDataset):
Expand All @@ -25,7 +24,7 @@ class WesternUSALiveFuelMoisture(NonGeoDataset):
(mass of water in vegetation) and remotely sensed variables
in the western United States. It contains 2615 datapoints and 138
variables. For more details see the
`dataset page <https://mlhub.earth/data/su_sar_moisture_content_main>`_.
`dataset page <https://beta.source.coop/stanford/sar-moisture-conent/>`_.
Dataset Format:
Expand All @@ -44,15 +43,13 @@ class WesternUSALiveFuelMoisture(NonGeoDataset):
This dataset requires the following additional library to be installed:
* `radiant-mlhub <https://pypi.org/project/radiant-mlhub/>`_ to download the
imagery and labels from the Radiant Earth MLHub
* `azcopy <https://github.com/Azure/azure-storage-azcopy>`_: to download the
dataset from Source Cooperative.
.. versionadded:: 0.5
"""

collection_id = 'su_sar_moisture_content'

md5 = 'a6c0721f06a3a0110b7d1243b18614f0'
url = 'https://radiantearth.blob.core.windows.net/mlhub/su-sar-moisture-content'

label_name = 'percent(t)'

Expand Down Expand Up @@ -204,8 +201,6 @@ def __init__(
input_features: list[str] = all_variable_names,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
download: bool = False,
api_key: str | None = None,
checksum: bool = False,
) -> None:
"""Initialize a new Western USA Live Fuel Moisture Dataset.
Expand All @@ -215,42 +210,22 @@ def __init__(
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
AssertionError: if ``input_features`` contains invalid variable names
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
super().__init__()
assert set(input_features) <= set(self.all_variable_names)

self.root = root
self.input_features = input_features
self.transforms = transforms
self.checksum = checksum
self.download = download
self.api_key = api_key

self._verify()

assert all(
input in self.all_variable_names for input in input_features
), 'Invalid input variable name.'
self.input_features = input_features

self.collection = self._retrieve_collection()

self.dataframe = self._load_data()

def _retrieve_collection(self) -> list[str]:
"""Retrieve dataset collection that maps samples to paths.
Returns:
list of sample paths
"""
return glob.glob(
os.path.join(self.root, self.collection_id, '**', 'labels.geojson')
)

def __len__(self) -> int:
"""Return the number of data points in the dataset.
Expand All @@ -270,7 +245,7 @@ def __getitem__(self, index: int) -> dict[str, Any]:
"""
data = self.dataframe.iloc[index, :]

sample: dict[str, Tensor] = {
sample = {
'input': torch.tensor(
data.drop([self.label_name]).values, dtype=torch.float32
),
Expand All @@ -289,29 +264,24 @@ def _load_data(self) -> pd.DataFrame:
the features and label
"""
data_rows = []
for path in self.collection:
for path in sorted(self.files):
with open(path) as f:
content = json.load(f)
data_dict = content['properties']
data_dict['lon'] = content['geometry']['coordinates'][0]
data_dict['lat'] = content['geometry']['coordinates'][1]
data_rows.append(data_dict)

df: pd.DataFrame = pd.DataFrame(data_rows)
df = pd.DataFrame(data_rows)
df = df[self.input_features + [self.label_name]]
return df

def _verify(self) -> None:
"""Verify the integrity of the dataset."""
# Check if the extracted files already exist
pathname = os.path.join(self.root, self.collection_id)
if os.path.exists(pathname):
return

# Check if the zip files have already been downloaded
pathname = os.path.join(self.root, self.collection_id) + '.tar.gz'
if os.path.exists(pathname):
self._extract()
# Check if the files already exist
file_glob = os.path.join(self.root, '**', 'feature_*.geojson')
self.files = glob.glob(file_glob, recursive=True)
if self.files:
return

# Check if the user requested to download the dataset
Expand All @@ -320,19 +290,10 @@ def _verify(self) -> None:

# Download the dataset
self._download()
self._extract()

def _extract(self) -> None:
"""Extract the dataset."""
pathname = os.path.join(self.root, self.collection_id) + '.tar.gz'
extract_archive(pathname, self.root)
self.files = glob.glob(file_glob, recursive=True)

def _download(self, api_key: str | None = None) -> None:
"""Download the dataset and extract it.
Args:
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
"""
download_radiant_mlhub_collection(self.collection_id, self.root, api_key)
filename = os.path.join(self.root, self.collection_id) + '.tar.gz'
extract_archive(filename, self.root)
def _download(self) -> None:
"""Download the dataset and extract it."""
os.makedirs(self.root, exist_ok=True)
azcopy = which('azcopy')
azcopy('sync', self.url, self.root, '--recursive=true')

0 comments on commit 06aa33f

Please sign in to comment.