Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I/O Bench: add new dataset #1972

Merged
merged 19 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ NAIP

.. autoclass:: NAIPChesapeakeDataModule

I/O Bench
^^^^^^^^^

.. autoclass:: IOBenchDataModule

Sentinel
^^^^^^^^

Expand Down
7 changes: 6 additions & 1 deletion docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ iNaturalist

.. autoclass:: INaturalist

I/O Bench
^^^^^^^^^

.. autoclass:: IOBench

L7 Irish
^^^^^^^^

Expand Down Expand Up @@ -176,7 +181,7 @@ South Africa Crop Type

.. autoclass:: SouthAfricaCropType

South America Soybean
South America Soybean
^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: SouthAmericaSoybean
Expand Down
1 change: 1 addition & 0 deletions docs/api/geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Dataset,Type,Source,License,Size (px),Resolution (m)
`GBIF`_,Points,Citizen Scientists,"CC0-1.0 OR CC-BY-4.0 OR CC-BY-NC-4.0",-,-
`GlobBiomass`_,Masks,Landsat,"CC-BY-4.0","45,000x45,000",100
`iNaturalist`_,Points,Citizen Scientists,-,-,-
`I/O Bench`_,"Imagery, Masks",Landsat,"CC-BY-4.0","8,000x8,000",30
`L7 Irish`_,"Imagery, Masks",Landsat,"CC0-1.0","8,400x7,500","15, 30"
`L8 Biome`_,"Imagery, Masks",Landsat,"CC0-1.0","8,900x8,900","15, 30"
`LandCover.ai Geo`_,"Imagery, Masks",Aerial,"CC-BY-NC-SA-4.0","4,200--9,500",0.25--0.5
Expand Down
21 changes: 21 additions & 0 deletions docs/user/contributing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,24 @@ A major component of TorchGeo is the large collection of :mod:`torchgeo.datasets
* Add the dataset metadata to either ``docs/api/geo_datasets.csv`` or ``docs/api/non_geo_datasets.csv``

A good way to get started is by looking at some of the existing implementations that are most closely related to the dataset that you are implementing (e.g. if you are implementing a semantic segmentation dataset, looking at the LandCover.ai dataset implementation would be a good starting point).

I/O Benchmarking
----------------

For PRs that may affect GeoDataset sampling speed, you can test the performance impact as follows. On the main branch (before) and on your PR branch (after), run the following commands:

.. code-block:: console

$ python -m torchgeo fit --config tests/conf/io_raw.yaml
$ python -m torchgeo fit --config tests/conf/io_preprocessed.yaml

This code will download a small (1 GB) dataset consisting of a single Landsat 9 scene and CDL file. It will then profile the speed at which various samplers work for both raw data (original downloaded files) and preprocessed data (same CRS, res, TAP, COG). The important output to look out for is the total time taken by ``train_dataloader_next`` (RandomGeoSampler) and ``val_next`` (GridGeoSampler). With this, you can create a table on your PR like:

====== ============ ========== ===================== ===================
raw (random) raw (grid) preprocessed (random) preprocessed (grid)
====== ============ ========== ===================== ===================
before 17.223 10.974 15.685 4.6075
after 17.360 11.032 9.613 4.6673
====== ============ ========== ===================== ===================

In this example, we see a 60% speed-up for RandomGeoSampler on preprocessed data. All other numbers are more or less the same across multiple runs.
13 changes: 13 additions & 0 deletions tests/conf/io_preprocessed.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
model:
class_path: IOBenchTask
data:
class_path: IOBenchDataModule
dict_kwargs:
root: "data/io"
split: "preprocessed"
download: true
checksum: true
trainer:
max_epochs: 1
num_sanity_val_steps: 0
profiler: "simple"
13 changes: 13 additions & 0 deletions tests/conf/io_raw.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
model:
class_path: IOBenchTask
data:
class_path: IOBenchDataModule
dict_kwargs:
root: "data/io"
split: "raw"
download: true
checksum: true
trainer:
max_epochs: 1
num_sanity_val_steps: 0
profiler: "simple"
9 changes: 9 additions & 0 deletions tests/conf/iobench.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
model:
class_path: IOBenchTask
data:
class_path: IOBenchDataModule
init_args:
batch_size: 2
patch_size: 16
dict_kwargs:
root: "tests/data/iobench"
56 changes: 56 additions & 0 deletions tests/data/iobench/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import os
import shutil

import numpy as np
import rasterio
from rasterio import Affine
from rasterio.crs import CRS

SIZE = 16

np.random.seed(0)


def create_file(path: str, dtype: str) -> None:
profile = {
"driver": "GTiff",
"dtype": dtype,
"width": SIZE,
"height": SIZE,
"count": 1,
"crs": CRS.from_epsg(32616),
"transform": Affine(30.0, 0.0, 229800.0, 0.0, -30.0, 4585230.0),
"compress": "lzw",
"predictor": 2,
}

Z = np.random.randint(size=(SIZE, SIZE), low=0, high=np.iinfo(dtype).max)

with rasterio.open(path, "w", **profile) as src:
src.write(Z, 1)


bands = ["B1", "B2", "B3", "B4", "B5", "B6", "B7", "QA_AEROSOL"]

os.makedirs(os.path.join("preprocessed", "cdl"), exist_ok=True)
os.makedirs(os.path.join("preprocessed", "landsat"), exist_ok=True)

create_file(os.path.join("preprocessed", "cdl", "2023_30m_cdls.tif"), "uint8")
for band in bands:
path = f"LC09_L2SP_023032_20230620_20230622_02_T1_SR_{band}.TIF"
path = os.path.join("preprocessed", "landsat", path)
create_file(path, "uint16")

# Compress data
shutil.make_archive("preprocessed", "gztar", ".", "preprocessed")

# Compute checksums
with open("preprocessed.tar.gz", "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(md5)
Binary file added tests/data/iobench/preprocessed.tar.gz
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
92 changes: 92 additions & 0 deletions tests/datasets/test_iobench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import glob
import os
import shutil
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from pytest import MonkeyPatch
from rasterio.crs import CRS

import torchgeo.datasets.utils
from torchgeo.datasets import (
BoundingBox,
DatasetNotFoundError,
IntersectionDataset,
IOBench,
RGBBandsMissingError,
UnionDataset,
)


def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)


class TestIOBench:
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> IOBench:
monkeypatch.setattr(torchgeo.datasets.iobench, "download_url", download_url)
md5 = "e82398add7c35896a31c4398c608ef83"
url = os.path.join("tests", "data", "iobench", "{}.tar.gz")
monkeypatch.setattr(IOBench, "url", url)
monkeypatch.setitem(IOBench.md5s, "preprocessed", md5)
root = str(tmp_path)
transforms = nn.Identity()
return IOBench(root, transforms=transforms, download=True, checksum=True)

def test_getitem(self, dataset: IOBench) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x["crs"], CRS)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["mask"], torch.Tensor)

def test_and(self, dataset: IOBench) -> None:
ds = dataset & dataset
assert isinstance(ds, IntersectionDataset)

def test_or(self, dataset: IOBench) -> None:
ds = dataset | dataset
assert isinstance(ds, UnionDataset)

def test_plot(self, dataset: IOBench) -> None:
x = dataset[dataset.bounds]
dataset.plot(x, suptitle="Test")
plt.close()

def test_already_extracted(self, dataset: IOBench) -> None:
IOBench(dataset.root, download=True)

def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join("tests", "data", "iobench", "*.tar.gz")
root = str(tmp_path)
for tarfile in glob.iglob(pathname):
shutil.copy(tarfile, root)
IOBench(root)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
IOBench(str(tmp_path))

def test_invalid_query(self, dataset: IOBench) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
with pytest.raises(
IndexError, match="query: .* not found in index with bounds:"
):
dataset[query]

def test_rgb_bands_absent_plot(self, dataset: IOBench) -> None:
with pytest.raises(
RGBBandsMissingError, match="Dataset does not contain some of the RGB bands"
):
print(dataset.root)
ds = IOBench(dataset.root, bands=["SR_B1", "SR_B2", "SR_B3"])
x = ds[ds.bounds]
ds.plot(x, suptitle="Test")
plt.close()
38 changes: 38 additions & 0 deletions tests/trainers/test_iobench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os

import pytest

from torchgeo.datamodules import MisconfigurationException
from torchgeo.main import main


class TestClassificationTask:
@pytest.mark.parametrize("name", ["iobench"])
def test_trainer(self, name: str, fast_dev_run: bool) -> None:
config = os.path.join("tests", "conf", name + ".yaml")

args = [
"--config",
config,
"--trainer.accelerator",
"cpu",
"--trainer.fast_dev_run",
str(fast_dev_run),
"--trainer.max_epochs",
"1",
"--trainer.log_every_n_steps",
"1",
]

main(["fit"] + args)
try:
main(["test"] + args)
except MisconfigurationException:
pass
try:
main(["predict"] + args)
except MisconfigurationException:
pass
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .geo import BaseDataModule, GeoDataModule, NonGeoDataModule
from .gid15 import GID15DataModule
from .inria import InriaAerialImageLabelingDataModule
from .iobench import IOBenchDataModule
from .l7irish import L7IrishDataModule
from .l8biome import L8BiomeDataModule
from .landcoverai import LandCoverAIDataModule
Expand Down Expand Up @@ -50,6 +51,7 @@
# GeoDataset
"AgriFieldNetDataModule",
"ChesapeakeCVPRDataModule",
"IOBenchDataModule",
"L7IrishDataModule",
"L8BiomeDataModule",
"NAIPChesapeakeDataModule",
Expand Down
65 changes: 65 additions & 0 deletions torchgeo/datamodules/iobench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""I/O benchmark datamodule."""

from typing import Any

from ..datasets import IOBench
from ..samplers import GridGeoSampler, RandomGeoSampler
from .geo import GeoDataModule


class IOBenchDataModule(GeoDataModule):
"""LightningDataModule implementation for the I/O benchmark dataset.

.. versionadded:: 0.6
"""

def __init__(
self,
batch_size: int = 32,
patch_size: int | tuple[int, int] = 256,
length: int | None = None,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a new IOBenchDataModule instance.

Args:
batch_size: Size of each mini-batch.
patch_size: Size of each patch, either ``size`` or ``(height, width)``.
length: Length of each training epoch.
num_workers: Number of workers for parallel data loading.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.IOBench`.
"""
super().__init__(
IOBench,
batch_size=batch_size,
patch_size=patch_size,
length=length,
num_workers=num_workers,
**kwargs,
)

def setup(self, stage: str) -> None:
"""Set up datasets.

Args:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
self.dataset = IOBench(**self.kwargs)

if stage in ["fit"]:
self.train_sampler = RandomGeoSampler(
self.dataset, self.patch_size, self.length
)
if stage in ["fit", "validate"]:
self.val_sampler = GridGeoSampler(
self.dataset, self.patch_size, self.patch_size
)
if stage in ["test"]:
self.test_sampler = GridGeoSampler(
self.dataset, self.patch_size, self.patch_size
)
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from .idtrees import IDTReeS
from .inaturalist import INaturalist
from .inria import InriaAerialImageLabeling
from .iobench import IOBench
from .l7irish import L7Irish
from .l8biome import L8Biome
from .landcoverai import LandCoverAI, LandCoverAIBase, LandCoverAIGeo
Expand Down Expand Up @@ -166,6 +167,7 @@
"GBIF",
"GlobBiomass",
"INaturalist",
"IOBench",
"L7Irish",
"L8Biome",
"LandCoverAIBase",
Expand Down
Loading
Loading