diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 20ce1bfcbac..8b5149ade43 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -462,6 +462,7 @@ xView2 ^^^^^^ .. autoclass:: XView2 +.. autoclass:: XView2DistShift ZueriCrop ^^^^^^^^^ diff --git a/tests/datasets/test_xview2.py b/tests/datasets/test_xview2.py index 7689acf5f78..dc8774d0933 100644 --- a/tests/datasets/test_xview2.py +++ b/tests/datasets/test_xview2.py @@ -12,8 +12,7 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -from torchgeo.datasets import DatasetNotFoundError, XView2 - +from torchgeo.datasets import DatasetNotFoundError, XView2, XView2DistShift class TestXView2: @pytest.fixture(params=['train', 'test']) @@ -27,6 +26,7 @@ def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> XView2: 'md5': '373e61d55c1b294aa76b94dbbd81332b', 'directory': 'train', }, + 'test': { 'filename': 'test_images_labels_targets.tar.gz', 'md5': 'bc6de81c956a3bada38b5b4e246266a1', diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 6a15fabdf76..f84760ce865 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -136,7 +136,7 @@ from .vaihingen import Vaihingen2D from .vhr10 import VHR10 from .western_usa_live_fuel_moisture import WesternUSALiveFuelMoisture -from .xview import XView2 +from .xview import XView2, XView2DistShift from .zuericrop import ZueriCrop __all__ = ( @@ -258,6 +258,7 @@ 'VHR10', 'WesternUSALiveFuelMoisture', 'XView2', + 'XView2DistShift', 'ZueriCrop', # Base classes 'GeoDataset', diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index 5716c06f593..bc9ce8d5992 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -19,6 +19,7 @@ from .utils import check_integrity, draw_semantic_segmentation_masks, extract_archive + class XView2(NonGeoDataset): """xView2 dataset. @@ -50,24 +51,24 @@ class XView2(NonGeoDataset): """ metadata = { - 'train': { - 'filename': 'train_images_labels_targets.tar.gz', - 'md5': 'a20ebbfb7eb3452785b63ad02ffd1e16', - 'directory': 'train', + "train": { + "filename": "train_images_labels_targets.tar.gz", + "md5": "a20ebbfb7eb3452785b63ad02ffd1e16", + "directory": "train", }, - 'test': { - 'filename': 'test_images_labels_targets.tar.gz', - 'md5': '1b39c47e05d1319c17cc8763cee6fe0c', - 'directory': 'test', + "test": { + "filename": "test_images_labels_targets.tar.gz", + "md5": "1b39c47e05d1319c17cc8763cee6fe0c", + "directory": "test", }, } - classes = ['background', 'no-damage', 'minor-damage', 'major-damage', 'destroyed'] - colormap = ['green', 'blue', 'orange', 'red'] + classes = ["background", "no-damage", "minor-damage", "major-damage", "destroyed"] + colormap = ["green", "blue", "orange", "red"] def __init__( self, - root: str = 'data', - split: str = 'train', + root: str = "data", + split: str = "train", transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, ) -> None: @@ -105,14 +106,14 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: data and label at that index """ files = self.files[index] - image1 = self._load_image(files['image1']) - image2 = self._load_image(files['image2']) - mask1 = self._load_target(files['mask1']) - mask2 = self._load_target(files['mask2']) + image1 = self._load_image(files["image1"]) + image2 = self._load_image(files["image2"]) + mask1 = self._load_target(files["mask1"]) + mask2 = self._load_target(files["mask2"]) image = torch.stack(tensors=[image1, image2], dim=0) mask = torch.stack(tensors=[mask1, mask2], dim=0) - sample = {'image': image, 'mask': mask} + sample = {"image": image, "mask": mask} if self.transforms is not None: sample = self.transforms(sample) @@ -138,17 +139,17 @@ def _load_files(self, root: str, split: str) -> list[dict[str, str]]: list of dicts containing paths for each pair of images and masks """ files = [] - directory = self.metadata[split]['directory'] - image_root = os.path.join(root, directory, 'images') - mask_root = os.path.join(root, directory, 'targets') - images = glob.glob(os.path.join(image_root, '*.png')) + directory = self.metadata[split]["directory"] + image_root = os.path.join(root, directory, "images") + mask_root = os.path.join(root, directory, "targets") + images = glob.glob(os.path.join(image_root, "*.png")) basenames = [os.path.basename(f) for f in images] - basenames = ['_'.join(f.split('_')[:-2]) for f in basenames] + basenames = ["_".join(f.split("_")[:-2]) for f in basenames] for name in sorted(set(basenames)): - image1 = os.path.join(image_root, f'{name}_pre_disaster.png') - image2 = os.path.join(image_root, f'{name}_post_disaster.png') - mask1 = os.path.join(mask_root, f'{name}_pre_disaster_target.png') - mask2 = os.path.join(mask_root, f'{name}_post_disaster_target.png') + image1 = os.path.join(image_root, f"{name}_pre_disaster.png") + image2 = os.path.join(image_root, f"{name}_post_disaster.png") + mask1 = os.path.join(mask_root, f"{name}_pre_disaster_target.png") + mask2 = os.path.join(mask_root, f"{name}_post_disaster_target.png") files.append(dict(image1=image1, image2=image2, mask1=mask1, mask2=mask2)) return files @@ -163,7 +164,7 @@ def _load_image(self, path: str) -> Tensor: """ filename = os.path.join(path) with Image.open(filename) as img: - array: np.typing.NDArray[np.int_] = np.array(img.convert('RGB')) + array: np.typing.NDArray[np.int_] = np.array(img.convert("RGB")) tensor = torch.from_numpy(array) # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) @@ -180,7 +181,7 @@ def _load_target(self, path: str) -> Tensor: """ filename = os.path.join(path) with Image.open(filename) as img: - array: np.typing.NDArray[np.int_] = np.array(img.convert('L')) + array: np.typing.NDArray[np.int_] = np.array(img.convert("L")) tensor = torch.from_numpy(array) tensor = tensor.to(torch.long) return tensor @@ -190,10 +191,10 @@ def _verify(self) -> None: # Check if the files already exist exists = [] for split_info in self.metadata.values(): - for directory in ['images', 'targets']: + for directory in ["images", "targets"]: exists.append( os.path.exists( - os.path.join(self.root, split_info['directory'], directory) + os.path.join(self.root, split_info["directory"], directory) ) ) @@ -203,10 +204,10 @@ def _verify(self) -> None: # Check if .tar.gz files already exists (if so then extract) exists = [] for split_info in self.metadata.values(): - filepath = os.path.join(self.root, split_info['filename']) + filepath = os.path.join(self.root, split_info["filename"]) if os.path.isfile(filepath): - if self.checksum and not check_integrity(filepath, split_info['md5']): - raise RuntimeError('Dataset found, but corrupted.') + if self.checksum and not check_integrity(filepath, split_info["md5"]): + raise RuntimeError("Dataset found, but corrupted.") exists.append(True) extract_archive(filepath) else: @@ -237,36 +238,233 @@ def plot( """ ncols = 2 image1 = draw_semantic_segmentation_masks( - sample['image'][0], sample['mask'][0], alpha=alpha, colors=self.colormap + sample["image"][0], sample["mask"][0], alpha=alpha, colors=self.colormap ) image2 = draw_semantic_segmentation_masks( - sample['image'][1], sample['mask'][1], alpha=alpha, colors=self.colormap + sample["image"][1], sample["mask"][1], alpha=alpha, colors=self.colormap ) - if 'prediction' in sample: # NOTE: this assumes predictions are made for post + if "prediction" in sample: # NOTE: this assumes predictions are made for post ncols += 1 image3 = draw_semantic_segmentation_masks( - sample['image'][1], - sample['prediction'], + sample["image"][1], + sample["prediction"], alpha=alpha, colors=self.colormap, ) fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) axs[0].imshow(image1) - axs[0].axis('off') + axs[0].axis("off") axs[1].imshow(image2) - axs[1].axis('off') + axs[1].axis("off") if ncols > 2: axs[2].imshow(image3) - axs[2].axis('off') + axs[2].axis("off") if show_titles: - axs[0].set_title('Pre disaster') - axs[1].set_title('Post disaster') + axs[0].set_title("Pre disaster") + axs[1].set_title("Post disaster") if ncols > 2: - axs[2].set_title('Predictions') + axs[2].set_title("Predictions") if suptitle is not None: plt.suptitle(suptitle) return fig + + +class XView2DistShift(XView2): + """A subclass of the XView2 dataset designed to reformat the original train/test splits. + + This class allows for the selection of particular disasters to be used as the + training set (in-domain) and test set (out-of-domain). The dataset can be split + according to the disaster names specified by the user, enabling the model to train + on one disaster type and evaluate on a different, out-of-domain disaster. The goal + is to test the generalization ability of models trained on one disaster to perform + on others. + """ + + classes = ["background", "building"] + + # List of disaster names + valid_disasters = [ + "hurricane-harvey", + "socal-fire", + "hurricane-matthew", + "mexico-earthquake", + "guatemala-volcano", + "santa-rosa-wildfire", + "palu-tsunami", + "hurricane-florence", + "hurricane-michael", + "midwest-flooding", + ] + + def __init__( + self, + root: str = "data", + split: str = "train", + id_ood_disaster: list[dict[str, str]] = [ + {"disaster_name": "hurricane-matthew", "pre-post": "post"}, + {"disaster_name": "mexico-earthquake", "pre-post": "post"}, + ], + transforms: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] = None, + checksum: bool = False, + ) -> None: + """Initialize the XView2DistShift dataset instance. + + Args: + root: Root directory where the dataset is located. + split: One of "train" or "test". + id_ood_disaster: List containing in-distribution and out-of-distribution disaster names. + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + AssertionError: If *split* is invalid. + ValueError: If a disaster name in `id_ood_disaster` is not one of the valid disasters. + DatasetNotFoundError: If dataset is not found. + """ + assert split in ["train", "test"], "Split must be either 'train' or 'test'." + # Validate that the disasters are valid + + if ( + id_ood_disaster[0]["disaster_name"] not in self.valid_disasters + or id_ood_disaster[1]["disaster_name"] not in self.valid_disasters + ): + raise ValueError( + f"Invalid disaster names. Valid options are: {', '.join(self.valid_disasters)}" + ) + + self.root = root + self.split = split + self.transforms = transforms + self.checksum = checksum + + self._verify() + + # Load all files and compute basenames and disasters only once + self.all_files = self._initialize_files(root) + + # Split logic by disaster and pre-post type + self.files = self._load_split_files_by_disaster_and_type( + self.all_files, id_ood_disaster[0], id_ood_disaster[1] + ) + print( + f"Loaded for disasters ID and OOD: {len(self.files['train'])} train, {len(self.files['test'])} test files." + ) + + def __getitem__(self, index: int) -> dict[str, torch.Tensor]: + """Get an item from the dataset at the given index.""" + file_info = ( + self.files["train"][index] + if self.split == "train" + else self.files["test"][index] + ) + + image = self._load_image(file_info["image"]).to("cuda") + mask = self._load_target(file_info["mask"]).long().to("cuda") + mask[mask == 2] = 1 + mask[(mask == 3) | (mask == 4)] = 0 + + sample = {"image": image, "mask": mask} + + if self.transforms: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the total number of samples in the dataset.""" + return ( + len(self.files["train"]) + if self.split == "train" + else len(self.files["test"]) + ) + + def _initialize_files(self, root: str) -> list[dict[str, str]]: + """Initialize the dataset by loading file paths and computing basenames with sample numbers.""" + all_files = [] + for split in self.metadata.keys(): + image_root = os.path.join(root, split, "images") + mask_root = os.path.join(root, split, "targets") + images = glob.glob(os.path.join(image_root, "*.png")) + + # Extract basenames while preserving the event-name and sample number + for img in images: + basename_parts = os.path.basename(img).split("_") + event_name = basename_parts[0] # e.g., mexico-earthquake + sample_number = basename_parts[1] # e.g., 00000001 + basename = ( + f"{event_name}_{sample_number}" # e.g., mexico-earthquake_00000001 + ) + + file_info = { + "image": img, + "mask": os.path.join( + mask_root, f"{basename}_pre_disaster_target.png" + ), + "basename": basename, + } + all_files.append(file_info) + return all_files + + def _load_split_files_by_disaster_and_type( + self, + files: list[dict[str, str]], + id_disaster: dict[str, str], + ood_disaster: dict[str, str], + ) -> dict[str, list[dict[str, str]]]: + """Return the filepaths for the train (ID) and test (OOD) sets based on disaster name and pre-post disaster type. + + Args: + files: List of file paths with their corresponding information. + id_disaster: Dictionary specifying in-domain (ID) disaster and type (e.g., {"disaster_name": "guatemala-volcano", "pre-post": "pre"}). + ood_disaster: Dictionary specifying out-of-domain (OOD) disaster and type (e.g., {"disaster_name": "mexico-earthquake", "pre-post": "post"}). + + Returns: + A dictionary containing 'train' (ID) and 'test' (OOD) file lists. + """ + train_files = [] + test_files = [] + disaster_list = [] + + for file_info in files: + basename = file_info["basename"] + disaster_name = basename.split("_")[ + 0 + ] # Extract disaster name from basename + pre_post = ( + "pre" if "pre_disaster" in file_info["image"] else "post" + ) # Identify pre/post type + + disaster_list.append(disaster_name) + + # Filter for in-domain (ID) training set + if disaster_name == id_disaster["disaster_name"]: + if ( + id_disaster.get("pre-post") == "both" + or id_disaster["pre-post"] == pre_post + ): + image = ( + file_info["image"].replace("post_disaster", "pre_disaster") + if pre_post == "pre" + else file_info["image"] + ) + mask = ( + file_info["mask"].replace("post_disaster", "pre_disaster") + if pre_post == "pre" + else file_info["mask"] + ) + train_files.append(dict(image=image, mask=mask)) + + # Filter for out-of-domain (OOD) test set + if disaster_name == ood_disaster["disaster_name"]: + if ( + ood_disaster.get("pre-post") == "both" + or ood_disaster["pre-post"] == pre_post + ): + test_files.append(file_info) + + return {"train": train_files, "test": test_files, "disasters": disaster_list}