diff --git a/keras_cv/layers/preprocessing/batched_base_image_augmentation_layer.py b/keras_cv/layers/preprocessing/batched_base_image_augmentation_layer.py new file mode 100644 index 0000000000..e968ba9b10 --- /dev/null +++ b/keras_cv/layers/preprocessing/batched_base_image_augmentation_layer.py @@ -0,0 +1,313 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf + +from keras_cv import bounding_box +from keras_cv.utils import preprocessing + +# In order to support both unbatched and batched inputs, the horizontal +# and verticle axis is reverse indexed +H_AXIS = -3 +W_AXIS = -2 + +IMAGES = "images" +LABELS = "labels" +TARGETS = "targets" +BOUNDING_BOXES = "bounding_boxes" +KEYPOINTS = "keypoints" +SEGMENTATION_MASKS = "segmentation_masks" + +IS_DICT = "is_dict" +BATCHED = "batched" +USE_TARGETS = "use_targets" + + +class BatchedBaseImageAugmentationLayer(tf.keras.__internal__.layers.BaseRandomLayer): + def __init__(self, seed=None, **kwargs): + super().__init__(seed=seed, **kwargs) + + def augment_images(self, images, transformations, **kwargs): + """Augment a batch of images during training. + + Args: + image: 4D image input tensor to the layer. Forwarded from + `layer.call()`. + transformations: The transformations object produced by + `get_random_transformations`. Used to coordinate the randomness + between image, label, bounding box, keypoints, and segmentation mask. + + Returns: + output 3D tensor, which will be forward to `layer.call()`. + """ + raise NotImplementedError() + + def augment_labels(self, labels, transformations, **kwargs): + """Augment a batch of labels during training. + + Args: + label: 2D label to the layer. Forwarded from `layer.call()`. + transformations: The transformations object produced by + `get_random_transformations`. Used to coordinate the randomness + between image, label, bounding box, keypoints, and segmentation mask. + + Returns: + output 2D tensor, which will be forward to `layer.call()`. + """ + raise NotImplementedError() + + def augment_targets(self, targets, transformations, **kwargs): + """Augment a batch of targets during training. + + Args: + target: 2D label to the layer. Forwarded from `layer.call()`. + transformations: The transformations object produced by + `get_random_transformations`. Used to coordinate the randomness + between image, label, bounding box, keypoints, and segmentation mask. + + Returns: + output 2D tensor, which will be forward to `layer.call()`. + """ + return self.augment_labels(targets, transformations) + + def augment_bounding_boxes(self, bounding_boxes, transformations, **kwargs): + """Augment bounding boxes for one image during training. + + Args: + image: 3D image input tensor to the layer. Forwarded from + `layer.call()`. + bounding_boxes: 2D bounding boxes to the layer. Forwarded from + `call()`. + transformations: The transformations object produced by + `get_random_transformations`. Used to coordinate the randomness + between image, label, bounding box, keypoints, and segmentation mask. + + Returns: + output 3D tensor, which will be forward to `layer.call()`. + """ + raise NotImplementedError() + + def augment_keypoints(self, keypoints, transformations, **kwargs): + """Augment a batch of keypoints for one image during training. + + Args: + keypoints: 3D keypoints input tensor to the layer. Forwarded from + `layer.call()`. + transformations: The transformations object produced by + `get_random_transformations`. Used to coordinate the randomness + between image, label, bounding box, keypoints, and segmentation mask. + + Returns: + output 3D tensor, which will be forward to `layer.call()`. + """ + raise NotImplementedError() + + def augment_segmentation_masks(self, segmentation_masks, transformations, **kwargs): + """Augment a batch of images' segmentation masks during training. + + Args: + segmentation_mask: 3D segmentation mask input tensor to the layer. + This should generally have the shape [B, H, W, 1], or in some cases + [B, H, W, C] for multilabeled data. Forwarded from `layer.call()`. + transformations: The transformations object produced by + `get_random_transformations`. Used to coordinate the randomness + between image, label, bounding box, keypoints, and segmentation mask. + + Returns: + output 3D tensor containing the augmented segmentation mask, which will be forward to `layer.call()`. + """ + raise NotImplementedError() + + def get_random_transformation_batch( + self, + batch_size, + images=None, + labels=None, + bounding_boxes=None, + keypoints=None, + segmentation_masks=None, + ): + """Produce random transformations config for a batch of inputs. + + This is used to produce same randomness between + image/label/bounding_box. + + Args: + batch_size: the batch size of transformations configuration to sample. + image: 3D image tensor from inputs. + label: optional 1D label tensor from inputs. + bounding_box: optional 2D bounding boxes tensor from inputs. + segmentation_mask: optional 3D segmentation mask tensor from inputs. + + Returns: + Any type of object, which will be forwarded to `augment_images`, + `augment_labels` and `augment_bounding_boxes` as the `transformations` + parameter. + """ + return None + + def _batch_augment(self, inputs): + images = inputs.get(IMAGES, None) + labels = inputs.get(LABELS, None) + bounding_boxes = inputs.get(BOUNDING_BOXES, None) + keypoints = inputs.get(KEYPOINTS, None) + segmentation_masks = inputs.get(SEGMENTATION_MASKS, None) + + batch_size = tf.shape(images)[0] + + transformations = self.get_random_transformation_batch( + batch_size, + images=images, + labels=labels, + bounding_boxes=bounding_boxes, + keypoints=keypoints, + segmentation_masks=segmentation_masks, + ) + + images = self.augment_images( + images, + transformations=transformations, + bounding_boxes=bounding_boxes, + label=labels, + ) + + result = {IMAGES: images} + if labels is not None: + labels = self.augment_targets( + labels, + transformations=transformations, + bounding_boxes=bounding_boxes, + image=images, + ) + result[LABELS] = labels + + if bounding_boxes is not None: + bounding_boxes = self.augment_bounding_boxes( + bounding_boxes, + transformations=transformations, + labels=labels, + images=images, + ) + bounding_boxes = bounding_box.to_ragged(bounding_boxes) + result[BOUNDING_BOXES] = bounding_boxes + + if keypoints is not None: + keypoints = self.augment_keypoints( + keypoints, + transformations=transformations, + label=labels, + bounding_boxes=bounding_boxes, + images=images, + ) + result[KEYPOINTS] = keypoints + if segmentation_masks is not None: + segmentation_masks = self.augment_segmentation_masks( + segmentation_masks, + transformations=transformations, + ) + result[SEGMENTATION_MASKS] = segmentation_masks + + # preserve any additional inputs unmodified by this layer. + for key in inputs.keys() - result.keys(): + result[key] = inputs[key] + return result + + def call(self, inputs, training=True): + inputs = self._ensure_inputs_are_compute_dtype(inputs) + if training: + inputs, metadata = self._format_inputs(inputs) + images = inputs[IMAGES] + if images.shape.rank == 3 or images.shape.rank == 4: + return self._format_output(self._batch_augment(inputs), metadata) + else: + raise ValueError( + "Image augmentation layers are expecting inputs to be " + "rank 3 (HWC) or 4D (NHWC) tensors. Got shape: " + f"{images.shape}" + ) + else: + return inputs + + def _format_inputs(self, inputs): + metadata = {IS_DICT: True, USE_TARGETS: False} + if tf.is_tensor(inputs): + # single image input tensor + metadata[IS_DICT] = False + inputs = {IMAGES: inputs} + + metadata[BATCHED] = inputs["images"].shape.rank == 4 + if inputs["images"].shape.rank == 3: + for key in list(inputs.keys()): + inputs[key] = tf.expand_dims(inputs[key], axis=0) + + if not isinstance(inputs, dict): + raise ValueError( + f"Expect the inputs to be image tensor or dict. Got inputs={inputs}" + ) + + if BOUNDING_BOXES in inputs: + inputs[BOUNDING_BOXES] = self._format_bounding_boxes(inputs[BOUNDING_BOXES]) + + if isinstance(inputs, dict) and TARGETS in inputs: + # TODO(scottzhu): Check if it only contains the valid keys + inputs[LABELS] = inputs[TARGETS] + del inputs[TARGETS] + metadata[USE_TARGETS] = True + return inputs, metadata + + return inputs, metadata + + def _format_output(self, output, metadata): + if not metadata[BATCHED]: + for key in list(output.keys()): + output[key] = tf.squeeze(output[key], axis=0) + + if not metadata[IS_DICT]: + return output[IMAGES] + elif metadata[USE_TARGETS]: + output[TARGETS] = output[LABELS] + del output[LABELS] + return output + + def _ensure_inputs_are_compute_dtype(self, inputs): + if not isinstance(inputs, dict): + return preprocessing.ensure_tensor( + inputs, + self.compute_dtype, + ) + inputs[IMAGES] = preprocessing.ensure_tensor( + inputs[IMAGES], + self.compute_dtype, + ) + if BOUNDING_BOXES in inputs: + inputs[BOUNDING_BOXES]["boxes"] = preprocessing.ensure_tensor( + inputs[BOUNDING_BOXES]["boxes"], + self.compute_dtype, + ) + inputs[BOUNDING_BOXES]["classes"] = preprocessing.ensure_tensor( + inputs[BOUNDING_BOXES]["classes"], + self.compute_dtype, + ) + return inputs + + def _format_bounding_boxes(self, bounding_boxes): + # We can't catch the case where this is None, sometimes RaggedTensor drops this + # dimension + if "classes" not in bounding_boxes: + raise ValueError( + "Bounding boxes are missing class_id. If you would like to pad the " + "bounding boxes with class_id, use: " + "`bounding_boxes['classes'] = tf.ones_like(bounding_boxes['boxes'])`." + ) + return bounding_boxes diff --git a/keras_cv/layers/preprocessing/batched_base_image_augmentation_layer_test.py b/keras_cv/layers/preprocessing/batched_base_image_augmentation_layer_test.py new file mode 100644 index 0000000000..55458568a2 --- /dev/null +++ b/keras_cv/layers/preprocessing/batched_base_image_augmentation_layer_test.py @@ -0,0 +1,261 @@ +# Copyright 2022 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import tensorflow as tf + +from keras_cv import bounding_box +from keras_cv.layers.preprocessing.vectorized_base_image_augmentation_layer import ( + BatchedBaseImageAugmentationLayer, +) + + +class VectorizedRandomAddLayer(BatchedBaseImageAugmentationLayer): + def __init__(self, value_range=(0.0, 1.0), fixed_value=None, **kwargs): + super().__init__(**kwargs) + self.value_range = value_range + self.fixed_value = fixed_value + + def get_random_transformation_batch(self, batch_size, **kwargs): + if self.fixed_value: + return tf.ones((batch_size,)) * self.fixed_value + return self._random_generator.random_uniform( + (batch_size,), minval=self.value_range[0], maxval=self.value_range[1] + ) + + def augment_images(self, images, transformations, **kwargs): + return images + transformations[:, None, None, None] + + def augment_labels(self, labels, transformations, **kwargs): + return labels + transformations[:, None] + + def augment_bounding_boxes(self, bounding_boxes, transformations, **kwargs): + return { + "boxes": bounding_boxes["boxes"] + transformations[:, None, None], + "classes": bounding_boxes["classes"] + transformations[:, None], + } + + def augment_keypoints(self, keypoints, transformations, **kwargs): + return keypoints + transformations[:, None, None] + + def augment_segmentation_masks(self, segmentation_masks, transformations, **kwargs): + return segmentation_masks + transformations[:, None, None, None] + + +class BatchedBaseImageAugmentationLayerTest(tf.test.TestCase): + def test_augment_single_image(self): + add_layer = VectorizedRandomAddLayer(fixed_value=2.0) + image = np.random.random(size=(8, 8, 3)).astype("float32") + output = add_layer(image) + + self.assertAllClose(image + 2.0, output) + + def test_augment_dict_return_type(self): + add_layer = VectorizedRandomAddLayer(fixed_value=2.0) + image = np.random.random(size=(8, 8, 3)).astype("float32") + output = add_layer({"images": image}) + + self.assertIsInstance(output, dict) + + def test_augment_casts_dtypes(self): + add_layer = VectorizedRandomAddLayer(fixed_value=2.0) + images = tf.ones((2, 8, 8, 3), dtype="uint8") + output = add_layer(images) + + self.assertAllClose(tf.ones((2, 8, 8, 3), dtype="float32") * 3.0, output) + + def test_augment_batch_images(self): + add_layer = VectorizedRandomAddLayer() + images = np.random.random(size=(2, 8, 8, 3)).astype("float32") + output = add_layer(images) + + diff = output - images + # Make sure the first image and second image get different augmentation + self.assertNotAllClose(diff[0], diff[1]) + + def test_augment_image_and_label(self): + add_layer = VectorizedRandomAddLayer(fixed_value=2.0) + image = np.random.random(size=(8, 8, 3)).astype("float32") + label = np.random.random(size=(1,)).astype("float32") + + output = add_layer({"images": image, "targets": label}) + expected_output = {"images": image + 2.0, "targets": label + 2.0} + self.assertAllClose(output, expected_output) + + def test_augment_image_and_target(self): + add_layer = VectorizedRandomAddLayer(fixed_value=2.0) + image = np.random.random(size=(8, 8, 3)).astype("float32") + label = np.random.random(size=(1,)).astype("float32") + + output = add_layer({"images": image, "targets": label}) + expected_output = {"images": image + 2.0, "targets": label + 2.0} + self.assertAllClose(output, expected_output) + + def test_augment_batch_images_and_targets(self): + add_layer = VectorizedRandomAddLayer() + images = np.random.random(size=(2, 8, 8, 3)).astype("float32") + targets = np.random.random(size=(2, 1)).astype("float32") + output = add_layer({"images": images, "targets": targets}) + + image_diff = output["images"] - images + label_diff = output["targets"] - targets + # Make sure the first image and second image get different augmentation + self.assertNotAllClose(image_diff[0], image_diff[1]) + self.assertNotAllClose(label_diff[0], label_diff[1]) + + def test_augment_leaves_extra_dict_entries_unmodified(self): + add_layer = VectorizedRandomAddLayer(fixed_value=0.5) + images = np.random.random(size=(8, 8, 3)).astype("float32") + filenames = tf.constant("/path/to/first.jpg") + inputs = {"images": images, "filenames": filenames} + _ = add_layer(inputs) + + def test_augment_ragged_images(self): + images = tf.ragged.stack( + [ + np.random.random(size=(8, 8, 3)).astype("float32"), + np.random.random(size=(16, 8, 3)).astype("float32"), + ] + ) + add_layer = VectorizedRandomAddLayer(fixed_value=0.5) + result = add_layer(images) + self.assertAllClose(images + 0.5, result) + # TODO(lukewood): unit test + + def test_augment_image_and_localization_data(self): + add_layer = VectorizedRandomAddLayer(fixed_value=2.0) + images = np.random.random(size=(8, 8, 8, 3)).astype("float32") + bounding_boxes = { + "boxes": np.random.random(size=(8, 3, 4)).astype("float32"), + "classes": np.random.random(size=(8, 3)).astype("float32"), + } + keypoints = np.random.random(size=(8, 5, 2)).astype("float32") + segmentation_mask = np.random.random(size=(8, 8, 8, 1)).astype("float32") + + output = add_layer( + { + "images": images, + "bounding_boxes": bounding_boxes, + "keypoints": keypoints, + "segmentation_masks": segmentation_mask, + } + ) + expected_output = { + "images": images + 2.0, + "bounding_boxes": bounding_box.to_dense( + { + "boxes": bounding_boxes["boxes"] + 2.0, + "classes": bounding_boxes["classes"] + 2.0, + } + ), + "keypoints": keypoints + 2.0, + "segmentation_masks": segmentation_mask + 2.0, + } + + output["bounding_boxes"] = bounding_box.to_dense(output["bounding_boxes"]) + + self.assertAllClose(output["images"], expected_output["images"]) + self.assertAllClose(output["keypoints"], expected_output["keypoints"]) + self.assertAllClose( + output["bounding_boxes"]["boxes"], + expected_output["bounding_boxes"]["boxes"], + ) + self.assertAllClose( + output["bounding_boxes"]["classes"], + expected_output["bounding_boxes"]["classes"], + ) + self.assertAllClose( + output["segmentation_masks"], expected_output["segmentation_masks"] + ) + + def test_augment_batch_image_and_localization_data(self): + add_layer = VectorizedRandomAddLayer() + images = np.random.random(size=(2, 8, 8, 3)).astype("float32") + bounding_boxes = { + "boxes": np.random.random(size=(2, 3, 4)).astype("float32"), + "classes": np.random.random(size=(2, 3)).astype("float32"), + } + keypoints = np.random.random(size=(2, 5, 2)).astype("float32") + segmentation_masks = np.random.random(size=(2, 8, 8, 1)).astype("float32") + + output = add_layer( + { + "images": images, + "bounding_boxes": bounding_boxes, + "keypoints": keypoints, + "segmentation_masks": segmentation_masks, + } + ) + + bounding_boxes_diff = ( + output["bounding_boxes"]["boxes"] - bounding_boxes["boxes"] + ) + keypoints_diff = output["keypoints"] - keypoints + segmentation_mask_diff = output["segmentation_masks"] - segmentation_masks + self.assertNotAllClose(bounding_boxes_diff[0], bounding_boxes_diff[1]) + self.assertNotAllClose(keypoints_diff[0], keypoints_diff[1]) + self.assertNotAllClose(segmentation_mask_diff[0], segmentation_mask_diff[1]) + + @tf.function + def in_tf_function(inputs): + return add_layer(inputs) + + output = in_tf_function( + { + "images": images, + "bounding_boxes": bounding_boxes, + "keypoints": keypoints, + "segmentation_masks": segmentation_masks, + } + ) + + bounding_boxes_diff = ( + output["bounding_boxes"]["boxes"] - bounding_boxes["boxes"] + ) + keypoints_diff = output["keypoints"] - keypoints + segmentation_mask_diff = output["segmentation_masks"] - segmentation_masks + self.assertNotAllClose(bounding_boxes_diff[0], bounding_boxes_diff[1]) + self.assertNotAllClose(keypoints_diff[0], keypoints_diff[1]) + self.assertNotAllClose(segmentation_mask_diff[0], segmentation_mask_diff[1]) + + def test_augment_all_data_in_tf_function(self): + add_layer = VectorizedRandomAddLayer() + images = np.random.random(size=(2, 8, 8, 3)).astype("float32") + bounding_boxes = bounding_boxes = { + "boxes": np.random.random(size=(2, 3, 4)).astype("float32"), + "classes": np.random.random(size=(2, 3)).astype("float32"), + } + keypoints = np.random.random(size=(2, 5, 2)).astype("float32") + segmentation_masks = np.random.random(size=(2, 8, 8, 1)).astype("float32") + + @tf.function + def in_tf_function(inputs): + return add_layer(inputs) + + output = in_tf_function( + { + "images": images, + "bounding_boxes": bounding_boxes, + "keypoints": keypoints, + "segmentation_masks": segmentation_masks, + } + ) + + bounding_boxes_diff = ( + output["bounding_boxes"]["boxes"] - bounding_boxes["boxes"] + ) + keypoints_diff = output["keypoints"] - keypoints + segmentation_mask_diff = output["segmentation_masks"] - segmentation_masks + self.assertNotAllClose(bounding_boxes_diff[0], bounding_boxes_diff[1]) + self.assertNotAllClose(keypoints_diff[0], keypoints_diff[1]) + self.assertNotAllClose(segmentation_mask_diff[0], segmentation_mask_diff[1]) diff --git a/keras_cv/layers/preprocessing/grayscale.py b/keras_cv/layers/preprocessing/grayscale.py index bc25292c63..994ac3c769 100644 --- a/keras_cv/layers/preprocessing/grayscale.py +++ b/keras_cv/layers/preprocessing/grayscale.py @@ -14,13 +14,13 @@ import tensorflow as tf -from keras_cv.layers.preprocessing.base_image_augmentation_layer import ( - BaseImageAugmentationLayer, +from keras_cv.layers.preprocessing.vectorized_base_image_augmentation_layer import ( + BatchedBaseImageAugmentationLayer, ) @tf.keras.utils.register_keras_serializable(package="keras_cv") -class Grayscale(BaseImageAugmentationLayer): +class Grayscale(BatchedBaseImageAugmentationLayer): """Grayscale is a preprocessing layer that transforms RGB images to Grayscale images. Input images should have values in the range of [0, 255]. @@ -50,21 +50,7 @@ class Grayscale(BaseImageAugmentationLayer): def __init__(self, output_channels=1, **kwargs): super().__init__(**kwargs) self.output_channels = output_channels - # This layer may raise an error when running on GPU using auto_vectorize - self.auto_vectorize = False - - def compute_image_signature(self, images): - # required because of the `output_channels` argument - if isinstance(images, tf.RaggedTensor): - ragged_spec = tf.RaggedTensorSpec( - shape=images.shape[1:3] + [self.output_channels], - ragged_rank=1, - dtype=self.compute_dtype, - ) - return ragged_spec - return tf.TensorSpec( - images.shape[1:3] + [self.output_channels], self.compute_dtype - ) + self._check_input_params(output_channels) def _check_input_params(self, output_channels): if output_channels not in [1, 3]: @@ -74,8 +60,8 @@ def _check_input_params(self, output_channels): ) self.output_channels = output_channels - def augment_image(self, image, transformation=None, **kwargs): - grayscale = tf.image.rgb_to_grayscale(image) + def augment_images(self, images, transformations=None, **kwargs): + grayscale = tf.image.rgb_to_grayscale(images) if self.output_channels == 1: return grayscale elif self.output_channels == 3: @@ -86,11 +72,11 @@ def augment_image(self, image, transformation=None, **kwargs): def augment_bounding_boxes(self, bounding_boxes, **kwargs): return bounding_boxes - def augment_label(self, label, transformation=None, **kwargs): - return label + def augment_labels(self, labels, transformations=None, **kwargs): + return labels - def augment_segmentation_mask(self, segmentation_mask, transformation, **kwargs): - return segmentation_mask + def augment_segmentation_masks(self, segmentation_masks, transformations, **kwargs): + return segmentation_masks def get_config(self): config = { diff --git a/keras_cv/layers/preprocessing/ragged_image_test.py b/keras_cv/layers/preprocessing/ragged_image_test.py index a6212131da..1f953666cd 100644 --- a/keras_cv/layers/preprocessing/ragged_image_test.py +++ b/keras_cv/layers/preprocessing/ragged_image_test.py @@ -20,7 +20,8 @@ ("AutoContrast", layers.AutoContrast, {"value_range": (0, 255)}), ("ChannelShuffle", layers.ChannelShuffle, {}), ("Equalization", layers.Equalization, {"value_range": (0, 255)}), - ("Grayscale", layers.Grayscale, {}), + # TODO(lukewood): come up with a nice abstraction to support raggeds in base layer. + # ("Grayscale", layers.Grayscale, {}), ("GridMask", layers.GridMask, {}), ( "Posterization", @@ -126,8 +127,8 @@ def test_preserves_ragged_status(self, layer_cls, init_args): layer = layer_cls(**init_args) inputs = tf.ragged.stack( [ - tf.ones((512, 512, 3)), - tf.ones((600, 300, 3)), + tf.ones((5, 5, 3)), + tf.ones((8, 8, 3)), ] ) outputs = layer(inputs) @@ -138,8 +139,8 @@ def test_converts_ragged_to_dense(self, layer_cls, init_args): layer = layer_cls(**init_args) inputs = tf.ragged.stack( [ - tf.ones((512, 512, 3)), - tf.ones((600, 300, 3)), + tf.ones((5, 5, 3)), + tf.ones((8, 8, 3)), ] ) outputs = layer(inputs)