diff --git a/albumentations/__init__.py b/albumentations/__init__.py index d22cf383a..e4d50b21c 100644 --- a/albumentations/__init__.py +++ b/albumentations/__init__.py @@ -1,4 +1,4 @@ -__version__ = "1.4.5" +__version__ = "1.4.6" from .augmentations import * from .core.composition import * diff --git a/albumentations/core/composition.py b/albumentations/core/composition.py index 2edac0cfd..7a4e9dfc3 100644 --- a/albumentations/core/composition.py +++ b/albumentations/core/composition.py @@ -1,7 +1,7 @@ import random import warnings from collections import defaultdict -from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Union, cast +from typing import Any, Dict, Iterator, List, Optional, Sequence, Union, cast import cv2 import numpy as np @@ -65,9 +65,7 @@ def __init__(self, transforms: TransformsSeqType, p: float): self.replay_mode = False self.applied_in_replay = False self._additional_targets: Dict[str, str] = {} - self._available_keys: Set[str] = set() self.processors: Dict[str, Union[BboxProcessor, KeypointsProcessor]] = {} - self._set_keys() def __iter__(self) -> Iterator[TransformType]: return iter(self.transforms) @@ -88,10 +86,6 @@ def __repr__(self) -> str: def additional_targets(self) -> Dict[str, str]: return self._additional_targets - @property - def available_keys(self) -> Set[str]: - return self._available_keys - def indented_repr(self, indent: int = REPR_INDENT_STEP) -> str: args = {k: v for k, v in self.to_dict_private().items() if not (k.startswith("__") or k == "transforms")} repr_string = self.__class__.__name__ + "([" @@ -133,22 +127,11 @@ def add_targets(self, additional_targets: Optional[Dict[str, str]]) -> None: f"Trying to overwrite existed additional targets. " f"Key={k} Exists={self._additional_targets[k]} New value: {v}", ) - self._additional_targets.update(additional_targets) + self._additional_targets.update(additional_targets) for t in self.transforms: t.add_targets(additional_targets) for proc in self.processors.values(): proc.add_targets(additional_targets) - self._set_keys() - - def _set_keys(self) -> None: - """Set _available_keys""" - for t in self.transforms: - self._available_keys.update(t.available_keys) - if self.processors: - self._available_keys.update(["labels"]) - for proc in self.processors.values(): - if proc.params.label_fields: - self._available_keys.update(proc.params.label_fields) def set_deterministic(self, flag: bool, save_key: str = "replay") -> None: for t in self.transforms: @@ -209,7 +192,6 @@ def __init__( self._disable_check_args_for_transforms(self.transforms) self.is_check_shapes = is_check_shapes - self._always_apply = get_always_apply(self.transforms) # transforms list that always apply self._check_each_transform = tuple( # processors that checks after each transform proc for proc in self.processors.values() if getattr(proc.params, "check_each_transform", False) ) @@ -229,22 +211,18 @@ def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> Dict[s if args: msg = "You have to pass data to augmentations as named arguments, for example: aug(image=image)" raise KeyError(msg) + if self.is_check_args: + self._check_args(**data) if not isinstance(force_apply, (bool, int)): msg = "force_apply must have bool or int type" raise TypeError(msg) need_to_run = force_apply or random.random() < self.p - if not need_to_run and not self._always_apply: - return data - - transforms = self.transforms if need_to_run else self._always_apply - - if self.is_check_args: - self._check_args(**data) for p in self.processors.values(): p.ensure_data_valid(data) + transforms = self.transforms if need_to_run else get_always_apply(self.transforms) for p in self.processors.values(): p.preprocess(data) @@ -308,9 +286,6 @@ def _check_args(self, **kwargs: Any) -> None: check_keypoints_param = ["keypoints"] shapes = [] for data_name, data in kwargs.items(): - if data_name not in self._available_keys and data_name not in ["mask", "masks"]: - msg = f"Key {data_name} is not in available keys." - raise ValueError(msg) internal_data_name = self._additional_targets.get(data_name, data_name) if internal_data_name in checked_single: if not isinstance(data, np.ndarray): @@ -518,7 +493,6 @@ def __init__( super().__init__(transforms, bbox_params, keypoint_params, additional_targets, p, is_check_shapes) self.set_deterministic(True, save_key=save_key) self.save_key = save_key - self._available_keys.add(save_key) def __call__(self, *args: Any, force_apply: bool = False, **kwargs: Any) -> Dict[str, Any]: kwargs[self.save_key] = defaultdict(dict) diff --git a/albumentations/core/transforms_interface.py b/albumentations/core/transforms_interface.py index 8504f89fc..82760d517 100644 --- a/albumentations/core/transforms_interface.py +++ b/albumentations/core/transforms_interface.py @@ -1,6 +1,6 @@ import random from copy import deepcopy -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union, cast +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast from warnings import warn import cv2 @@ -43,12 +43,8 @@ class CombinedMeta(SerializableMeta, ValidatedTransformMeta): class BasicTransform(Serializable, metaclass=CombinedMeta): - _targets: Union[Tuple[Targets, ...], Targets] # targets that this transform can work on - _available_keys: Set[str] # targets that this transform, as string, lower-cased - _key2func: Dict[ - str, - Callable[..., Any], - ] # mapping for targets (plus additional targets) and methods for which they depend + # `_targets` defines the types of targets (e.g., image, mask) that the transform can be applied to. + _targets: Union[Tuple[Targets, ...], Targets] call_backup = None interpolation: int fill_value: ColorType @@ -68,8 +64,6 @@ def __init__(self, always_apply: bool = False, p: float = 0.5): self._additional_targets: Dict[str, str] = {} # replay mode params self.params: Dict[Any, Any] = {} - self._key2func = {} - self._set_keys() def __call__(self, *args: Any, force_apply: bool = False, **kwargs: Any) -> Any: if args: @@ -103,11 +97,12 @@ def apply_with_params(self, params: Dict[str, Any], *args: Any, **kwargs: Any) - params = self.update_params(params, **kwargs) res = {} for key, arg in kwargs.items(): - if key in self._key2func and arg is not None: - target_function = self._key2func[key] - res[key] = target_function(arg, **params) + if arg is not None: + target_function = self._get_target_function(key) + target_dependencies = {k: kwargs[k] for k in self.target_dependence.get(key, [])} + res[key] = target_function(arg, **dict(params, **target_dependencies)) else: - res[key] = arg + res[key] = None return res def set_deterministic(self, flag: bool, save_key: str = "replay") -> "BasicTransform": @@ -130,6 +125,14 @@ def __repr__(self) -> str: state.update(self.get_transform_init_args()) return f"{self.__class__.__name__}({format_args(state)})" + def _get_target_function(self, key: str) -> Callable[..., Any]: + """Returns function to process target""" + transform_key = key + if key in self._additional_targets: + transform_key = self._additional_targets.get(key, key) + + return self.targets.get(transform_key, lambda x, **p: x) + def apply(self, img: np.ndarray, *args: Any, **params: Any) -> np.ndarray: """Apply transform on image.""" raise NotImplementedError @@ -146,23 +149,6 @@ def targets(self) -> Dict[str, Callable[..., Any]]: # >> {"masks": self.apply_to_masks} raise NotImplementedError - def _set_keys(self) -> None: - """Set _available_keys""" - if not hasattr(self, "_targets"): - self._available_keys = set() - else: - self._available_keys = { - target.value.lower() - for target in (self._targets if isinstance(self._targets, tuple) else [self._targets]) - } - self._available_keys.update(self.targets.keys()) - self._key2func = {key: self.targets[key] for key in self._available_keys if key in self.targets} - - @property - def available_keys(self) -> Set[str]: - """Returns set of available keys""" - return self._available_keys - def update_params(self, params: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: """Update parameters with transform specific params""" if hasattr(self, "interpolation"): @@ -174,6 +160,10 @@ def update_params(self, params: Dict[str, Any], **kwargs: Any) -> Dict[str, Any] params.update({"cols": kwargs["image"].shape[1], "rows": kwargs["image"].shape[0]}) return params + @property + def target_dependence(self) -> Dict[str, Any]: + return {} + def add_targets(self, additional_targets: Dict[str, str]) -> None: """Add targets to transform them the same way as one of existing targets ex: {'target_image': 'image'} @@ -184,16 +174,7 @@ def add_targets(self, additional_targets: Dict[str, str]) -> None: additional_targets (dict): keys - new target name, values - old target name. ex: {'image2': 'image'} """ - for k, v in additional_targets.items(): - if k in self._additional_targets and v != self._additional_targets[k]: - raise ValueError( - f"Trying to overwrite existed additional targets. " - f"Key={k} Exists={self._additional_targets[k]} New value: {v}", - ) - if v in self._available_keys: - self._additional_targets[k] = v - self._key2func[k] = self.targets[v] - self._available_keys.add(k) + self._additional_targets = {**self._additional_targets, **additional_targets} @property def targets_as_params(self) -> List[str]: diff --git a/albumentations/pytorch/transforms.py b/albumentations/pytorch/transforms.py index 853032fad..d5ac1f8ff 100644 --- a/albumentations/pytorch/transforms.py +++ b/albumentations/pytorch/transforms.py @@ -4,7 +4,6 @@ import torch from albumentations.core.transforms_interface import BasicTransform -from albumentations.core.types import Targets __all__ = ["ToTensorV2"] @@ -24,8 +23,6 @@ class ToTensorV2(BasicTransform): """ - _targets = (Targets.IMAGE, Targets.MASK) - def __init__(self, transpose_mask: bool = False, always_apply: bool = True, p: float = 1.0): super().__init__(always_apply=always_apply, p=p) self.transpose_mask = transpose_mask @@ -54,3 +51,6 @@ def apply_to_masks(self, masks: List[np.ndarray], **params: Any) -> List[torch.T def get_transform_init_args_names(self) -> Tuple[str, ...]: return ("transpose_mask",) + + def get_params_dependent_on_targets(self, params: Any) -> Dict[str, Any]: + return {} diff --git a/tests/test_core.py b/tests/test_core.py index f75ee2b1f..9f569458e 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -53,8 +53,8 @@ def test_one_or_other(): def test_compose(): - first = MagicMock(available_keys={"image"}) - second = MagicMock(available_keys={"image"}) + first = MagicMock() + second = MagicMock() augmentation = Compose([first, second], p=1) image = np.ones((8, 8)) augmentation(image=image) @@ -70,8 +70,8 @@ def oneof_always_apply_crash(): def test_always_apply(): - first = MagicMock(always_apply=True, available_keys={"image"}) - second = MagicMock(always_apply=False, available_keys={"image"}) + first = MagicMock(always_apply=True) + second = MagicMock(always_apply=False) augmentation = Compose([first, second], p=0) image = np.ones((8, 8)) augmentation(image=image) @@ -80,7 +80,7 @@ def test_always_apply(): def test_one_of(): - transforms = [Mock(p=1, available_keys={"image"}) for _ in range(10)] + transforms = [Mock(p=1) for _ in range(10)] augmentation = OneOf(transforms, p=1) image = np.ones((8, 8)) augmentation(image=image) @@ -90,7 +90,7 @@ def test_one_of(): @pytest.mark.parametrize("N", [1, 2, 5, 10]) @pytest.mark.parametrize("replace", [True, False]) def test_n_of(N, replace): - transforms = [Mock(p=1, side_effect=lambda **kw: {"image": kw["image"]}, available_keys={"image"}) for _ in range(10)] + transforms = [Mock(p=1, side_effect=lambda **kw: {"image": kw["image"]}) for _ in range(10)] augmentation = SomeOf(transforms, N, p=1, replace=replace) image = np.ones((8, 8)) augmentation(image=image) @@ -100,7 +100,7 @@ def test_n_of(N, replace): def test_sequential(): - transforms = [Mock(side_effect=lambda **kw: kw, available_keys={"image"}) for _ in range(10)] + transforms = [Mock(side_effect=lambda **kw: kw) for _ in range(10)] augmentation = Sequential(transforms, p=1) image = np.ones((8, 8)) augmentation(image=image) @@ -254,13 +254,13 @@ def test_named_args(): ], ) def test_targets_type_check(targets, additional_targets, err_message): - aug = Compose([A.NoOp()], additional_targets=additional_targets) + aug = Compose([], additional_targets=additional_targets) with pytest.raises(TypeError) as exc_info: aug(**targets) assert str(exc_info.value) == err_message - aug = Compose([A.NoOp()]) + aug = Compose([]) aug.add_targets(additional_targets) with pytest.raises(TypeError) as exc_info: aug(**targets) @@ -353,7 +353,7 @@ def test_check_each_transform(targets, bbox_params, keypoint_params, expected): @pytest.mark.parametrize("image", IMAGES) def test_bbox_params_is_not_set(image, bboxes): - t = Compose([A.NoOp(p=1.0)]) + t = Compose([]) with pytest.raises(ValueError) as exc_info: t(image=image, bboxes=bboxes) assert str(exc_info.value) == "bbox_params must be specified for bbox transformations" @@ -394,7 +394,7 @@ def test_choice_inner_compositions(transforms): "transforms", [ Compose([ChannelShuffle(p=1)], p=1), - # Compose([ChannelShuffle(p=0)], p=0), # p=0, never calls, no process for data + Compose([ChannelShuffle(p=0)], p=0), ], ) def test_contiguous_output(transforms): @@ -421,7 +421,7 @@ def test_contiguous_output(transforms): ], ) def test_compose_image_mask_equal_size(targets): - transforms = Compose([A.NoOp()]) + transforms = Compose([]) with pytest.raises(ValueError) as exc_info: transforms(**targets) @@ -432,7 +432,7 @@ def test_compose_image_mask_equal_size(targets): "of Compose class (do it only if you are sure about your data consistency)." ) # test after disabling shapes check - transforms = Compose([A.NoOp()], is_check_shapes=False) + transforms = Compose([], is_check_shapes=False) transforms(**targets) @@ -493,39 +493,88 @@ def test_sequential_multiple_transformations(image, aug): assert np.array_equal(result['mask'], mask) -def test_compose_non_available_keys() -> None: - """Check that non available keys raises error, except `mask` and `masks`""" - transform = A.Compose( - [MagicMock(available_keys={"image"}),], - ) - image = np.empty([10, 10, 3], dtype=np.uint8) - mask = np.empty([10, 10], dtype=np.uint8) - _res = transform(image=image, mask=mask) - _res = transform(image=image, masks=[mask]) - with pytest.raises(ValueError) as exc_info: - _res = transform(image=image, image_2=mask) - - expected_msg = "Key image_2 is not in available keys." - assert str(exc_info.value) == expected_msg +@pytest.mark.parametrize( + "transforms", + [ + [ # image only + A.Blur(p=1), + A.MedianBlur(p=1), + A.ToGray(p=1), + A.CLAHE(p=1), + A.RandomBrightnessContrast(p=1), + A.RandomGamma(p=1), + A.ImageCompression(quality_lower=75, p=1), + ], + [ # with dual + A.Blur(p=1), + A.MedianBlur(p=1), + A.ToGray(p=1), + A.CLAHE(p=1), + A.RandomBrightnessContrast(p=1), + A.RandomGamma(p=1), + A.ImageCompression(quality_lower=75, p=1), + A.Crop(x_max=50, y_max=50), + ] + ] +) +@pytest.mark.parametrize( + ["compose_args", "args"], + [ + [ + {}, + {"image": np.empty([100, 100, 3], dtype=np.uint8)} + ], + [ + {}, + { + "image": np.empty([100, 100, 3], dtype=np.uint8), + "mask": np.empty([100, 100, 3], dtype=np.uint8), + } + ], + [ + {}, + { + "image": np.empty([100, 100, 3], dtype=np.uint8), + "masks": [np.empty([100, 100, 3], dtype=np.uint8)] * 3, + } + ], + [ + dict(bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"])), + { + "image": np.empty([100, 100, 3], dtype=np.uint8), + "bboxes": [[0.5, 0.5, 0.1, 0.1]], + "class_labels": [1], + } + ], + [ + dict(keypoint_params=A.KeypointParams(format="xy", label_fields=["class_labels"])), + { + "image": np.empty([100, 100, 3], dtype=np.uint8), + "keypoints": [[10, 20]], + "class_labels": [1], + } + ], + [ + dict( + bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels_1"]), + keypoint_params=A.KeypointParams(format="xy", label_fields=["class_labels_2"]) + ), + { + "image": np.empty([100, 100, 3], dtype=np.uint8), + "mask": np.empty([100, 100, 3], dtype=np.uint8), + "bboxes": [[0.5, 0.5, 0.1, 0.1]], + "class_labels_1": [1], + "keypoints": [[10, 20]], + "class_labels_2": [1], + } + ], + ] +) +def test_common_pipeline_validity(transforms: list, compose_args: dict, args: dict): + # Just check that everything is fine - no errors + pipeline = A.Compose(transforms, **compose_args) -def test_compose_without_keys() -> None: - """Check that absent of key not raises error""" - image = np.empty([10, 10, 3], dtype=np.uint8) - keypoints = [[1, 1], [7, 7]] - bboxes = [[0, 0, 7, 7, 0],] - transform = A.Compose( - [A.NoOp(),], - keypoint_params=A.KeypointParams(format="xy"), - bbox_params=A.BboxParams(format="pascal_voc"), - ) - res = transform(image=image, keypoints=keypoints, bboxes=bboxes) - assert "keypoints" in res - assert "bboxes" in res - res = transform(image=image) - assert "keypoints" not in res - assert "bboxes" not in res - res = transform(image=image, keypoints=[]) - assert res["keypoints"] == [] - res = transform(image=image, bboxes=[]) - assert res["bboxes"] == [] + res = pipeline(**args) + for k in args.keys(): + assert k in res