Skip to content

Commit

Permalink
Revision datasets.kitti_dataset.KittiDataset (#896)
Browse files Browse the repository at this point in the history
Co-authored-by: otaj <[email protected]>
  • Loading branch information
lijm1358 and otaj authored Oct 5, 2022
1 parent 9dd72c4 commit 79c6e24
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 34 deletions.
67 changes: 33 additions & 34 deletions pl_bolts/datasets/kitti_dataset.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,39 @@
import os
from typing import Callable, Optional, Tuple

import numpy as np
from torch.utils.data import Dataset

from pl_bolts.utils import _PIL_AVAILABLE
from pl_bolts.utils.stability import under_review
from pl_bolts.utils.warnings import warn_missing_pkg

if _PIL_AVAILABLE:
from PIL import Image
else: # pragma: no cover
warn_missing_pkg("PIL", pypi_name="Pillow")

DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1)
KITTI_LABELS = tuple(range(-1, 34))
DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33)


@under_review()
class KittiDataset(Dataset):
"""
Note:
You need to have downloaded the Kitti dataset first and provide the path to where it is saved.
You can download the dataset here: http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015
There are 34 classes, however not all of them are useful for training (e.g. railings on highways). These
useless classes (the pixel values of these classes) are stored in `void_labels`. Useful classes are stored
in `valid_labels`.
The `encode_segmap` function sets all pixels with any of the `void_labels` to `ignore_index`
(250 by default). It also sets all of the valid pixels to the appropriate value between 0 and
`len(valid_labels)` (since that is the number of valid classes), so it can be used properly by
the loss function when comparing with the output.
"""KITTI Dataset for sematic segmentation.
You need to have downloaded the Kitti semantic dataset first and provide the path to where it is saved.
You can download the dataset here: http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015
There are 34 classes, however not all of them are useful for training (e.g. railings on highways).
Useful classes (the pixel values of these classes) are stored in `valid_labels`, other labels
except useful classes are stored in `void_labels`.
The class id and valid labels(`ignoreInEval`) can be found in here:
https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py
Args:
data_dir (str): where to load the data from path, i.e. '/path/to/folder/with/data_semantics/'
img_size (tuple): image dimensions (width, height)
valid_labels (tuple): useful classes to include
transform (callable, optional): A function/transform that takes in the numpy array and transforms it.
"""

IMAGE_PATH = os.path.join("training", "image_2")
Expand All @@ -40,23 +43,15 @@ def __init__(
self,
data_dir: str,
img_size: tuple = (1242, 376),
void_labels: list = DEFAULT_VOID_LABELS,
valid_labels: list = DEFAULT_VALID_LABELS,
transform=None,
valid_labels: Tuple[int] = DEFAULT_VALID_LABELS,
transform: Optional[Callable] = None,
):
"""
Args:
data_dir (str): where to load the data from path, i.e. '/path/to/folder/with/data_semantics/'
img_size: image dimensions (width, height)
void_labels: useless classes to be excluded from training
valid_labels: useful classes to include
"""
if not _PIL_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError("You want to use `PIL` which is not installed yet.")

self.img_size = img_size
self.void_labels = void_labels
self.valid_labels = valid_labels
self.void_labels = tuple(label for label in KITTI_LABELS if label not in self.valid_labels)
self.ignore_index = 250
self.class_map = dict(zip(self.valid_labels, range(len(self.valid_labels))))
self.transform = transform
Expand All @@ -67,35 +62,39 @@ def __init__(
self.img_list = self.get_filenames(self.img_path)
self.mask_list = self.get_filenames(self.mask_path)

def __len__(self):
def __len__(self) -> int:
return len(self.img_list)

def __getitem__(self, idx):
def __getitem__(self, idx: int):
img = Image.open(self.img_list[idx])
img = img.resize(self.img_size)
img = np.array(img)

mask = Image.open(self.mask_list[idx]).convert("L")
mask = Image.open(self.mask_list[idx])
mask = mask.resize(self.img_size)
mask = np.array(mask)
mask = self.encode_segmap(mask)

if self.transform:
if self.transform is not None:
img = self.transform(img)

return img, mask

def encode_segmap(self, mask):
"""Sets void classes to zero so they won't be considered for training."""
"""Sets all pixels of the mask with any of the `void_labels` to `ignore_index` (250 by default).
It also sets all of the valid pixels to the appropriate value between 0 and `len(valid_labels)` (the number of
valid classes), so it can be used properly by the loss function when comparing with the output.
"""
for voidc in self.void_labels:
mask[mask == voidc] = self.ignore_index
for validc in self.valid_labels:
mask[mask == validc] = self.class_map[validc]
# remove extra idxs from updated dataset
mask[mask > 18] = self.ignore_index
mask[mask > 33] = self.ignore_index
return mask

def get_filenames(self, path):
def get_filenames(self, path: str):
"""Returns a list of absolute paths to images inside given `path`"""
files_list = list()
for filename in os.listdir(path):
Expand Down
44 changes: 44 additions & 0 deletions tests/datasets/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os

import numpy as np
import pytest
import torch
from torch.utils.data import DataLoader, Dataset
Expand All @@ -7,12 +10,20 @@
BinaryEMNIST,
BinaryMNIST,
DummyDataset,
KittiDataset,
RandomDataset,
RandomDictDataset,
RandomDictStringDataset,
)
from pl_bolts.datasets.dummy_dataset import DummyDetectionDataset
from pl_bolts.datasets.sr_mnist_dataset import SRMNIST
from pl_bolts.utils import _PIL_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _PIL_AVAILABLE:
from PIL import Image
else: # pragma: no cover
warn_missing_pkg("PIL", pypi_name="Pillow")


@pytest.mark.parametrize("batch_size,num_samples", [(16, 100), (1, 0)])
Expand Down Expand Up @@ -161,3 +172,36 @@ def test_binary_emnist_dataset(datadir, split):
assert torch.allclose(img.min(), torch.tensor(0.0))
assert torch.allclose(img.max(), torch.tensor(1.0))
assert torch.equal(torch.unique(img), torch.tensor([0.0, 1.0]))


def test_kitti_dataset(datadir, catch_warnings):
"""Test KittiDataset with random generated image."""
kitti_dir = os.path.join(datadir, "data_semantics")
training_image_dir = os.path.join(kitti_dir, "training/image_2")
training_mask_dir = os.path.join(kitti_dir, "training/semantic")

if not os.path.exists(kitti_dir):
os.makedirs(kitti_dir)
if not os.path.exists(training_image_dir):
os.makedirs(training_image_dir)
if not os.path.exists(training_mask_dir):
os.makedirs(training_mask_dir)

img_rand = np.random.rand(377, 1243, 3) * 255
img_rand = Image.fromarray(img_rand.astype("uint8")).convert("RGB")
img_rand.save(os.path.join(training_image_dir, "000000_10.png"))

mask_rand = np.random.rand(377, 1243) * 33
mask_rand = Image.fromarray(mask_rand.astype("uint8")).convert("L")
mask_rand.save(os.path.join(training_mask_dir, "000000_10.png"))

dl = DataLoader(KittiDataset(data_dir=kitti_dir, transform=transform_lib.ToTensor()))
img, target = next(iter(dl))
target_idx = list(range(0, 19)) + [250]

assert img.size() == torch.Size([1, 3, 376, 1242])
assert target.size() == torch.Size([1, 376, 1242])

assert torch.allclose(img.min(), torch.tensor(0.0), atol=0.01)
assert torch.allclose(img.max(), torch.tensor(1.0), atol=0.01)
assert torch.equal(torch.unique(target), torch.tensor(target_idx).to(dtype=torch.uint8))

0 comments on commit 79c6e24

Please sign in to comment.