diff --git a/src/featureforest/utils/data.py b/src/featureforest/utils/data.py index c7aace6..dfa73a8 100644 --- a/src/featureforest/utils/data.py +++ b/src/featureforest/utils/data.py @@ -247,7 +247,7 @@ def is_image_rgb(image_data: ndarray) -> bool: Returns: bool: is image RGB(A)? """ - return image_data.shape[-1] >= 3 + return image_data.shape[-1] in [3, 4] def is_stacked(image_data: ndarray) -> bool: