diff --git a/src/pl_bolts/datamodules/cityscapes_datamodule.py b/src/pl_bolts/datamodules/cityscapes_datamodule.py index 351314d83..3f59ec2ad 100644 --- a/src/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/src/pl_bolts/datamodules/cityscapes_datamodule.py @@ -4,7 +4,6 @@ from torch.utils.data import DataLoader from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -14,7 +13,6 @@ warn_missing_pkg("torchvision") -@under_review() class CityscapesDataModule(LightningDataModule): """ .. figure:: https://www.cityscapes-dataset.com/wordpress/wp-content/uploads/2015/07/muenster00-1024x510.png @@ -85,7 +83,7 @@ def __init__( data_dir: where to load the data from path, i.e. where directory leftImg8bit and gtFine or gtCoarse are located quality_mode: the quality mode to use, either 'fine' or 'coarse' - target_type: targets to use, either 'instance' or 'semantic' + target_type: targets to use, can be 'instance', 'semantic', 'color', or 'polygon'. num_workers: how many workers to use for loading data batch_size: number of examples per training/eval step seed: random seed to be used for train/val/test splits @@ -101,8 +99,10 @@ def __init__( "You want to use CityScapes dataset loaded from `torchvision` which is not installed yet." ) - if target_type not in ["instance", "semantic"]: - raise ValueError(f'Only "semantic" and "instance" target types are supported. Got {target_type}.') + if target_type not in ["instance", "semantic", "color", "polygon"]: + raise ValueError( + f'Only "instance", "semantic", "color", "polygon" target types are supported. Got {target_type}.' + ) self.dims = (3, 1024, 2048) self.data_dir = data_dir @@ -121,10 +121,7 @@ def __init__( @property def num_classes(self) -> int: - """ - Return: - 30 - """ + """Returns the number of classes.""" return 30 def train_dataloader(self) -> DataLoader: @@ -151,6 +148,33 @@ def train_dataloader(self) -> DataLoader: pin_memory=self.pin_memory, ) + def train_extra_dataloader(self) -> DataLoader: + """Cityscapes extra train dataset. + + Only supported in coarse quality mode. + """ + transforms = self.train_transforms or self._default_transforms() + target_transforms = self.target_transforms or self._default_target_transforms() + + dataset = Cityscapes( + self.data_dir, + split="train_extra", + target_type=self.target_type, + mode=self.quality_mode, + transform=transforms, + target_transform=target_transforms, + **self.extra_args, + ) + + return DataLoader( + dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + drop_last=self.drop_last, + pin_memory=self.pin_memory, + ) + def val_dataloader(self) -> DataLoader: """Cityscapes val set.""" transforms = self.val_transforms or self._default_transforms() @@ -176,7 +200,10 @@ def val_dataloader(self) -> DataLoader: ) def test_dataloader(self) -> DataLoader: - """Cityscapes test set.""" + """Cityscapes test set. + + Only supported in fine quality mode. + """ transforms = self.test_transforms or self._default_transforms() target_transforms = self.target_transforms or self._default_target_transforms() @@ -208,5 +235,8 @@ def _default_transforms(self) -> Callable: ] ) - def _default_target_transforms(self) -> Callable: + def _default_target_transforms(self) -> Optional[Callable]: + if self.target_type == "polygon": + return None + return transform_lib.Compose([transform_lib.ToTensor(), transform_lib.Lambda(lambda t: t.squeeze())]) diff --git a/tests/datamodules/test_datamodules.py b/tests/datamodules/test_datamodules.py index b18aa6816..607c8492e 100644 --- a/tests/datamodules/test_datamodules.py +++ b/tests/datamodules/test_datamodules.py @@ -24,7 +24,7 @@ def test_dev_datasets(datadir): pass -def _create_synth_Cityscapes_dataset(path_dir): +def _create_synth_Cityscapes_dataset(path_dir, img_size=(2048, 1024)): """Create synthetic dataset with random images, just to simulate that the dataset have been already downloaded.""" non_existing_citites = ["dummy_city_1", "dummy_city_2"] @@ -40,32 +40,26 @@ def _create_synth_Cityscapes_dataset(path_dir): image_name = f"{base_name}_leftImg8bit.png" instance_target_name = f"{base_name}_gtFine_instanceIds.png" semantic_target_name = f"{base_name}_gtFine_labelIds.png" - Image.new("RGB", (2048, 1024)).save(images_dir / split / city / image_name) - Image.new("L", (2048, 1024)).save(fine_labels_dir / split / city / instance_target_name) - Image.new("L", (2048, 1024)).save(fine_labels_dir / split / city / semantic_target_name) + color_target_name = f"{base_name}_gtFine_color.png" + Image.new("RGB", img_size).save(images_dir / split / city / image_name) + Image.new("L", img_size).save(fine_labels_dir / split / city / instance_target_name) + Image.new("L", img_size).save(fine_labels_dir / split / city / semantic_target_name) + Image.new("RGBA", img_size).save(fine_labels_dir / split / city / color_target_name) -def test_cityscapes_datamodule(datadir): +@pytest.mark.parametrize( + ("target_type", "target_size"), + [("semantic", (1024, 2048)), ("instance", (1024, 2048)), ("color", (4, 1024, 2048))], +) +def test_cityscapes_datamodule(datadir, catch_warnings, target_type: str, target_size: tuple, batch_size: int = 1): _create_synth_Cityscapes_dataset(datadir) - batch_size = 1 - target_types = ["semantic", "instance"] - for target_type in target_types: - dm = CityscapesDataModule(datadir, num_workers=0, batch_size=batch_size, target_type=target_type) - loader = dm.train_dataloader() - img, mask = next(iter(loader)) - assert img.size() == torch.Size([batch_size, 3, 1024, 2048]) - assert mask.size() == torch.Size([batch_size, 1024, 2048]) + dm = CityscapesDataModule(datadir, num_workers=0, batch_size=batch_size, target_type=target_type) - loader = dm.val_dataloader() - img, mask = next(iter(loader)) - assert img.size() == torch.Size([batch_size, 3, 1024, 2048]) - assert mask.size() == torch.Size([batch_size, 1024, 2048]) - - loader = dm.test_dataloader() - img, mask = next(iter(loader)) - assert img.size() == torch.Size([batch_size, 3, 1024, 2048]) - assert mask.size() == torch.Size([batch_size, 1024, 2048]) + for loader in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]: + img, mask = next(iter(loader)) + assert img.size() == torch.Size([batch_size, 3, 1024, 2048]) + assert mask.size() == torch.Size([batch_size, *target_size]) @pytest.mark.parametrize(("val_split", "train_len"), [(0.2, 48_000), (5_000, 55_000)]) @@ -78,17 +72,9 @@ def test_vision_data_module(datadir, val_split, catch_warnings, train_len): def test_data_modules(datadir, catch_warnings, dm_cls): """Test datamodules train, val, and test dataloaders outputs have correct shape.""" dm = _create_dm(dm_cls, datadir) - train_loader = dm.train_dataloader() - img, _ = next(iter(train_loader)) - assert img.size() == torch.Size([2, *dm.dims]) - - val_loader = dm.val_dataloader() - img, _ = next(iter(val_loader)) - assert img.size() == torch.Size([2, *dm.dims]) - - test_loader = dm.test_dataloader() - img, _ = next(iter(test_loader)) - assert img.size() == torch.Size([2, *dm.dims]) + for loader in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]: + img, _ = next(iter(loader)) + assert img.size() == torch.Size([2, *dm.dims]) def _create_dm(dm_cls, datadir, **kwargs): @@ -112,17 +98,9 @@ def test_sr_datamodule(datadir): def test_emnist_datamodules(datadir, catch_warnings, dm_cls, split): """Test BinaryEMNIST and EMNIST datamodules download data and have the correct shape.""" dm = _create_dm(dm_cls, datadir, split=split) - train_loader = dm.train_dataloader() - img, _ = next(iter(train_loader)) - assert img.size() == torch.Size([2, *dm.dims]) - - val_loader = dm.val_dataloader() - img, _ = next(iter(val_loader)) - assert img.size() == torch.Size([2, *dm.dims]) - - test_loader = dm.test_dataloader() - img, _ = next(iter(test_loader)) - assert img.size() == torch.Size([2, *dm.dims]) + for loader in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]: + img, _ = next(iter(loader)) + assert img.size() == torch.Size([2, *dm.dims]) @pytest.mark.parametrize("dm_cls", [BinaryEMNISTDataModule, EMNISTDataModule])