From da2399b47997044d9956555bf229cdd453bc493d Mon Sep 17 00:00:00 2001 From: Burak Date: Tue, 9 Apr 2024 15:15:00 +0200 Subject: [PATCH 1/4] minor typo in custom_raster_dataset.ipynb --- docs/tutorials/custom_raster_dataset.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/custom_raster_dataset.ipynb b/docs/tutorials/custom_raster_dataset.ipynb index 7401e580edb..e4da8499114 100644 --- a/docs/tutorials/custom_raster_dataset.ipynb +++ b/docs/tutorials/custom_raster_dataset.ipynb @@ -345,7 +345,7 @@ "\n", "### `rgb_bands`\n", "\n", - "If your data is a multispectral iamge, you can define which bands correspond to the red, green, and blue channels. In the case of Sentinel-2, this corresponds to B04, B03, and B02, in that order.\n", + "If your data is a multispectral image, you can define which bands correspond to the red, green, and blue channels. In the case of Sentinel-2, this corresponds to B04, B03, and B02, in that order.\n", "\n", "Putting this all together into a single class, we get:" ] From 62919bfcf08f73d74ce9553d617319b57437fc49 Mon Sep 17 00:00:00 2001 From: Burak <68427259+burakekim@users.noreply.github.com> Date: Mon, 18 Nov 2024 15:46:50 +0100 Subject: [PATCH 2/4] xview2distshift dataset --- docs/api/datasets.rst | 1 + torchgeo/datasets/__init__.py | 3 +- torchgeo/datasets/xview.py | 169 ++++++++++++++++++++++++++++++++++ 3 files changed, 172 insertions(+), 1 deletion(-) diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 20ce1bfcbac..8b5149ade43 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -462,6 +462,7 @@ xView2 ^^^^^^ .. autoclass:: XView2 +.. autoclass:: XView2DistShift ZueriCrop ^^^^^^^^^ diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 6a15fabdf76..f84760ce865 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -136,7 +136,7 @@ from .vaihingen import Vaihingen2D from .vhr10 import VHR10 from .western_usa_live_fuel_moisture import WesternUSALiveFuelMoisture -from .xview import XView2 +from .xview import XView2, XView2DistShift from .zuericrop import ZueriCrop __all__ = ( @@ -258,6 +258,7 @@ 'VHR10', 'WesternUSALiveFuelMoisture', 'XView2', + 'XView2DistShift', 'ZueriCrop', # Base classes 'GeoDataset', diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index 5716c06f593..98fe3ceb534 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -270,3 +270,172 @@ def plot( plt.suptitle(suptitle) return fig + +class XView2DistShift(XView2): + """ + A subclass of the XView2 dataset designed to reformat the original train/test splits + based on specific in-domain (ID) and out-of-domain (OOD) disasters. + + This class allows for the selection of particular disasters to be used as the + training set (in-domain) and test set (out-of-domain). The dataset can be split + according to the disaster names specified by the user, enabling the model to train + on one disaster type and evaluate on a different, out-of-domain disaster. The goal + is to test the generalization ability of models trained on one disaster to perform + on others. + """ + + classes = ["background", "building"] + + # List of disaster names + valid_disasters = [ + 'hurricane-harvey', 'socal-fire', 'hurricane-matthew', 'mexico-earthquake', + 'guatemala-volcano', 'santa-rosa-wildfire', 'palu-tsunami', 'hurricane-florence', + 'hurricane-michael', 'midwest-flooding' + ] + + def __init__( + self, + root: Path = "data", + split: str = "train", + id_ood_disaster: List[Dict[str, str]] = [{"disaster_name": "hurricane-matthew", "pre-post": "post"}, {"disaster_name": "mexico-earthquake", "pre-post": "post"}], + transforms: Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] = None, + checksum: bool = False, + **kwargs + ) -> None: + """Initialize the XView2DistShift dataset instance. + + Args: + root: Root directory where the dataset is located. + split: One of "train" or "test". + id_ood_disaster: List containing in-distribution and out-of-distribution disaster names. + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + AssertionError: If *split* is invalid. + ValueError: If a disaster name in `id_ood_disaster` is not one of the valid disasters. + DatasetNotFoundError: If dataset is not found. + """ + assert split in ["train", "test"], "Split must be either 'train' or 'test'." + # Validate that the disasters are valid + + if id_ood_disaster[0]['disaster_name'] not in self.valid_disasters or id_ood_disaster[1]['disaster_name'] not in self.valid_disasters: + raise ValueError(f"Invalid disaster names. Valid options are: {', '.join(self.valid_disasters)}") + + self.root = root + self.split = split + self.transforms = transforms + self.checksum = checksum + + self._verify() + + # Load all files and compute basenames and disasters only once + self.all_files = self._initialize_files(root) + + # Split logic by disaster and pre-post type + self.files = self._load_split_files_by_disaster_and_type(self.all_files, id_ood_disaster[0], id_ood_disaster[1]) + print(f"Loaded for disasters ID and OOD: {len(self.files['train'])} train, {len(self.files['test'])} test files.") + + def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: + """Get an item from the dataset at the given index.""" + file_info = ( + self.files["train"][index] + if self.split == "train" + else self.files["test"][index]) + + image = self._load_image(file_info["image"]).to("cuda") + mask = self._load_target(file_info["mask"]).long().to("cuda") + mask[mask == 2] = 1 + mask[(mask == 3) | (mask == 4)] = 0 + + sample = {"image": image, "mask": mask} + + if self.transforms: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the total number of samples in the dataset.""" + return ( + len(self.files["train"]) + if self.split == "train" + else len(self.files["test"]) + ) + + def _initialize_files(self, root: Path) -> List[Dict[str, str]]: + """Initialize the dataset by loading file paths and computing basenames with sample numbers.""" + all_files = [] + for split in self.metadata.keys(): + image_root = os.path.join(root, split, "images") + mask_root = os.path.join(root, split, "targets") + images = glob.glob(os.path.join(image_root, "*.png")) + + # Extract basenames while preserving the event-name and sample number + for img in images: + basename_parts = os.path.basename(img).split("_") + event_name = basename_parts[0] # e.g., mexico-earthquake + sample_number = basename_parts[1] # e.g., 00000001 + basename = ( + f"{event_name}_{sample_number}" # e.g., mexico-earthquake_00000001 + ) + + + file_info = { + "image": img, + "mask": os.path.join( + mask_root, f"{basename}_pre_disaster_target.png" + ), + "basename": basename, + } + all_files.append(file_info) + return all_files + + def _load_split_files_by_disaster_and_type( + self, files: List[Dict[str, str]], id_disaster: Dict[str, str], ood_disaster: Dict[str, str] + ) -> Dict[str, List[Dict[str, str]]]: + """ + Return the paths of the files for the train (ID) and test (OOD) sets based on the specified disaster name + and pre-post disaster type. + + Args: + files: List of file paths with their corresponding information. + id_disaster: Dictionary specifying in-domain (ID) disaster and type (e.g., {"disaster_name": "guatemala-volcano", "pre-post": "pre"}). + ood_disaster: Dictionary specifying out-of-domain (OOD) disaster and type (e.g., {"disaster_name": "mexico-earthquake", "pre-post": "post"}). + + Returns: + A dictionary containing 'train' (ID) and 'test' (OOD) file lists. + """ + train_files = [] + test_files = [] + disaster_list = [] + + for file_info in files: + basename = file_info["basename"] + disaster_name = basename.split("_")[0] # Extract disaster name from basename + pre_post = ("pre" if "pre_disaster" in file_info["image"] else "post") # Identify pre/post type + + disaster_list.append(disaster_name) + + # Filter for in-domain (ID) training set + if disaster_name == id_disaster["disaster_name"]: + if id_disaster.get("pre-post") == "both" or id_disaster["pre-post"] == pre_post: + image = ( + file_info["image"].replace("post_disaster", "pre_disaster") + if pre_post == "pre" + else file_info["image"] + ) + mask = ( + file_info["mask"].replace("post_disaster", "pre_disaster") + if pre_post == "pre" + else file_info["mask"] + ) + train_files.append(dict(image=image, mask=mask)) + + # Filter for out-of-domain (OOD) test set + if disaster_name == ood_disaster["disaster_name"]: + if ood_disaster.get("pre-post") == "both" or ood_disaster["pre-post"] == pre_post: + test_files.append(file_info) + + return {"train": train_files, "test": test_files, "disasters":disaster_list} \ No newline at end of file From 5985f44c451b99329e8e62b50a90f7c2fb2f724e Mon Sep 17 00:00:00 2001 From: Burak <68427259+burakekim@users.noreply.github.com> Date: Mon, 18 Nov 2024 16:00:56 +0100 Subject: [PATCH 3/4] test xview2 --- tests/datasets/test_xview2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_xview2.py b/tests/datasets/test_xview2.py index 7689acf5f78..35e02e27c28 100644 --- a/tests/datasets/test_xview2.py +++ b/tests/datasets/test_xview2.py @@ -12,7 +12,7 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -from torchgeo.datasets import DatasetNotFoundError, XView2 +from torchgeo.datasets import DatasetNotFoundError, XView2, XView2DistShift class TestXView2: From a23344e69567d6f435e8f95018a6716ae0fe5aa4 Mon Sep 17 00:00:00 2001 From: Burak Date: Mon, 18 Nov 2024 17:03:23 +0100 Subject: [PATCH 4/4] formatting --- tests/datasets/test_xview2.py | 2 +- torchgeo/datasets/xview.py | 211 +++++++++++++++++++--------------- 2 files changed, 121 insertions(+), 92 deletions(-) diff --git a/tests/datasets/test_xview2.py b/tests/datasets/test_xview2.py index 35e02e27c28..dc8774d0933 100644 --- a/tests/datasets/test_xview2.py +++ b/tests/datasets/test_xview2.py @@ -14,7 +14,6 @@ from torchgeo.datasets import DatasetNotFoundError, XView2, XView2DistShift - class TestXView2: @pytest.fixture(params=['train', 'test']) def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> XView2: @@ -27,6 +26,7 @@ def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> XView2: 'md5': '373e61d55c1b294aa76b94dbbd81332b', 'directory': 'train', }, + 'test': { 'filename': 'test_images_labels_targets.tar.gz', 'md5': 'bc6de81c956a3bada38b5b4e246266a1', diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index 98fe3ceb534..bc9ce8d5992 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -19,6 +19,7 @@ from .utils import check_integrity, draw_semantic_segmentation_masks, extract_archive + class XView2(NonGeoDataset): """xView2 dataset. @@ -50,24 +51,24 @@ class XView2(NonGeoDataset): """ metadata = { - 'train': { - 'filename': 'train_images_labels_targets.tar.gz', - 'md5': 'a20ebbfb7eb3452785b63ad02ffd1e16', - 'directory': 'train', + "train": { + "filename": "train_images_labels_targets.tar.gz", + "md5": "a20ebbfb7eb3452785b63ad02ffd1e16", + "directory": "train", }, - 'test': { - 'filename': 'test_images_labels_targets.tar.gz', - 'md5': '1b39c47e05d1319c17cc8763cee6fe0c', - 'directory': 'test', + "test": { + "filename": "test_images_labels_targets.tar.gz", + "md5": "1b39c47e05d1319c17cc8763cee6fe0c", + "directory": "test", }, } - classes = ['background', 'no-damage', 'minor-damage', 'major-damage', 'destroyed'] - colormap = ['green', 'blue', 'orange', 'red'] + classes = ["background", "no-damage", "minor-damage", "major-damage", "destroyed"] + colormap = ["green", "blue", "orange", "red"] def __init__( self, - root: str = 'data', - split: str = 'train', + root: str = "data", + split: str = "train", transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, ) -> None: @@ -105,14 +106,14 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: data and label at that index """ files = self.files[index] - image1 = self._load_image(files['image1']) - image2 = self._load_image(files['image2']) - mask1 = self._load_target(files['mask1']) - mask2 = self._load_target(files['mask2']) + image1 = self._load_image(files["image1"]) + image2 = self._load_image(files["image2"]) + mask1 = self._load_target(files["mask1"]) + mask2 = self._load_target(files["mask2"]) image = torch.stack(tensors=[image1, image2], dim=0) mask = torch.stack(tensors=[mask1, mask2], dim=0) - sample = {'image': image, 'mask': mask} + sample = {"image": image, "mask": mask} if self.transforms is not None: sample = self.transforms(sample) @@ -138,17 +139,17 @@ def _load_files(self, root: str, split: str) -> list[dict[str, str]]: list of dicts containing paths for each pair of images and masks """ files = [] - directory = self.metadata[split]['directory'] - image_root = os.path.join(root, directory, 'images') - mask_root = os.path.join(root, directory, 'targets') - images = glob.glob(os.path.join(image_root, '*.png')) + directory = self.metadata[split]["directory"] + image_root = os.path.join(root, directory, "images") + mask_root = os.path.join(root, directory, "targets") + images = glob.glob(os.path.join(image_root, "*.png")) basenames = [os.path.basename(f) for f in images] - basenames = ['_'.join(f.split('_')[:-2]) for f in basenames] + basenames = ["_".join(f.split("_")[:-2]) for f in basenames] for name in sorted(set(basenames)): - image1 = os.path.join(image_root, f'{name}_pre_disaster.png') - image2 = os.path.join(image_root, f'{name}_post_disaster.png') - mask1 = os.path.join(mask_root, f'{name}_pre_disaster_target.png') - mask2 = os.path.join(mask_root, f'{name}_post_disaster_target.png') + image1 = os.path.join(image_root, f"{name}_pre_disaster.png") + image2 = os.path.join(image_root, f"{name}_post_disaster.png") + mask1 = os.path.join(mask_root, f"{name}_pre_disaster_target.png") + mask2 = os.path.join(mask_root, f"{name}_post_disaster_target.png") files.append(dict(image1=image1, image2=image2, mask1=mask1, mask2=mask2)) return files @@ -163,7 +164,7 @@ def _load_image(self, path: str) -> Tensor: """ filename = os.path.join(path) with Image.open(filename) as img: - array: np.typing.NDArray[np.int_] = np.array(img.convert('RGB')) + array: np.typing.NDArray[np.int_] = np.array(img.convert("RGB")) tensor = torch.from_numpy(array) # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) @@ -180,7 +181,7 @@ def _load_target(self, path: str) -> Tensor: """ filename = os.path.join(path) with Image.open(filename) as img: - array: np.typing.NDArray[np.int_] = np.array(img.convert('L')) + array: np.typing.NDArray[np.int_] = np.array(img.convert("L")) tensor = torch.from_numpy(array) tensor = tensor.to(torch.long) return tensor @@ -190,10 +191,10 @@ def _verify(self) -> None: # Check if the files already exist exists = [] for split_info in self.metadata.values(): - for directory in ['images', 'targets']: + for directory in ["images", "targets"]: exists.append( os.path.exists( - os.path.join(self.root, split_info['directory'], directory) + os.path.join(self.root, split_info["directory"], directory) ) ) @@ -203,10 +204,10 @@ def _verify(self) -> None: # Check if .tar.gz files already exists (if so then extract) exists = [] for split_info in self.metadata.values(): - filepath = os.path.join(self.root, split_info['filename']) + filepath = os.path.join(self.root, split_info["filename"]) if os.path.isfile(filepath): - if self.checksum and not check_integrity(filepath, split_info['md5']): - raise RuntimeError('Dataset found, but corrupted.') + if self.checksum and not check_integrity(filepath, split_info["md5"]): + raise RuntimeError("Dataset found, but corrupted.") exists.append(True) extract_archive(filepath) else: @@ -237,70 +238,78 @@ def plot( """ ncols = 2 image1 = draw_semantic_segmentation_masks( - sample['image'][0], sample['mask'][0], alpha=alpha, colors=self.colormap + sample["image"][0], sample["mask"][0], alpha=alpha, colors=self.colormap ) image2 = draw_semantic_segmentation_masks( - sample['image'][1], sample['mask'][1], alpha=alpha, colors=self.colormap + sample["image"][1], sample["mask"][1], alpha=alpha, colors=self.colormap ) - if 'prediction' in sample: # NOTE: this assumes predictions are made for post + if "prediction" in sample: # NOTE: this assumes predictions are made for post ncols += 1 image3 = draw_semantic_segmentation_masks( - sample['image'][1], - sample['prediction'], + sample["image"][1], + sample["prediction"], alpha=alpha, colors=self.colormap, ) fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) axs[0].imshow(image1) - axs[0].axis('off') + axs[0].axis("off") axs[1].imshow(image2) - axs[1].axis('off') + axs[1].axis("off") if ncols > 2: axs[2].imshow(image3) - axs[2].axis('off') + axs[2].axis("off") if show_titles: - axs[0].set_title('Pre disaster') - axs[1].set_title('Post disaster') + axs[0].set_title("Pre disaster") + axs[1].set_title("Post disaster") if ncols > 2: - axs[2].set_title('Predictions') + axs[2].set_title("Predictions") if suptitle is not None: plt.suptitle(suptitle) return fig - + + class XView2DistShift(XView2): - """ - A subclass of the XView2 dataset designed to reformat the original train/test splits - based on specific in-domain (ID) and out-of-domain (OOD) disasters. - - This class allows for the selection of particular disasters to be used as the - training set (in-domain) and test set (out-of-domain). The dataset can be split - according to the disaster names specified by the user, enabling the model to train - on one disaster type and evaluate on a different, out-of-domain disaster. The goal - is to test the generalization ability of models trained on one disaster to perform + """A subclass of the XView2 dataset designed to reformat the original train/test splits. + + This class allows for the selection of particular disasters to be used as the + training set (in-domain) and test set (out-of-domain). The dataset can be split + according to the disaster names specified by the user, enabling the model to train + on one disaster type and evaluate on a different, out-of-domain disaster. The goal + is to test the generalization ability of models trained on one disaster to perform on others. """ - + classes = ["background", "building"] - + # List of disaster names valid_disasters = [ - 'hurricane-harvey', 'socal-fire', 'hurricane-matthew', 'mexico-earthquake', - 'guatemala-volcano', 'santa-rosa-wildfire', 'palu-tsunami', 'hurricane-florence', - 'hurricane-michael', 'midwest-flooding' + "hurricane-harvey", + "socal-fire", + "hurricane-matthew", + "mexico-earthquake", + "guatemala-volcano", + "santa-rosa-wildfire", + "palu-tsunami", + "hurricane-florence", + "hurricane-michael", + "midwest-flooding", ] - + def __init__( self, - root: Path = "data", + root: str = "data", split: str = "train", - id_ood_disaster: List[Dict[str, str]] = [{"disaster_name": "hurricane-matthew", "pre-post": "post"}, {"disaster_name": "mexico-earthquake", "pre-post": "post"}], - transforms: Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] = None, + id_ood_disaster: list[dict[str, str]] = [ + {"disaster_name": "hurricane-matthew", "pre-post": "post"}, + {"disaster_name": "mexico-earthquake", "pre-post": "post"}, + ], + transforms: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] = None, checksum: bool = False, - **kwargs ) -> None: """Initialize the XView2DistShift dataset instance. @@ -311,7 +320,7 @@ def __init__( transforms: a function/transform that takes input sample and its target as entry and returns a transformed version checksum: if True, check the MD5 of the downloaded files (may be slow) - + Raises: AssertionError: If *split* is invalid. ValueError: If a disaster name in `id_ood_disaster` is not one of the valid disasters. @@ -319,9 +328,14 @@ def __init__( """ assert split in ["train", "test"], "Split must be either 'train' or 'test'." # Validate that the disasters are valid - - if id_ood_disaster[0]['disaster_name'] not in self.valid_disasters or id_ood_disaster[1]['disaster_name'] not in self.valid_disasters: - raise ValueError(f"Invalid disaster names. Valid options are: {', '.join(self.valid_disasters)}") + + if ( + id_ood_disaster[0]["disaster_name"] not in self.valid_disasters + or id_ood_disaster[1]["disaster_name"] not in self.valid_disasters + ): + raise ValueError( + f"Invalid disaster names. Valid options are: {', '.join(self.valid_disasters)}" + ) self.root = root self.split = split @@ -332,23 +346,28 @@ def __init__( # Load all files and compute basenames and disasters only once self.all_files = self._initialize_files(root) - + # Split logic by disaster and pre-post type - self.files = self._load_split_files_by_disaster_and_type(self.all_files, id_ood_disaster[0], id_ood_disaster[1]) - print(f"Loaded for disasters ID and OOD: {len(self.files['train'])} train, {len(self.files['test'])} test files.") + self.files = self._load_split_files_by_disaster_and_type( + self.all_files, id_ood_disaster[0], id_ood_disaster[1] + ) + print( + f"Loaded for disasters ID and OOD: {len(self.files['train'])} train, {len(self.files['test'])} test files." + ) - def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: + def __getitem__(self, index: int) -> dict[str, torch.Tensor]: """Get an item from the dataset at the given index.""" file_info = ( self.files["train"][index] if self.split == "train" - else self.files["test"][index]) + else self.files["test"][index] + ) image = self._load_image(file_info["image"]).to("cuda") mask = self._load_target(file_info["mask"]).long().to("cuda") mask[mask == 2] = 1 mask[(mask == 3) | (mask == 4)] = 0 - + sample = {"image": image, "mask": mask} if self.transforms: @@ -364,14 +383,14 @@ def __len__(self) -> int: else len(self.files["test"]) ) - def _initialize_files(self, root: Path) -> List[Dict[str, str]]: + def _initialize_files(self, root: str) -> list[dict[str, str]]: """Initialize the dataset by loading file paths and computing basenames with sample numbers.""" all_files = [] for split in self.metadata.keys(): image_root = os.path.join(root, split, "images") mask_root = os.path.join(root, split, "targets") images = glob.glob(os.path.join(image_root, "*.png")) - + # Extract basenames while preserving the event-name and sample number for img in images: basename_parts = os.path.basename(img).split("_") @@ -381,7 +400,6 @@ def _initialize_files(self, root: Path) -> List[Dict[str, str]]: f"{event_name}_{sample_number}" # e.g., mexico-earthquake_00000001 ) - file_info = { "image": img, "mask": os.path.join( @@ -393,11 +411,12 @@ def _initialize_files(self, root: Path) -> List[Dict[str, str]]: return all_files def _load_split_files_by_disaster_and_type( - self, files: List[Dict[str, str]], id_disaster: Dict[str, str], ood_disaster: Dict[str, str] - ) -> Dict[str, List[Dict[str, str]]]: - """ - Return the paths of the files for the train (ID) and test (OOD) sets based on the specified disaster name - and pre-post disaster type. + self, + files: list[dict[str, str]], + id_disaster: dict[str, str], + ood_disaster: dict[str, str], + ) -> dict[str, list[dict[str, str]]]: + """Return the filepaths for the train (ID) and test (OOD) sets based on disaster name and pre-post disaster type. Args: files: List of file paths with their corresponding information. @@ -410,17 +429,24 @@ def _load_split_files_by_disaster_and_type( train_files = [] test_files = [] disaster_list = [] - + for file_info in files: basename = file_info["basename"] - disaster_name = basename.split("_")[0] # Extract disaster name from basename - pre_post = ("pre" if "pre_disaster" in file_info["image"] else "post") # Identify pre/post type - - disaster_list.append(disaster_name) - + disaster_name = basename.split("_")[ + 0 + ] # Extract disaster name from basename + pre_post = ( + "pre" if "pre_disaster" in file_info["image"] else "post" + ) # Identify pre/post type + + disaster_list.append(disaster_name) + # Filter for in-domain (ID) training set if disaster_name == id_disaster["disaster_name"]: - if id_disaster.get("pre-post") == "both" or id_disaster["pre-post"] == pre_post: + if ( + id_disaster.get("pre-post") == "both" + or id_disaster["pre-post"] == pre_post + ): image = ( file_info["image"].replace("post_disaster", "pre_disaster") if pre_post == "pre" @@ -435,7 +461,10 @@ def _load_split_files_by_disaster_and_type( # Filter for out-of-domain (OOD) test set if disaster_name == ood_disaster["disaster_name"]: - if ood_disaster.get("pre-post") == "both" or ood_disaster["pre-post"] == pre_post: + if ( + ood_disaster.get("pre-post") == "both" + or ood_disaster["pre-post"] == pre_post + ): test_files.append(file_info) - return {"train": train_files, "test": test_files, "disasters":disaster_list} \ No newline at end of file + return {"train": train_files, "test": test_files, "disasters": disaster_list}