diff --git a/pyproject.toml b/pyproject.toml index 3a2c2319ee..0b6e5467f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,8 @@ dependencies = [ # einops 0.3+ required for einops.repeat "einops>=0.3", # fiona 1.8.21+ required for Python 3.10 wheels - "fiona>=1.8.21", + # fiona 1.9+ required for fiona.listdir + "fiona>=1.9", # kornia 0.7.3+ required for instance segmentation support in AugmentationSequential "kornia>=0.7.3", # lightly 1.4.5+ required for LARS optimizer diff --git a/requirements/min-reqs.old b/requirements/min-reqs.old index a6e91f70fe..369e09d612 100644 --- a/requirements/min-reqs.old +++ b/requirements/min-reqs.old @@ -3,7 +3,7 @@ setuptools==61.0.0 # install einops==0.3.0 -fiona==1.8.21 +fiona==1.9.0 kornia==0.7.3 lightly==1.4.5 lightning[pytorch-extra]==2.0.0 diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index 0620243837..5db62a5345 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -1,11 +1,11 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. - import math import os import pickle +import shutil import sys -from collections.abc import Iterable +from collections.abc import Generator, Iterable from pathlib import Path from typing import Any @@ -13,6 +13,7 @@ import torch import torch.nn as nn from _pytest.fixtures import SubRequest +from _pytest.tmpdir import TempPathFactory from rasterio.crs import CRS from rasterio.enums import Resampling from torch.utils.data import ConcatDataset @@ -84,6 +85,26 @@ def __len__(self) -> int: return 2 +@pytest.fixture(scope='module') +def module_tmp_path(tmp_path_factory: TempPathFactory) -> Path: + # The default fixture is scoped per function + return tmp_path_factory.mktemp('module_tmp') + + +@pytest.fixture(scope='module') +def temp_archive( + request: SubRequest, module_tmp_path: Path +) -> Generator[tuple[str, str], None, None]: + # Runs before tests + dir_not_zipped = request.param + dir_zipped = shutil.make_archive( + module_tmp_path / dir_not_zipped, 'zip', root_dir=dir_not_zipped + ) + yield dir_not_zipped, dir_zipped + # Runs after tests + os.remove(dir_zipped) + + class TestGeoDataset: @pytest.fixture(scope='class') def dataset(self) -> GeoDataset: @@ -178,29 +199,25 @@ def test_files_property_for_non_existing_file_or_dir(self, tmp_path: Path) -> No with pytest.warns(UserWarning, match='Path was ignored.'): assert len(CustomGeoDataset(paths=paths).files) == 0 - def test_files_property_for_virtual_files(self) -> None: - # Tests only a subset of schemes and combinations. - paths = [ - 'file://directory/file.tif', - 'zip://archive.zip!folder/file.tif', - 'az://azure_bucket/prefix/file.tif', - '/vsiaz/azure_bucket/prefix/file.tif', - 'zip+az://azure_bucket/prefix/archive.zip!folder_in_archive/file.tif', - '/vsizip//vsiaz/azure_bucket/prefix/archive.zip/folder_in_archive/file.tif', - ] - assert len(CustomGeoDataset(paths=paths).files) == len(paths) - - def test_files_property_ordered(self) -> None: + def test_files_property_ordered(self, tmp_path: Path) -> None: """Ensure that the list of files is ordered.""" - paths = ['file://file3.tif', 'file://file1.tif', 'file://file2.tif'] - assert CustomGeoDataset(paths=paths).files == sorted(paths) - def test_files_property_deterministic(self) -> None: + files = ['file3.tif', 'file1.tif', 'file2.tif'] + paths = [tmp_path / fake_file for fake_file in files] + for fake_file in paths: + fake_file.touch() + str_paths = [str(fake_file) for fake_file in paths] + assert CustomGeoDataset(paths=paths).files == sorted(str_paths) + + def test_files_property_deterministic(self, tmp_path: Path) -> None: """Ensure that the list of files is consistent regardless of their original order. """ - paths1 = ['file://file3.tif', 'file://file1.tif', 'file://file2.tif'] - paths2 = ['file://file2.tif', 'file://file3.tif', 'file://file1.tif'] + files = ['file3.tif', 'file1.tif', 'file2.tif'] + paths1 = [tmp_path / fake_file for fake_file in files] + paths2 = paths1[::-1] # reverse order + for fake_file in paths1: + fake_file.touch() assert ( CustomGeoDataset(paths=paths1).files == CustomGeoDataset(paths=paths2).files ) @@ -213,6 +230,86 @@ def test_files_property_mix_str_and_pathlib(self, tmp_path: Path) -> None: ds = CustomGeoDataset(paths=[str(foo), bar]) assert ds.files == [str(bar), str(foo)] + @pytest.mark.parametrize( + 'temp_archive', [os.path.join('tests', 'data', 'vector')], indirect=True + ) + def test_zipped_file(self, temp_archive: tuple[str, str]) -> None: + _, dir_zipped = temp_archive + filename = 'vector_2024.geojson' + + specific_file_zipped = f'{dir_zipped}!{filename}' + + files_found = CustomGeoDataset(paths=f'zip://{specific_file_zipped}').files + assert len(files_found) == 1 + assert str(files_found[0]).endswith(filename) + + @pytest.mark.parametrize( + 'temp_archive', [os.path.join('tests', 'data', 'vector')], indirect=True + ) + def test_zipped_file_non_existing(self, temp_archive: tuple[str, str]) -> None: + _, dir_zipped = temp_archive + with pytest.warns(UserWarning, match='Path was ignored.'): + files = CustomGeoDataset( + paths=f'zip://{dir_zipped}!/non_existing_file.tif' + ).files + assert len(files) == 0 + + @pytest.mark.parametrize( + 'temp_archive', + [ + os.path.join( + 'tests', + 'data', + 'sentinel2', + 'S2A_MSIL2A_20220414T110751_N0400_R108_T26EMU_20220414T165533.SAFE', + ) + ], + indirect=True, + ) + def test_zipped_specific_file_dir(self, temp_archive: tuple[str, str]) -> None: + dir_not_zipped, dir_zipped = temp_archive + + filepath_within_dir = ( + 'GRANULE/L2A_T26EMU_A035569_20220414T110747' + '/IMG_DATA/R60m/T26EMU_20220414T110751_B02_60m.jp2' + ) + + files_found = CustomGeoDataset( + paths=f'zip://{dir_zipped}!/{filepath_within_dir}' + ).files + assert len(files_found) == 1 + assert str(files_found[0]).endswith(filepath_within_dir) + + @pytest.mark.parametrize( + 'temp_archive', + [ + os.path.join( + 'tests', + 'data', + 'sentinel2', + 'S2A_MSIL2A_20220414T110751_N0400_R108_T26EMU_20220414T165533.SAFE', + ) + ], + indirect=True, + ) + def test_zipped_directory(self, temp_archive: tuple[str, str]) -> None: + dir_not_zipped, dir_zipped = temp_archive + bands = Sentinel2.rgb_bands + transforms = nn.Identity() + cache = False + + files_not_zipped = Sentinel2( + paths=dir_not_zipped, bands=bands, transforms=transforms, cache=cache + ).files + + files_zipped = Sentinel2( + paths=f'zip://{dir_zipped}', bands=bands, transforms=transforms, cache=cache + ).files + + basenames_not_zipped = [Path(path).stem for path in files_not_zipped] + basenames_zipped = [Path(path).stem for path in files_zipped] + assert basenames_zipped == basenames_not_zipped + class TestRasterDataset: naip_dir = os.path.join('tests', 'data', 'naip') diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 198065a708..4929b28634 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -6,7 +6,6 @@ import abc import fnmatch import functools -import glob import os import pathlib import re @@ -37,11 +36,11 @@ from .utils import ( BoundingBox, Path, + _list_directory_recursive, array_to_tensor, concat_samples, disambiguate_timestamp, merge_samples, - path_is_vsi, ) @@ -308,13 +307,14 @@ def files(self) -> list[str]: # Using set to remove any duplicates if directories are overlapping files: set[str] = set() for path in paths: - if os.path.isdir(path): - pathname = os.path.join(path, '**', self.filename_glob) - files |= set(glob.iglob(pathname, recursive=True)) - elif (os.path.isfile(path) or path_is_vsi(path)) and fnmatch.fnmatch( - str(path), f'*{self.filename_glob}' + if os.path.isfile(path) and fnmatch.fnmatch( + str(path), os.path.join('*', self.filename_glob) ): files.add(str(path)) + elif files_found := set( + _list_directory_recursive(path, self.filename_glob) + ): + files |= files_found elif not hasattr(self, 'download'): warnings.warn( f"Could not find any relevant files for provided path '{path}'. " diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 237b479326..0f092919c2 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -8,6 +8,8 @@ import collections import contextlib +import fnmatch +import glob import importlib import os import pathlib @@ -19,9 +21,11 @@ from datetime import datetime, timedelta from typing import Any, TypeAlias, cast, overload +import fiona import numpy as np import rasterio import torch +from fiona.errors import FionaValueError from torch import Tensor from torchvision.datasets.utils import ( check_integrity, @@ -607,18 +611,22 @@ def percentile_normalization( return img_normalized -def path_is_vsi(path: Path) -> bool: - """Checks if the given path is pointing to a Virtual File System. +def _path_is_gdal_vsi(path: Path) -> bool: + """Checks if the given path has a GDAL Virtual File System Interface (VSI) prefix. + + This is a path within an Apache Virtual File System (VFS) supported by GDAL and + related libraries (rasterio and fiona). .. note:: Does not check if the path exists, or if it is a dir or file. - VSI can for instance be Cloud Storage Blobs or zip-archives. + VFS can for instance be Cloud Storage Blobs or zip-archives. They will start with a prefix indicating this. For examples of these, see references for the two accepted syntaxes. * https://gdal.org/user/virtual_file_systems.html * https://rasterio.readthedocs.io/en/latest/topics/datasets.html + * https://commons.apache.org/proper/commons-vfs/filesystems.html Args: path: a directory or file @@ -631,6 +639,74 @@ def path_is_vsi(path: Path) -> bool: return '://' in str(path) or str(path).startswith('/vsi') +def _listdir_vfs_recursive(root: Path) -> list[str]: + """Lists all files in Virtual File Systems (VFS) recursively. + + Args: + root: directory to list. These must contain the prefix for the VFS + (e.g., '/vsiaz/' or 'az://' for azure blob storage, or + '/vsizip/' or 'zip://' for zipped archives). + + Returns: + A list of all file paths matching filename_glob in the root VFS directory or its + subdirectories. + + Raises: + FileNotFoundError: If root does not exist. + + .. versionadded:: 0.7 + """ + dirs = [str(root)] + files = [] + while dirs: + dir = dirs.pop() + try: + subdirs = fiona.listdir(dir) + # Don't use os.path.join here because vsi uri's require forward-slash, + # even on windows. + dirs.extend([f'{dir}/{subdir}' for subdir in subdirs]) + except FionaValueError as e: + if 'is not a directory' in str(e): + files.append(dir) + else: + raise FileNotFoundError(f'No such file or directory: {dir}') + return files + + +def _list_directory_recursive(root: Path, filename_glob: str) -> list[str]: + """Lists files in directory recursively matching the given glob expression. + + Also supports GDAL Virtual File Systems (VFS). + + Args: + root: directory to list. For VFS these will have prefix + e.g. /vsiaz/ or az:// for azure blob storage + filename_glob: filename pattern to filter filenames + + Returns: + A list of all file paths matching filename_glob in the root directory or its + subdirectories. + + .. versionadded:: 0.7 + """ + files: list[str] + if _path_is_gdal_vsi(root): + # Change type to match expected input to filter + all_files: list[str] = [] + try: + all_files = _listdir_vfs_recursive(root) + except FileNotFoundError: + # To match the behaviour of glob.iglob we silently return empty list + # for non-existing root. + pass + # Prefix glob with wildcard to ignore directories + files = fnmatch.filter(all_files, f'*{filename_glob}') + else: + pathname = os.path.join(root, '**', filename_glob) + files = glob.glob(pathname, recursive=True) + return files + + def array_to_tensor(array: np.typing.NDArray[Any]) -> Tensor: """Converts a :class:`numpy.ndarray` to :class:`torch.Tensor`.