Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revision datasets.kitti_dataset.KittiDataset #896

Merged
merged 7 commits into from
Oct 5, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci_test-full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ jobs:
python --version
pip --version
# python -m pip install --upgrade --user pip
pip install --requirement requirements/devel.txt --upgrade --quiet --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
pip install --requirement requirements/devel.txt --upgrade --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
pip list
shell: bash

Expand Down
67 changes: 33 additions & 34 deletions pl_bolts/datasets/kitti_dataset.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,39 @@
import os
from typing import Callable, Optional, Tuple

import numpy as np
from torch.utils.data import Dataset

from pl_bolts.utils import _PIL_AVAILABLE
from pl_bolts.utils.stability import under_review
from pl_bolts.utils.warnings import warn_missing_pkg

if _PIL_AVAILABLE:
from PIL import Image
else: # pragma: no cover
warn_missing_pkg("PIL", pypi_name="Pillow")

DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1)
KITTI_LABELS = tuple(range(-1, 34))
DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33)


@under_review()
class KittiDataset(Dataset):
"""
Note:
You need to have downloaded the Kitti dataset first and provide the path to where it is saved.
You can download the dataset here: http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015

There are 34 classes, however not all of them are useful for training (e.g. railings on highways). These
useless classes (the pixel values of these classes) are stored in `void_labels`. Useful classes are stored
in `valid_labels`.

The `encode_segmap` function sets all pixels with any of the `void_labels` to `ignore_index`
(250 by default). It also sets all of the valid pixels to the appropriate value between 0 and
`len(valid_labels)` (since that is the number of valid classes), so it can be used properly by
the loss function when comparing with the output.
"""KITTI Dataset for sematic segmentation.

You need to have downloaded the Kitti semantic dataset first and provide the path to where it is saved.
You can download the dataset here: http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015

There are 34 classes, however not all of them are useful for training (e.g. railings on highways).
Useful classes (the pixel values of these classes) are stored in `valid_labels`, other labels
except useful classes are stored in `void_labels`.

The class id and valid labels(`ignoreInEval`) can be found in here:
https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py

Args:
data_dir (str): where to load the data from path, i.e. '/path/to/folder/with/data_semantics/'
img_size (tuple): image dimensions (width, height)
valid_labels (tuple): useful classes to include
transform (callable, optional): A function/transform that takes in the numpy array and transforms it.
"""

IMAGE_PATH = os.path.join("training", "image_2")
Expand All @@ -40,23 +43,15 @@ def __init__(
self,
data_dir: str,
img_size: tuple = (1242, 376),
void_labels: list = DEFAULT_VOID_LABELS,
valid_labels: list = DEFAULT_VALID_LABELS,
transform=None,
valid_labels: Tuple[int] = DEFAULT_VALID_LABELS,
transform: Optional[Callable] = None,
):
"""
Args:
data_dir (str): where to load the data from path, i.e. '/path/to/folder/with/data_semantics/'
img_size: image dimensions (width, height)
void_labels: useless classes to be excluded from training
valid_labels: useful classes to include
"""
if not _PIL_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError("You want to use `PIL` which is not installed yet.")

self.img_size = img_size
self.void_labels = void_labels
self.valid_labels = valid_labels
self.void_labels = tuple(label for label in KITTI_LABELS if label not in self.valid_labels)
self.ignore_index = 250
self.class_map = dict(zip(self.valid_labels, range(len(self.valid_labels))))
self.transform = transform
Expand All @@ -67,35 +62,39 @@ def __init__(
self.img_list = self.get_filenames(self.img_path)
self.mask_list = self.get_filenames(self.mask_path)

def __len__(self):
def __len__(self) -> int:
return len(self.img_list)

def __getitem__(self, idx):
def __getitem__(self, idx: int):
img = Image.open(self.img_list[idx])
img = img.resize(self.img_size)
img = np.array(img)

mask = Image.open(self.mask_list[idx]).convert("L")
mask = Image.open(self.mask_list[idx])
mask = mask.resize(self.img_size)
mask = np.array(mask)
mask = self.encode_segmap(mask)

if self.transform:
if self.transform is not None:
img = self.transform(img)

return img, mask

def encode_segmap(self, mask):
"""Sets void classes to zero so they won't be considered for training."""
"""Sets all pixels of the mask with any of the `void_labels` to `ignore_index` (250 by default).

It also sets all of the valid pixels to the appropriate value between 0 and `len(valid_labels)` (the number of
valid classes), so it can be used properly by the loss function when comparing with the output.
"""
for voidc in self.void_labels:
mask[mask == voidc] = self.ignore_index
for validc in self.valid_labels:
mask[mask == validc] = self.class_map[validc]
# remove extra idxs from updated dataset
mask[mask > 18] = self.ignore_index
mask[mask > 33] = self.ignore_index
return mask

def get_filenames(self, path):
def get_filenames(self, path: str):
"""Returns a list of absolute paths to images inside given `path`"""
files_list = list()
for filename in os.listdir(path):
Expand Down
44 changes: 44 additions & 0 deletions tests/datasets/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os

import numpy as np
import pytest
import torch
from torch.utils.data import DataLoader, Dataset
Expand All @@ -7,12 +10,20 @@
BinaryEMNIST,
BinaryMNIST,
DummyDataset,
KittiDataset,
RandomDataset,
RandomDictDataset,
RandomDictStringDataset,
)
from pl_bolts.datasets.dummy_dataset import DummyDetectionDataset
from pl_bolts.datasets.sr_mnist_dataset import SRMNIST
from pl_bolts.utils import _PIL_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _PIL_AVAILABLE:
from PIL import Image
else: # pragma: no cover
warn_missing_pkg("PIL", pypi_name="Pillow")


@pytest.mark.parametrize("batch_size,num_samples", [(16, 100), (1, 0)])
Expand Down Expand Up @@ -164,3 +175,36 @@ def test_binary_emnist_dataset(datadir, split):
assert torch.allclose(img.min(), torch.tensor(0.0))
assert torch.allclose(img.max(), torch.tensor(1.0))
assert torch.equal(torch.unique(img), torch.tensor([0.0, 1.0]))


def test_kitti_dataset(datadir, catch_warnings):
"""Test KittiDataset with random generated image."""
kitti_dir = os.path.join(datadir, "data_semantics")
training_image_dir = os.path.join(kitti_dir, "training/image_2")
training_mask_dir = os.path.join(kitti_dir, "training/semantic")

if not os.path.exists(kitti_dir):
os.makedirs(kitti_dir)
if not os.path.exists(training_image_dir):
os.makedirs(training_image_dir)
if not os.path.exists(training_mask_dir):
os.makedirs(training_mask_dir)

img_rand = np.random.rand(377, 1243, 3) * 255
img_rand = Image.fromarray(img_rand.astype("uint8")).convert("RGB")
img_rand.save(os.path.join(training_image_dir, "000000_10.png"))

mask_rand = np.random.rand(377, 1243) * 33
mask_rand = Image.fromarray(mask_rand.astype("uint8")).convert("L")
mask_rand.save(os.path.join(training_mask_dir, "000000_10.png"))

dl = DataLoader(KittiDataset(data_dir=kitti_dir, transform=transform_lib.ToTensor()))
img, target = next(iter(dl))
target_idx = list(range(0, 19)) + [250]

assert img.size() == torch.Size([1, 3, 376, 1242])
assert target.size() == torch.Size([1, 376, 1242])

assert torch.allclose(img.min(), torch.tensor(0.0), atol=0.01)
assert torch.allclose(img.max(), torch.tensor(1.0), atol=0.01)
assert torch.equal(torch.unique(target), torch.tensor(target_idx).to(dtype=torch.uint8))