Skip to content

Commit

Permalink
NASA Marine Debris: radiant mlhub -> source coop (#2183)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart authored Jul 27, 2024
1 parent f6a49b8 commit 900e8a1
Show file tree
Hide file tree
Showing 23 changed files with 102 additions and 173 deletions.
58 changes: 58 additions & 0 deletions tests/data/nasa_marine_debris/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os

import numpy as np
import rasterio as rio
from rasterio import Affine
from rasterio.crs import CRS

SIZE = 32
DTYPE = np.uint8

np.random.seed(0)

profile = {
'driver': 'GTiff',
'dtype': DTYPE,
'width': SIZE,
'height': SIZE,
'count': 3,
'crs': CRS.from_epsg(4326),
'transform': Affine(
2.1457672119140625e-05,
0.0,
-87.626953125,
0.0,
-2.0629065249348766e-05,
15.977172621632805,
),
}

os.makedirs('source', exist_ok=True)
os.makedirs('labels', exist_ok=True)

files = [
'20160928_153233_0e16_16816-29821-16',
'20160928_153233_0e16_16816-29824-16',
'20160928_153233_0e16_16816-29825-16',
'20160928_153233_0e16_16816-29828-16',
'20160928_153233_0e16_16816-29829-16',
]
for file in files:
with rio.open(os.path.join('source', f'{file}.tif'), 'w', **profile) as f:
for i in range(1, 4):
Z = np.random.randint(np.iinfo(DTYPE).max, size=(SIZE, SIZE), dtype=DTYPE)
f.write(Z, i)

count = np.random.randint(5)
x = np.random.randint(SIZE, size=count)
y = np.random.randint(SIZE, size=count)
dx = np.random.randint(5, size=count)
dy = np.random.randint(5, size=count)
label = np.ones(count)
Z = np.stack([x, y, x + dx, y + dy, label], axis=-1)
np.save(os.path.join('labels', f'{file}.npy'), Z)
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
69 changes: 10 additions & 59 deletions tests/datasets/test_nasa_marine_debris.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import glob
import os
import shutil
from pathlib import Path

import matplotlib.pyplot as plt
Expand All @@ -13,41 +11,18 @@
from pytest import MonkeyPatch

from torchgeo.datasets import DatasetNotFoundError, NASAMarineDebris


class Collection:
def download(self, output_dir: str, **kwargs: str) -> None:
glob_path = os.path.join('tests', 'data', 'nasa_marine_debris', '*.tar.gz')
for tarball in glob.iglob(glob_path):
shutil.copy(tarball, output_dir)


def fetch(collection_id: str, **kwargs: str) -> Collection:
return Collection()


class Collection_corrupted:
def download(self, output_dir: str, **kwargs: str) -> None:
filenames = NASAMarineDebris.filenames
for filename in filenames:
with open(os.path.join(output_dir, filename), 'w') as f:
f.write('bad')


def fetch_corrupted(collection_id: str, **kwargs: str) -> Collection_corrupted:
return Collection_corrupted()
from torchgeo.datasets.utils import Executable


class TestNASAMarineDebris:
@pytest.fixture()
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> NASAMarineDebris:
radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3')
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch)
md5s = ['6f4f0d2313323950e45bf3fc0c09b5de', '540cf1cf4fd2c13b609d0355abe955d7']
monkeypatch.setattr(NASAMarineDebris, 'md5s', md5s)
root = tmp_path
@pytest.fixture
def dataset(
self, azcopy: Executable, monkeypatch: MonkeyPatch, tmp_path: Path
) -> NASAMarineDebris:
url = os.path.join('tests', 'data', 'nasa_marine_debris')
monkeypatch.setattr(NASAMarineDebris, 'url', url)
transforms = nn.Identity()
return NASAMarineDebris(root, transforms, download=True, checksum=True)
return NASAMarineDebris(tmp_path, transforms, download=True)

def test_getitem(self, dataset: NASAMarineDebris) -> None:
x = dataset[0]
Expand All @@ -58,36 +33,12 @@ def test_getitem(self, dataset: NASAMarineDebris) -> None:
assert x['boxes'].shape[-1] == 4

def test_len(self, dataset: NASAMarineDebris) -> None:
assert len(dataset) == 4
assert len(dataset) == 5

def test_already_downloaded(
self, dataset: NASAMarineDebris, tmp_path: Path
) -> None:
NASAMarineDebris(root=tmp_path, download=True)

def test_already_downloaded_not_extracted(
self, dataset: NASAMarineDebris, tmp_path: Path
) -> None:
shutil.rmtree(dataset.root)
os.makedirs(tmp_path, exist_ok=True)
Collection().download(output_dir=str(tmp_path))
NASAMarineDebris(root=tmp_path, download=False)

def test_corrupted_previously_downloaded(self, tmp_path: Path) -> None:
filenames = NASAMarineDebris.filenames
for filename in filenames:
with open(os.path.join(tmp_path, filename), 'w') as f:
f.write('bad')
with pytest.raises(RuntimeError, match='Dataset checksum mismatch.'):
NASAMarineDebris(root=tmp_path, download=False, checksum=True)

def test_corrupted_new_download(
self, tmp_path: Path, monkeypatch: MonkeyPatch
) -> None:
with pytest.raises(RuntimeError, match='Dataset checksum mismatch.'):
radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3')
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_corrupted)
NASAMarineDebris(root=tmp_path, download=True, checksum=True)
NASAMarineDebris(tmp_path, download=True)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
Expand Down
148 changes: 34 additions & 114 deletions torchgeo/datasets/nasa_marine_debris.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""NASA Marine Debris dataset."""

import glob
import os
from collections.abc import Callable

Expand All @@ -16,18 +17,13 @@

from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import (
Path,
check_integrity,
download_radiant_mlhub_collection,
extract_archive,
)
from .utils import Path, which


class NASAMarineDebris(NonGeoDataset):
"""NASA Marine Debris dataset.
The `NASA Marine Debris <https://mlhub.earth/data/nasa_marine_debris>`__
The `NASA Marine Debris <https://beta.source.coop/repositories/nasa/marine-debris/>`__
dataset is a dataset for detection of floating marine debris in satellite imagery.
Dataset features:
Expand All @@ -52,26 +48,19 @@ class NASAMarineDebris(NonGeoDataset):
This dataset requires the following additional library to be installed:
* `radiant-mlhub <https://pypi.org/project/radiant-mlhub/>`_ to download the
imagery and labels from the Radiant Earth MLHub
* `azcopy <https://github.com/Azure/azure-storage-azcopy>`_: to download the
dataset from Source Cooperative.
.. versionadded:: 0.2
"""

collection_ids = ['nasa_marine_debris_source', 'nasa_marine_debris_labels']
directories = ['nasa_marine_debris_source', 'nasa_marine_debris_labels']
filenames = ['nasa_marine_debris_source.tar.gz', 'nasa_marine_debris_labels.tar.gz']
md5s = ['fe8698d1e68b3f24f0b86b04419a797d', 'd8084f5a72778349e07ac90ec1e1d990']
class_label = 'marine_debris'
url = 'https://radiantearth.blob.core.windows.net/mlhub/nasa-marine-debris'

def __init__(
self,
root: Path = 'data',
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
api_key: str | None = None,
checksum: bool = False,
verbose: bool = False,
) -> None:
"""Initialize a new NASA Marine Debris Dataset instance.
Expand All @@ -80,21 +69,18 @@ def __init__(
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
checksum: if True, check the MD5 of the downloaded files (may be slow)
verbose: if True, print messages when new tiles are loaded
Raises:
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
self.root = root
self.transforms = transforms
self.download = download
self.api_key = api_key
self.checksum = checksum
self.verbose = verbose

self._verify()
self.files = self._load_files()

self.source = sorted(glob.glob(os.path.join(self.root, 'source', '*.tif')))
self.labels = sorted(glob.glob(os.path.join(self.root, 'labels', '*.npy')))

def __getitem__(self, index: int) -> dict[str, Tensor]:
"""Return an index within the dataset.
Expand All @@ -105,15 +91,21 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
Returns:
data and labels at that index
"""
image = self._load_image(self.files[index]['image'])
boxes = self._load_target(self.files[index]['target'])
sample = {'image': image, 'boxes': boxes}
with rasterio.open(self.source[index]) as source:
image = torch.from_numpy(source.read()).float()

labels = np.load(self.labels[index])

# Boxes contain unnecessary value of 1 after xyxy coords
boxes = torch.from_numpy(labels[:, :4])

# Filter invalid boxes
w_check = (sample['boxes'][:, 2] - sample['boxes'][:, 0]) > 0
h_check = (sample['boxes'][:, 3] - sample['boxes'][:, 1]) > 0
w_check = (boxes[:, 2] - boxes[:, 0]) > 0
h_check = (boxes[:, 3] - boxes[:, 1]) > 0
indices = w_check & h_check
sample['boxes'] = sample['boxes'][indices]
boxes = boxes[indices]

sample = {'image': image, 'boxes': boxes}

if self.transforms is not None:
sample = self.transforms(sample)
Expand All @@ -126,100 +118,28 @@ def __len__(self) -> int:
Returns:
length of the dataset
"""
return len(self.files)

def _load_image(self, path: Path) -> Tensor:
"""Load a single image.
Args:
path: path to the image
Returns:
the image
"""
with rasterio.open(path) as f:
array = f.read()
tensor = torch.from_numpy(array).float()
return tensor

def _load_target(self, path: Path) -> Tensor:
"""Load the target bounding boxes for a single image.
Args:
path: path to the labels
Returns:
the target boxes
"""
array = np.load(path)
# boxes contain unecessary value of 1 after xyxy coords
array = array[:, :4]
tensor = torch.from_numpy(array)
return tensor

def _load_files(self) -> list[dict[str, str]]:
"""Load a image and label files.
Returns:
list of dicts containing image and label files
"""
image_root = os.path.join(self.root, self.directories[0])
target_root = os.path.join(self.root, self.directories[1])
image_folders = sorted(
f for f in os.listdir(image_root) if not f.endswith('json')
)

files = []
for folder in image_folders:
files.append(
{
'image': os.path.join(image_root, folder, 'image_geotiff.tif'),
'target': os.path.join(
target_root,
folder.replace('source', 'labels'),
'pixel_bounds.npy',
),
}
)
return files
return len(self.source)

def _verify(self) -> None:
"""Verify the integrity of the dataset."""
# Check if the files already exist
exists = [
os.path.exists(os.path.join(self.root, directory))
for directory in self.directories
]
if all(exists):
return

# Check if zip file already exists (if so then extract)
exists = []
for filename, md5 in zip(self.filenames, self.md5s):
filepath = os.path.join(self.root, filename)
if os.path.exists(filepath):
if self.checksum and not check_integrity(filepath, md5):
raise RuntimeError('Dataset checksum mismatch.')
exists.append(True)
extract_archive(filepath)
else:
exists.append(False)

# Check if the directories already exist
dirs = ['source', 'labels']
exists = [os.path.exists(os.path.join(self.root, d)) for d in dirs]
if all(exists):
return

# Check if the user requested to download the dataset
if not self.download:
raise DatasetNotFoundError(self)

# Download and extract the dataset
for collection_id in self.collection_ids:
download_radiant_mlhub_collection(collection_id, self.root, self.api_key)
for filename, md5 in zip(self.filenames, self.md5s):
filepath = os.path.join(self.root, filename)
if self.checksum and not check_integrity(filepath, md5):
raise RuntimeError('Dataset checksum mismatch.')
extract_archive(filepath)
# Download the dataset
self._download()

def _download(self) -> None:
"""Download the dataset."""
os.makedirs(self.root, exist_ok=True)
azcopy = which('azcopy')
azcopy('sync', self.url, self.root, '--recursive=true')

def plot(
self,
Expand Down

0 comments on commit 900e8a1

Please sign in to comment.