diff --git a/sscma/datasets/dataset_wrappers.py b/sscma/datasets/dataset_wrappers.py index 6949b18e..2c324cdb 100644 --- a/sscma/datasets/dataset_wrappers.py +++ b/sscma/datasets/dataset_wrappers.py @@ -1,8 +1,11 @@ +from typing import Union + from torch.utils.data.dataset import ConcatDataset as _ConcatDataset from sscma.registry import DATASETS +from mmdet.datasets import CocoDataset -@DATASETS.register_module(_ConcatDataset) +@DATASETS.register_module() class SemiDataset(_ConcatDataset): """ For merging real labeled and pseudo-labeled datasets in semi-supervised. @@ -13,12 +16,12 @@ class SemiDataset(_ConcatDataset): """ def __init__(self, sup_dataset: dict, unsup_dataset: dict, **kwargs) -> None: - self._sup_dataset = DATASETS.build(sup_dataset) - self._unsup_dataset = DATASETS.build(unsup_dataset) + self._sup_dataset: CocoDataset = DATASETS.build(sup_dataset) + self._unsup_dataset: CocoDataset = DATASETS.build(unsup_dataset) super(SemiDataset, self).__init__((self._sup_dataset, self._unsup_dataset)) - self.CLASSES = self.sup_dataset.CLASSES + self.CLASSES: Union[list, tuple] = self.sup_dataset.METAINFO['classes'] @property def sup_dataset(self):