diff --git a/pl_bolts/datasets/kitti_dataset.py b/pl_bolts/datasets/kitti_dataset.py index 61a10a95a6..063aa6560a 100644 --- a/pl_bolts/datasets/kitti_dataset.py +++ b/pl_bolts/datasets/kitti_dataset.py @@ -1,10 +1,10 @@ import os +from typing import Callable, Optional, Tuple import numpy as np from torch.utils.data import Dataset from pl_bolts.utils import _PIL_AVAILABLE -from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg if _PIL_AVAILABLE: @@ -12,25 +12,28 @@ else: # pragma: no cover warn_missing_pkg("PIL", pypi_name="Pillow") -DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1) +KITTI_LABELS = tuple(range(-1, 34)) DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33) -@under_review() class KittiDataset(Dataset): - """ - Note: - You need to have downloaded the Kitti dataset first and provide the path to where it is saved. - You can download the dataset here: http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015 - - There are 34 classes, however not all of them are useful for training (e.g. railings on highways). These - useless classes (the pixel values of these classes) are stored in `void_labels`. Useful classes are stored - in `valid_labels`. - - The `encode_segmap` function sets all pixels with any of the `void_labels` to `ignore_index` - (250 by default). It also sets all of the valid pixels to the appropriate value between 0 and - `len(valid_labels)` (since that is the number of valid classes), so it can be used properly by - the loss function when comparing with the output. + """KITTI Dataset for sematic segmentation. + + You need to have downloaded the Kitti semantic dataset first and provide the path to where it is saved. + You can download the dataset here: http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015 + + There are 34 classes, however not all of them are useful for training (e.g. railings on highways). + Useful classes (the pixel values of these classes) are stored in `valid_labels`, other labels + except useful classes are stored in `void_labels`. + + The class id and valid labels(`ignoreInEval`) can be found in here: + https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py + + Args: + data_dir (str): where to load the data from path, i.e. '/path/to/folder/with/data_semantics/' + img_size (tuple): image dimensions (width, height) + valid_labels (tuple): useful classes to include + transform (callable, optional): A function/transform that takes in the numpy array and transforms it. """ IMAGE_PATH = os.path.join("training", "image_2") @@ -40,23 +43,15 @@ def __init__( self, data_dir: str, img_size: tuple = (1242, 376), - void_labels: list = DEFAULT_VOID_LABELS, - valid_labels: list = DEFAULT_VALID_LABELS, - transform=None, + valid_labels: Tuple[int] = DEFAULT_VALID_LABELS, + transform: Optional[Callable] = None, ): - """ - Args: - data_dir (str): where to load the data from path, i.e. '/path/to/folder/with/data_semantics/' - img_size: image dimensions (width, height) - void_labels: useless classes to be excluded from training - valid_labels: useful classes to include - """ if not _PIL_AVAILABLE: # pragma: no cover raise ModuleNotFoundError("You want to use `PIL` which is not installed yet.") self.img_size = img_size - self.void_labels = void_labels self.valid_labels = valid_labels + self.void_labels = tuple(label for label in KITTI_LABELS if label not in self.valid_labels) self.ignore_index = 250 self.class_map = dict(zip(self.valid_labels, range(len(self.valid_labels)))) self.transform = transform @@ -67,35 +62,39 @@ def __init__( self.img_list = self.get_filenames(self.img_path) self.mask_list = self.get_filenames(self.mask_path) - def __len__(self): + def __len__(self) -> int: return len(self.img_list) - def __getitem__(self, idx): + def __getitem__(self, idx: int): img = Image.open(self.img_list[idx]) img = img.resize(self.img_size) img = np.array(img) - mask = Image.open(self.mask_list[idx]).convert("L") + mask = Image.open(self.mask_list[idx]) mask = mask.resize(self.img_size) mask = np.array(mask) mask = self.encode_segmap(mask) - if self.transform: + if self.transform is not None: img = self.transform(img) return img, mask def encode_segmap(self, mask): - """Sets void classes to zero so they won't be considered for training.""" + """Sets all pixels of the mask with any of the `void_labels` to `ignore_index` (250 by default). + + It also sets all of the valid pixels to the appropriate value between 0 and `len(valid_labels)` (the number of + valid classes), so it can be used properly by the loss function when comparing with the output. + """ for voidc in self.void_labels: mask[mask == voidc] = self.ignore_index for validc in self.valid_labels: mask[mask == validc] = self.class_map[validc] # remove extra idxs from updated dataset - mask[mask > 18] = self.ignore_index + mask[mask > 33] = self.ignore_index return mask - def get_filenames(self, path): + def get_filenames(self, path: str): """Returns a list of absolute paths to images inside given `path`""" files_list = list() for filename in os.listdir(path): diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 882a6cbe1c..7842323b5e 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -1,3 +1,6 @@ +import os + +import numpy as np import pytest import torch from torch.utils.data import DataLoader, Dataset @@ -7,12 +10,20 @@ BinaryEMNIST, BinaryMNIST, DummyDataset, + KittiDataset, RandomDataset, RandomDictDataset, RandomDictStringDataset, ) from pl_bolts.datasets.dummy_dataset import DummyDetectionDataset from pl_bolts.datasets.sr_mnist_dataset import SRMNIST +from pl_bolts.utils import _PIL_AVAILABLE +from pl_bolts.utils.warnings import warn_missing_pkg + +if _PIL_AVAILABLE: + from PIL import Image +else: # pragma: no cover + warn_missing_pkg("PIL", pypi_name="Pillow") @pytest.mark.parametrize("batch_size,num_samples", [(16, 100), (1, 0)]) @@ -161,3 +172,36 @@ def test_binary_emnist_dataset(datadir, split): assert torch.allclose(img.min(), torch.tensor(0.0)) assert torch.allclose(img.max(), torch.tensor(1.0)) assert torch.equal(torch.unique(img), torch.tensor([0.0, 1.0])) + + +def test_kitti_dataset(datadir, catch_warnings): + """Test KittiDataset with random generated image.""" + kitti_dir = os.path.join(datadir, "data_semantics") + training_image_dir = os.path.join(kitti_dir, "training/image_2") + training_mask_dir = os.path.join(kitti_dir, "training/semantic") + + if not os.path.exists(kitti_dir): + os.makedirs(kitti_dir) + if not os.path.exists(training_image_dir): + os.makedirs(training_image_dir) + if not os.path.exists(training_mask_dir): + os.makedirs(training_mask_dir) + + img_rand = np.random.rand(377, 1243, 3) * 255 + img_rand = Image.fromarray(img_rand.astype("uint8")).convert("RGB") + img_rand.save(os.path.join(training_image_dir, "000000_10.png")) + + mask_rand = np.random.rand(377, 1243) * 33 + mask_rand = Image.fromarray(mask_rand.astype("uint8")).convert("L") + mask_rand.save(os.path.join(training_mask_dir, "000000_10.png")) + + dl = DataLoader(KittiDataset(data_dir=kitti_dir, transform=transform_lib.ToTensor())) + img, target = next(iter(dl)) + target_idx = list(range(0, 19)) + [250] + + assert img.size() == torch.Size([1, 3, 376, 1242]) + assert target.size() == torch.Size([1, 376, 1242]) + + assert torch.allclose(img.min(), torch.tensor(0.0), atol=0.01) + assert torch.allclose(img.max(), torch.tensor(1.0), atol=0.01) + assert torch.equal(torch.unique(target), torch.tensor(target_idx).to(dtype=torch.uint8))