Skip to content

Commit

Permalink
Refactor with respect to other branch
Browse files Browse the repository at this point in the history
  • Loading branch information
adriantre committed Jun 22, 2023
1 parent 91a7080 commit 6624dee
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 15 deletions.
9 changes: 4 additions & 5 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
BoundingBox,
concat_samples,
disambiguate_timestamp,
listdir_vsi_recursive,
list_directory_recursive,
merge_samples,
)

Expand Down Expand Up @@ -341,7 +341,6 @@ def __init__(
bands: Optional[Sequence[str]] = None,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
cache: bool = True,
vsi: bool = False,
) -> None:
"""Initialize a new Dataset instance.
Expand All @@ -356,7 +355,6 @@ def __init__(
transforms: a function/transform that takes an input sample
and returns a transformed version
cache: if True, cache file handle to speed up repeated sampling
vsi: if True, will support GDAL Virtual File Systems
Raises:
FileNotFoundError: if no files are found in ``root``
Expand All @@ -376,8 +374,9 @@ def __init__(
if ext:
filespaths.append(dir_or_file)
else:
pathname = os.path.join(dir_or_file, "**", self.filename_glob)
filespaths.extend(glob.iglob(pathname, recursive=True))
filespaths.extend(
list_directory_recursive(dir_or_file, self.filename_glob)
)

# Populate the dataset index
i = 0
Expand Down
40 changes: 30 additions & 10 deletions torchgeo/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import bz2
import collections
import contextlib
import fnmatch
import glob
import gzip
import lzma
import os
Expand Down Expand Up @@ -45,7 +47,7 @@
"draw_semantic_segmentation_masks",
"rgb_to_mask",
"percentile_normalization",
"listdir_vsi_recursive",
"list_directory_recursive",
)


Expand Down Expand Up @@ -742,18 +744,17 @@ def percentile_normalization(
return img_normalized


def listdir_vsi_recursive(root: str) -> list[str]:
"""Walk directory and return list of all files in subdirectories.
def _path_is_vsi(path: str) -> bool:
from rasterio._path import SCHEMES

Also supports listing filenames within a GDAL Virtual File System like
cloud buckets. https://gdal.org/user/virtual_file_systems.html
prefix = path.split("://")[0]
schemes = prefix.split("+")
is_apache_vfs_scheme = set(schemes).issubset(set(SCHEMES))
is_gdal_vsi = path.startswith("/vsi")
return is_gdal_vsi or is_apache_vfs_scheme

Args:
root: root directory or blob in bucket

Returns
List of absolute filepaths withing the root
"""
def _listdir_vsi_recursive(root: str) -> list[str]:
dirs = [root]
files = []
while dirs:
Expand All @@ -764,3 +765,22 @@ def listdir_vsi_recursive(root: str) -> list[str]:
except FionaValueError:
files.append(dir)
return files


def list_directory_recursive(root: str, filename_glob: str) -> list[str]:
"""Lists files in directory recursively.
Also supports gdal virtual file systems (vsi).
Args:
root: directory to list. For vsi these can start with
e.g. /vsiaz or az:// for azure blob storage
filename_glob: filename pattern to filter filenames
"""
if _path_is_vsi(root):
filepaths = _listdir_vsi_recursive(root)
filepaths = fnmatch.filter(filepaths, filename_glob)
else:
pathname = os.path.join(root, "**", filename_glob)
filepaths = list(glob.iglob(pathname, recursive=True))
return filepaths

0 comments on commit 6624dee

Please sign in to comment.