From eed722d7b104f7389d0f208cf7182ca268356920 Mon Sep 17 00:00:00 2001 From: Dimitris Mantas <75796651+DimitrisMantas@users.noreply.github.com> Date: Mon, 26 Feb 2024 23:45:52 +0100 Subject: [PATCH] Implement deterministic GeoDataset (#1908) --- tests/datasets/test_geo.py | 15 +++++++++++++++ torchgeo/datasets/geo.py | 5 +++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index 8202c7a92b0..df41223615c 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -171,6 +171,21 @@ def test_files_property_for_virtual_files(self) -> None: ] assert len(CustomGeoDataset(paths=paths).files) == len(paths) + def test_files_property_ordered(self) -> None: + """Ensure that the list of files is ordered.""" + paths = ["file://file3.tif", "file://file1.tif", "file://file2.tif"] + assert CustomGeoDataset(paths=paths).files == sorted(paths) + + def test_files_property_deterministic(self) -> None: + """Ensure that the list of files is consistent regardless of their original + order. + """ + paths1 = ["file://file3.tif", "file://file1.tif", "file://file2.tif"] + paths2 = ["file://file2.tif", "file://file3.tif", "file://file1.tif"] + assert ( + CustomGeoDataset(paths=paths1).files == CustomGeoDataset(paths=paths2).files + ) + class TestRasterDataset: @pytest.fixture(params=zip([["R", "G", "B"], None], [True, False])) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 9f90e8b75bf..662161b8193 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -285,7 +285,7 @@ def res(self, new_res: float) -> None: self._res = new_res @property - def files(self) -> set[str]: + def files(self) -> list[str]: """A list of all files in the dataset. Returns: @@ -314,7 +314,8 @@ def files(self) -> set[str]: UserWarning, ) - return files + # Sort the output to enforce deterministic behavior. + return sorted(files) class RasterDataset(GeoDataset):