Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redesign PatchWSIDataset #4152

Merged
merged 17 commits into from
Apr 22, 2022
8 changes: 8 additions & 0 deletions docs/source/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -303,3 +303,11 @@ OpenSlideWSIReader
~~~~~~~~~~~~~~~~~~
.. autoclass:: monai.data.OpenSlideWSIReader
:members:

Whole slide image datasets
--------------------------

PatchWSIDataset
~~~~~~~~~~~~~~~
.. autoclass:: monai.data.PatchWSIDataset
:members:
1 change: 1 addition & 0 deletions monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,5 @@
worker_init_fn,
zoom_affine,
)
from .wsi_datasets import PatchWSIDataset
from .wsi_reader import BaseWSIReader, CuCIMWSIReader, OpenSlideWSIReader, WSIReader
135 changes: 135 additions & 0 deletions monai/data/wsi_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from typing import Callable, Dict, List, Optional, Tuple, Union

import numpy as np

from monai.data import Dataset
from monai.data.wsi_reader import BaseWSIReader, WSIReader
from monai.transforms import apply_transform
from monai.utils import ensure_tuple_rep

__all__ = ["PatchWSIDataset"]


class PatchWSIDataset(Dataset):
"""
This dataset extracts patches from whole slide images (without loading the whole image)
It also reads labels for each patch and provides each patch with its associated class labels.

Args:
data: the list of input samples including image, location, and label (see the note below for more details).
size: the size of patch to be extracted from the whole slide image.
level: the level at which the patches to be extracted (default to 0).
transform: transforms to be executed on input data.
reader: the module to be used for loading whole slide imaging,
- if `reader` is a string, it defines the backend of `monai.data.WSIReader`. Defaults to cuCIM.
- if `reader` is a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader.
- if `reader` is an instance of a a class inherited from `BaseWSIReader`, it is set as the wsi_reader.
kwargs: additional arguments to pass to `WSIReader` or provided whole slide reader class

drbeh marked this conversation as resolved.
Show resolved Hide resolved
Note:
The input data has the following form as an example:

.. code-block:: python

[
{"image": "path/to/image1.tiff", "location": [200, 500], "label": 0},
{"image": "path/to/image2.tiff", "location": [100, 700], "label": 1}
]

"""

def __init__(
self,
data: List,
size: Optional[Union[int, Tuple[int, int]]] = None,
level: Optional[int] = None,
transform: Optional[Callable] = None,
reader="cuCIM",
**kwargs,
):
super().__init__(data, transform)

# Ensure patch size is a two dimensional tuple
if size is None:
self.size = None
else:
self.size = ensure_tuple_rep(size, 2)

# Create a default level that override all levels if it is not None
self.level = level
# Set the default WSIReader's level to 0 if level is not provided
if level is None:
level = 0

# Setup the WSI reader
self.wsi_reader: Union[WSIReader, BaseWSIReader]
self.backend = ""
if isinstance(reader, str):
self.backend = reader.lower()
self.wsi_reader = WSIReader(backend=self.backend, level=level, **kwargs)
elif inspect.isclass(reader) and issubclass(reader, BaseWSIReader):
self.wsi_reader = reader(level=level, **kwargs)
elif isinstance(reader, BaseWSIReader):
self.wsi_reader = reader
else:
raise ValueError(f"Unsupported reader type: {reader}.")

# Initialized an empty whole slide image object dict
self.wsi_object_dict: Dict = {}

def _get_wsi_object(self, sample: Dict):
image_path = sample["image"]
Nic-Ma marked this conversation as resolved.
Show resolved Hide resolved
if image_path not in self.wsi_object_dict:
self.wsi_object_dict[image_path] = self.wsi_reader.read(image_path)
return self.wsi_object_dict[image_path]

def _get_label(self, sample: Dict):
return np.array(sample["label"], dtype=np.float32)

def _get_location(self, sample: Dict):
size = self._get_size(sample)
return [sample["location"][i] - size[i] // 2 for i in range(len(size))]

def _get_level(self, sample: Dict):
if self.level is None:
return sample.get("level", 0)
return self.level

def _get_size(self, sample: Dict):
if self.size is None:
return ensure_tuple_rep(sample.get("size"), 2)
return self.size

def _get_data(self, sample: Dict):
# Don't store OpenSlide objects to avoid issues with OpenSlide internal cache
if self.backend == "openslide":
self.wsi_object_dict = {}
wsi_obj = self._get_wsi_object(sample)
location = self._get_location(sample)
level = self._get_level(sample)
size = self._get_size(sample)
return self.wsi_reader.get_data(wsi=wsi_obj, location=location, size=size, level=level)

def _transform(self, index: int):
# Get a single entry of data
sample: Dict = self.data[index]
# Extract patch image and associated metadata
image, metadata = self._get_data(sample)
# Get the label
label = self._get_label(sample)

# Create put all patch information together and apply transforms
patch = {"image": image, "label": label, "metadata": metadata}
return apply_transform(self.transform, patch) if self.transform else patch
169 changes: 169 additions & 0 deletions tests/test_patch_wsi_dataset_new.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import unittest
from unittest import skipUnless

import numpy as np
from numpy.testing import assert_array_equal
from parameterized import parameterized

from monai.data import PatchWSIDataset
from monai.data.wsi_reader import CuCIMWSIReader, OpenSlideWSIReader
from monai.utils import optional_import
from tests.utils import download_url_or_skip_test, testing_data_config

cucim, has_cucim = optional_import("cucim")
has_cucim = has_cucim and hasattr(cucim, "CuImage")
openslide, has_osl = optional_import("openslide")
imwrite, has_tiff = optional_import("tifffile", name="imwrite")
_, has_codec = optional_import("imagecodecs")
has_tiff = has_tiff and has_codec

FILE_KEY = "wsi_img"
FILE_URL = testing_data_config("images", FILE_KEY, "url")
base_name, extension = os.path.basename(f"{FILE_URL}"), ".tiff"
FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", "temp_" + base_name + extension)

TEST_CASE_0 = [
{"data": [{"image": FILE_PATH, "location": [0, 0], "label": [1], "level": 0}], "size": (1, 1)},
{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([1])},
]

TEST_CASE_0_L1 = [
{"data": [{"image": FILE_PATH, "location": [0, 0], "label": [1]}], "size": (1, 1), "level": 1},
{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([1])},
]

TEST_CASE_0_L2 = [
{"data": [{"image": FILE_PATH, "location": [0, 0], "label": [1]}], "size": (1, 1), "level": 1},
{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([1])},
]
TEST_CASE_1 = [
{"data": [{"image": FILE_PATH, "location": [0, 0], "size": 1, "label": [1]}]},
{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([1])},
]

TEST_CASE_2 = [
{"data": [{"image": FILE_PATH, "location": [0, 0], "label": [1]}], "size": 1, "level": 0},
{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([1])},
]

TEST_CASE_3 = [
{"data": [{"image": FILE_PATH, "location": [0, 0], "label": [[[0, 1], [1, 0]]]}], "size": 1},
{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[0, 1], [1, 0]]])},
]

TEST_CASE_4 = [
{
"data": [
{"image": FILE_PATH, "location": [0, 0], "label": [[[0, 1], [1, 0]]]},
{"image": FILE_PATH, "location": [0, 0], "label": [[[1, 0], [0, 0]]]},
],
"size": 1,
},
[
{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[0, 1], [1, 0]]])},
{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1, 0], [0, 0]]])},
],
]

TEST_CASE_5 = [
{
"data": [
{"image": FILE_PATH, "location": [0, 0], "label": [[[0, 1], [1, 0]]], "size": 1, "level": 1},
{"image": FILE_PATH, "location": [100, 100], "label": [[[1, 0], [0, 0]]], "size": 1, "level": 1},
]
},
[
{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[0, 1], [1, 0]]])},
{"image": np.array([[[243]], [[243]], [[243]]], dtype=np.uint8), "label": np.array([[[1, 0], [0, 0]]])},
],
]


@skipUnless(has_cucim or has_osl or has_tiff, "Requires cucim, openslide, or tifffile!")
def setUpModule(): # noqa: N802
hash_type = testing_data_config("images", FILE_KEY, "hash_type")
hash_val = testing_data_config("images", FILE_KEY, "hash_val")
download_url_or_skip_test(FILE_URL, FILE_PATH, hash_type=hash_type, hash_val=hash_val)


class PatchWSIDatasetTests:
class Tests(unittest.TestCase):
backend = None

@parameterized.expand([TEST_CASE_0, TEST_CASE_0_L1, TEST_CASE_0_L2, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_read_patches_str(self, input_parameters, expected):
dataset = PatchWSIDataset(reader=self.backend, **input_parameters)
sample = dataset[0]
self.assertTupleEqual(sample["label"].shape, expected["label"].shape)
self.assertTupleEqual(sample["image"].shape, expected["image"].shape)
self.assertIsNone(assert_array_equal(sample["label"], expected["label"]))
self.assertIsNone(assert_array_equal(sample["image"], expected["image"]))

@parameterized.expand([TEST_CASE_0, TEST_CASE_0_L1, TEST_CASE_0_L2, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_read_patches_class(self, input_parameters, expected):
if self.backend == "openslide":
reader = OpenSlideWSIReader
elif self.backend == "cucim":
reader = CuCIMWSIReader
else:
raise ValueError("Unsupported backend: {self.backend}")
dataset = PatchWSIDataset(reader=reader, **input_parameters)
sample = dataset[0]
self.assertTupleEqual(sample["label"].shape, expected["label"].shape)
self.assertTupleEqual(sample["image"].shape, expected["image"].shape)
self.assertIsNone(assert_array_equal(sample["label"], expected["label"]))
self.assertIsNone(assert_array_equal(sample["image"], expected["image"]))

@parameterized.expand([TEST_CASE_0, TEST_CASE_0_L1, TEST_CASE_0_L2, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_read_patches_object(self, input_parameters, expected):
if self.backend == "openslide":
reader = OpenSlideWSIReader(level=input_parameters.get("level", 0))
elif self.backend == "cucim":
reader = CuCIMWSIReader(level=input_parameters.get("level", 0))
else:
raise ValueError("Unsupported backend: {self.backend}")
dataset = PatchWSIDataset(reader=reader, **input_parameters)
sample = dataset[0]
self.assertTupleEqual(sample["label"].shape, expected["label"].shape)
self.assertTupleEqual(sample["image"].shape, expected["image"].shape)
self.assertIsNone(assert_array_equal(sample["label"], expected["label"]))
self.assertIsNone(assert_array_equal(sample["image"], expected["image"]))

@parameterized.expand([TEST_CASE_4, TEST_CASE_5])
def test_read_patches_str_multi(self, input_parameters, expected):
dataset = PatchWSIDataset(reader=self.backend, **input_parameters)
for i in range(len(dataset)):
self.assertTupleEqual(dataset[i]["label"].shape, expected[i]["label"].shape)
self.assertTupleEqual(dataset[i]["image"].shape, expected[i]["image"].shape)
self.assertIsNone(assert_array_equal(dataset[i]["label"], expected[i]["label"]))
self.assertIsNone(assert_array_equal(dataset[i]["image"], expected[i]["image"]))


@skipUnless(has_cucim, "Requires cucim")
class TestPatchWSIDatasetCuCIM(PatchWSIDatasetTests.Tests):
@classmethod
def setUpClass(cls):
cls.backend = "cucim"


@skipUnless(has_osl, "Requires cucim")
class TestPatchWSIDatasetOpenSlide(PatchWSIDatasetTests.Tests):
@classmethod
def setUpClass(cls):
cls.backend = "openslide"


if __name__ == "__main__":
unittest.main()