diff --git a/albumentations/__init__.py b/albumentations/__init__.py index 00f0d5d55..a12ef7bfa 100644 --- a/albumentations/__init__.py +++ b/albumentations/__init__.py @@ -1,4 +1,4 @@ -__version__ = "1.4.0" +__version__ = "1.4.1" from .augmentations import * from .core.composition import * diff --git a/albumentations/augmentations/mixing/transforms.py b/albumentations/augmentations/mixing/transforms.py index 639c04f76..0155721ac 100644 --- a/albumentations/augmentations/mixing/transforms.py +++ b/albumentations/augmentations/mixing/transforms.py @@ -1,5 +1,6 @@ import random -from typing import Any, Callable, Dict, Generator, Iterable, Iterator, Optional, Sequence, Tuple, Union +import types +from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Sequence, Tuple, Union from warnings import warn import numpy as np @@ -61,8 +62,8 @@ class MixUp(ReferenceBasedTransform): def __init__( self, - reference_data: Optional[Union[Generator[ReferenceImage, None, None], Sequence[ReferenceImage]]] = None, - read_fn: Callable[[ReferenceImage], Dict[str, Any]] = lambda x: {"image": x, "mask": None, "class_label": None}, + reference_data: Optional[Union[Generator[ReferenceImage, None, None], Sequence[Any]]] = None, + read_fn: Callable[[ReferenceImage], Any] = lambda x: {"image": x, "mask": None, "class_label": None}, alpha: float = 0.4, always_apply: bool = False, p: float = 0.5, @@ -79,8 +80,13 @@ def __init__( if reference_data is None: warn("No reference data provided for MixUp. This transform will act as a no-op.") # Create an empty generator - elif isinstance(reference_data, Iterable) and not isinstance(reference_data, str): - self.reference_data = reference_data + self.reference_data: List[Any] = [] + elif ( + isinstance(reference_data, types.GeneratorType) + or isinstance(reference_data, Iterable) + and not isinstance(reference_data, str) + ): + self.reference_data = reference_data # type: ignore[assignment] else: msg = "reference_data must be a list, tuple, generator, or None." raise TypeError(msg) @@ -120,19 +126,28 @@ def get_transform_init_args_names(self) -> Tuple[str, ...]: return "reference_data", "alpha" def get_params(self) -> Dict[str, Union[None, float, Dict[str, Any]]]: - if self.reference_data and isinstance(self.reference_data, Sequence): - mix_idx = random.randint(0, len(self.reference_data) - 1) - mix_data = self.reference_data[mix_idx] - elif self.reference_data and isinstance(self.reference_data, Iterator): + mix_data = None + # Check if reference_data is not empty and is a sequence (list, tuple, np.array) + if isinstance(self.reference_data, Sequence) and not isinstance(self.reference_data, (str, bytes)): + if len(self.reference_data) > 0: # Additional check to ensure it's not empty + mix_idx = random.randint(0, len(self.reference_data) - 1) + mix_data = self.reference_data[mix_idx] + # Check if reference_data is an iterator or generator + elif isinstance(self.reference_data, Iterator): try: - mix_data = next(self.reference_data) # Get the next item from the iterator + mix_data = next(self.reference_data) # Attempt to get the next item except StopIteration: warn( "Reference data iterator/generator has been exhausted. " "Further mixing augmentations will not be applied.", RuntimeWarning, ) - return {"mix_data": None, "mix_coef": 1} - mix_coef = beta(self.alpha, self.alpha) if mix_data else 1 + return {"mix_data": {}, "mix_coef": 1} - return {"mix_data": self.read_fn(mix_data) if mix_data else None, "mix_coef": mix_coef} + # If mix_data is None or empty after the above checks, return default values + if mix_data is None: + return {"mix_data": {}, "mix_coef": 1} + + # If mix_data is not None, calculate mix_coef and apply read_fn + mix_coef = beta(self.alpha, self.alpha) # Assuming beta is defined elsewhere + return {"mix_data": self.read_fn(mix_data), "mix_coef": mix_coef} diff --git a/tests/test_mixing.py b/tests/test_mixing.py index 32375c65b..f57b79433 100644 --- a/tests/test_mixing.py +++ b/tests/test_mixing.py @@ -22,14 +22,23 @@ def complex_read_fn_image(x): [(A.MixUp, { "reference_data": [{"image": np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8)}], "read_fn": lambda x: x}), + (A.MixUp, { + "reference_data": [1], + "read_fn": lambda x: {"image": np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8)}}, + ), + (A.MixUp, { + "reference_data": np.array([1]), + "read_fn": lambda x: {"image": np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8)}}, + ), + (A.MixUp, { + "reference_data": None, + }), (A.MixUp, { "reference_data": image_generator(), "read_fn": lambda x: x}), (A.MixUp, { "reference_data": complex_image_generator(), - "read_fn": complex_read_fn_image})] - -) + "read_fn": complex_read_fn_image})] ) def test_image_only(augmentation_cls, params, image): aug = augmentation_cls(p=1, **params) data = aug(image=image) @@ -40,7 +49,13 @@ def test_image_only(augmentation_cls, params, image): [(A.MixUp, { "reference_data": [{"image": np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8), "global_label": np.array([0, 0, 1])}], - "read_fn": lambda x: x})] + "read_fn": lambda x: x}), + (A.MixUp, { + "reference_data": [1], + "read_fn": lambda x: {"image": np.ones((100, 100, 3)).astype(np.uint8), + "global_label": np.array([0, 0, 1])}}, + ), + ] ) def test_image_global_label(augmentation_cls, params, image, global_label): aug = augmentation_cls(p=1, **params) @@ -49,8 +64,13 @@ def test_image_global_label(augmentation_cls, params, image, global_label): assert data["image"].dtype == np.uint8 - mix_coeff_image = find_mix_coef(data["image"], image, aug.reference_data[0]["image"]) - mix_coeff_label = find_mix_coef(data["global_label"], global_label, aug.reference_data[0]["global_label"]) + reference_item = params["read_fn"](aug.reference_data[0]) + + reference_image = reference_item["image"] + reference_global_label = reference_item["global_label"] + + mix_coeff_image = find_mix_coef(data["image"], image, reference_image) + mix_coeff_label = find_mix_coef(data["global_label"], global_label, reference_global_label) assert math.isclose(mix_coeff_image, mix_coeff_label, abs_tol=0.01) assert 0 <= mix_coeff_image <= 1