Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MMFlood dataset #2450

Merged
merged 17 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ L8 Biome

.. autoclass:: L8BiomeDataModule

MMFlood
^^^^^^^^

.. autoclass:: MMFloodDataModule

NAIP
^^^^

Expand Down
4 changes: 4 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ Landsat
.. autoclass:: Landsat2
.. autoclass:: Landsat1

MMFlood
^^^^^^^
.. autoclass:: MMFlood

NAIP
^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/datasets/geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Dataset,Type,Source,License,Size (px),Resolution (m)
`L8 Biome`_,"Imagery, Masks",Landsat,"CC0-1.0","8,900x8,900","15, 30"
`LandCover.ai Geo`_,"Imagery, Masks",Aerial,"CC-BY-NC-SA-4.0","4,200--9,500",0.25--0.5
`Landsat`_,Imagery,Landsat,"public domain","8,900x8,900",30
`MMFlood`_,"Imagery,DEM,Masks","Sentinel, MapZen/TileZen, OpenStreetMap",CC-BY-4.0,"2,147x2,313",20
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
`NAIP`_,Imagery,Aerial,"public domain","6,100x7,600",0.3--2
`NCCM`_,Masks,Sentinel-2,"CC-BY-4.0",-,10
`NLCD`_,Masks,Landsat,"public domain",-,30
Expand Down
18 changes: 18 additions & 0 deletions tests/conf/mmflood.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
model:
class_path: SemanticSegmentationTask
init_args:
loss: 'ce'
model: 'unet'
backbone: 'resnet18'
in_channels: 3
num_classes: 2
num_filters: 1
data:
class_path: MMFloodDataModule
init_args:
batch_size: 1
dict_kwargs:
root: 'tests/data/mmflood'
patch_size: 8
normalization: 'median'
include_dem: True
1 change: 1 addition & 0 deletions tests/data/mmflood/activations.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"EMSR000": {"title": "Test flood", "type": "Flood", "country": "N/A", "start": "2014-11-06T17:57:00", "end": "2015-01-29T12:47:04", "lat": 45.82427031690563, "lon": 14.484407562009336, "subset": "train", "delineations": ["EMSR000_00"]}, "EMSR001": {"title": "Test flood", "type": "Flood", "country": "N/A", "start": "2014-11-06T17:57:00", "end": "2015-01-29T12:47:04", "lat": 45.82427031690563, "lon": 14.484407562009336, "subset": "train", "delineations": ["EMSR001_00"]}, "EMSR003": {"title": "Test flood", "type": "Flood", "country": "N/A", "start": "2014-11-06T17:57:00", "end": "2015-01-29T12:47:04", "lat": 45.82427031690563, "lon": 14.484407562009336, "subset": "val", "delineations": ["EMSR003_00"]}, "EMSR004": {"title": "Test flood", "type": "Flood", "country": "N/A", "start": "2014-11-06T17:57:00", "end": "2015-01-29T12:47:04", "lat": 45.82427031690563, "lon": 14.484407562009336, "subset": "test", "delineations": ["EMSR004_00"]}}
Binary file added tests/data/mmflood/activations.tar.000.gz.part
Binary file not shown.
Binary file added tests/data/mmflood/activations.tar.001.gz.part
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
125 changes: 125 additions & 0 deletions tests/data/mmflood/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import json
import os
import tarfile

import numpy as np
import rasterio
from rasterio.crs import CRS
from rasterio.transform import Affine


def generate_data(path: str, filename: str, height: int, width: int) -> None:
MAX_VALUE = 1000.0
MIN_VALUE = 0.0
RANGE = MAX_VALUE - MIN_VALUE
FOLDERS = ['s1_raw', 'DEM', 'mask']
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
profile = {
'driver': 'GTiff',
'dtype': 'float32',
'nodata': None,
'crs': CRS.from_epsg(4326),
'transform': Affine(
0.0001287974837883981,
0.0,
14.438064999669106,
0.0,
-8.989523639880024e-05,
45.71617928533084,
),
'blockysize': 1,
'tiled': False,
'interleave': 'pixel',
'height': height,
'width': width,
}
data = {
's1_raw': np.random.rand(2, height, width).astype(np.float32) * RANGE
- MIN_VALUE,
'DEM': np.random.rand(1, height, width).astype(np.float32) * RANGE - MIN_VALUE,
'mask': np.random.randint(low=0, high=2, size=(1, height, width)).astype(
np.uint8
),
}

os.makedirs(os.path.join(path, 'hydro'), exist_ok=True)

for folder in FOLDERS:
folder_path = os.path.join(path, folder)
os.makedirs(folder_path, exist_ok=True)
filepath = os.path.join(folder_path, filename)
profile2 = profile.copy()
profile2['count'] = 2 if folder == 's1_raw' else 1
with rasterio.open(filepath, mode='w', **profile2) as src:
src.write(data[folder])

return
lccol marked this conversation as resolved.
Show resolved Hide resolved


def generate_tar_gz(src: str, dst: str) -> None:
with tarfile.open(dst, 'w:gz') as tar:
tar.add(src, arcname=src)
return


def split_tar(path: str, dst: str, nparts: int) -> None:
fstats = os.stat(path)
size = fstats.st_size
chunk = size // nparts

with open(path, 'rb') as fp:
for idx in range(nparts):
part_path = os.path.join(dst, f'activations.tar.{idx:03}.gz.part')

bytes_to_write = chunk if idx < nparts - 1 else size - fp.tell()
with open(part_path, 'wb') as dst_fp:
dst_fp.write(fp.read(bytes_to_write))

return


def generate_folders_and_metadata(datapath: str, metadatapath: str) -> None:
folders_splits = [
('EMSR000', 'train'),
('EMSR001', 'train'),
('EMSR003', 'val'),
('EMSR004', 'test'),
]
num_files = {'EMSR000': 3, 'EMSR001': 2, 'EMSR003': 2, 'EMSR004': 1}
metadata = {}
for folder, split in folders_splits:
data = {}
data['title'] = 'Test flood'
data['type'] = 'Flood'
data['country'] = 'N/A'
data['start'] = '2014-11-06T17:57:00'
data['end'] = '2015-01-29T12:47:04'
data['lat'] = 45.82427031690563
data['lon'] = 14.484407562009336
data['subset'] = split
data['delineations'] = [f'{folder}_00']

dst_folder = os.path.join(datapath, f'{folder}-0')
for idx in range(num_files[folder]):
generate_data(
dst_folder, filename=f'{folder}-{idx}.tif', height=16, width=16
)

metadata[folder] = data

generate_tar_gz(src='activations', dst='activations.tar.gz')
split_tar(path='activations.tar.gz', dst='.', nparts=2)
os.remove('activations.tar.gz')
with open(os.path.join(metadatapath, 'activations.json'), 'w') as fp:
json.dump(metadata, fp)

return


if __name__ == '__main__':
datapath = os.path.join(os.getcwd(), 'activations')
metadatapath = os.getcwd()

generate_folders_and_metadata(datapath, metadatapath)
100 changes: 100 additions & 0 deletions tests/datasets/test_mmflood.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
from itertools import product
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from rasterio.crs import CRS

from torchgeo.datasets import (
BoundingBox,
DatasetNotFoundError,
IntersectionDataset,
MMFlood,
UnionDataset,
)


class TestMMFlood:
@pytest.fixture(params=product([True, False], ['train', 'val', 'test']))
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> MMFlood:
dataset_root = os.path.join('tests', 'data', 'mmflood/')
url = os.path.join(dataset_root)
lccol marked this conversation as resolved.
Show resolved Hide resolved

monkeypatch.setattr(MMFlood, 'url', url)
monkeypatch.setattr(MMFlood, '_nparts', 2)

include_dem, split = request.param
root = tmp_path
return MMFlood(
root,
split=split,
include_dem=include_dem,
transforms=nn.Identity(),
download=True,
checksum=True,
)

def test_getitem(self, dataset: MMFlood) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x['crs'], CRS)
assert isinstance(x['image'], torch.Tensor)
assert isinstance(x['mask'], torch.Tensor)

# If DEM is included, check if 3 channels are present, 2 otherwise
if dataset.include_dem:
assert x['image'].size(0) == 3
else:
assert x['image'].size(0) == 2
return

def test_len(self, dataset: MMFlood) -> None:
if dataset.split == 'train':
assert len(dataset) == 5
elif dataset.split == 'val':
assert len(dataset) == 2
else:
assert len(dataset) == 1

def test_and(self, dataset: MMFlood) -> None:
ds = dataset & dataset
assert isinstance(ds, IntersectionDataset)

def test_or(self, dataset: MMFlood) -> None:
ds = dataset | dataset
assert isinstance(ds, UnionDataset)

def test_already_downloaded(self, dataset: MMFlood) -> None:
MMFlood(root=dataset.root)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
MMFlood(tmp_path)

def test_plot(self, dataset: MMFlood) -> None:
x = dataset[dataset.bounds]
dataset.plot(x, suptitle='Test')
plt.close()

def test_plot_prediction(self, dataset: MMFlood) -> None:
x = dataset[dataset.bounds]
x['prediction'] = x['mask'].clone()
dataset.plot(x, suptitle='Prediction')
plt.close()

def test_invalid_query(self, dataset: MMFlood) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
with pytest.raises(
IndexError, match='query: .* not found in index with bounds:'
):
dataset[query]
1 change: 1 addition & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class TestSemanticSegmentationTask:
'landcoverai',
'landcoverai100',
'loveda',
'mmflood',
'naipchesapeake',
'potsdam2d',
'sen12ms_all',
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .landcoverai import LandCoverAI100DataModule, LandCoverAIDataModule
from .levircd import LEVIRCDDataModule, LEVIRCDPlusDataModule
from .loveda import LoveDADataModule
from .mmflood import MMFloodDataModule
from .naip import NAIPChesapeakeDataModule
from .nasa_marine_debris import NASAMarineDebrisDataModule
from .oscd import OSCDDataModule
Expand Down Expand Up @@ -87,6 +88,7 @@
'LandCoverAI100DataModule',
'LandCoverAIDataModule',
'LoveDADataModule',
'MMFloodDataModule',
'MisconfigurationException',
'NAIPChesapeakeDataModule',
'NASAMarineDebrisDataModule',
Expand Down
Loading
Loading