From e8894779dfe8881030510463d304b0ae7c14f533 Mon Sep 17 00:00:00 2001 From: Sarvagya Malaviya <45961148+quantumalaviya@users.noreply.github.com> Date: Tue, 11 Apr 2023 23:24:57 +0530 Subject: [PATCH] [YOLOX] Step 2/? : Setting up YoloX structure, add internal layers and update iou losses (#1296) * first attempt at introducing YoloX * formatted and fixed bugs * cast fix #1 * cast fix #2 * cast fix #3 * cast fix #4 * adding ensure shape for support * reverting and removing ensure_shape * fixed another bug * updated train.py * updated docs, tests and added support for loss strings * first attempt at introducing YoloX * formatted and fixed bugs * adding ensure shape for support * reverting and removing ensure_shape * reformatted by black * fixed a linting issue * finally rebased atop the recent changes * finally rebased atop the new changes * fixed linting issues * reverted rebasing issues with iou loss * fixing rebased errors part 2 * fixed more linting issues * TPU testing changes * linting fixes * updated with implementation details from paper * updated based on review comments and api changes * first attempt at introducing YoloX * updated docs, tests and added support for loss strings * fixed linting issues * reverted rebasing issues with iou loss * review comments * removed examples * linting fix * fixed rebasing error * updated no_reduction warning * review comments * revert version and linting fixes --- keras_cv/losses/iou_loss.py | 7 + .../models/object_detection/yolox/__init__.py | 13 ++ .../yolox/binary_crossentropy.py | 97 ++++++++++ .../object_detection/yolox/layers/__init__.py | 22 +++ .../yolox/layers/yolox_decoder.py | 168 ++++++++++++++++++ .../yolox/layers/yolox_head.py | 155 ++++++++++++++++ .../yolox/layers/yolox_head_test.py | 55 ++++++ .../yolox/layers/yolox_label_encoder.py | 49 +++++ .../yolox/layers/yolox_label_encoder_test.py | 83 +++++++++ .../yolox/layers/yolox_pafpn.py | 139 +++++++++++++++ .../yolox/layers/yolox_pafpn_test.py | 52 ++++++ 11 files changed, 840 insertions(+) create mode 100644 keras_cv/models/object_detection/yolox/__init__.py create mode 100644 keras_cv/models/object_detection/yolox/binary_crossentropy.py create mode 100644 keras_cv/models/object_detection/yolox/layers/__init__.py create mode 100644 keras_cv/models/object_detection/yolox/layers/yolox_decoder.py create mode 100644 keras_cv/models/object_detection/yolox/layers/yolox_head.py create mode 100644 keras_cv/models/object_detection/yolox/layers/yolox_head_test.py create mode 100644 keras_cv/models/object_detection/yolox/layers/yolox_label_encoder.py create mode 100644 keras_cv/models/object_detection/yolox/layers/yolox_label_encoder_test.py create mode 100644 keras_cv/models/object_detection/yolox/layers/yolox_pafpn.py create mode 100644 keras_cv/models/object_detection/yolox/layers/yolox_pafpn_test.py diff --git a/keras_cv/losses/iou_loss.py b/keras_cv/losses/iou_loss.py index 36f2642a64..4e6a568ad9 100644 --- a/keras_cv/losses/iou_loss.py +++ b/keras_cv/losses/iou_loss.py @@ -96,6 +96,13 @@ def call(self, y_true, y_pred): f"bounding boxes. Received y_true.shape[-1]={y_true.shape[-1]}." ) + if y_true.shape[-2] < y_pred.shape[-2]: + raise ValueError( + "IoULoss expects number of boxes in y_pred to be equal to the number " + f"of boxes in y_true. Received number of boxes in y_true={y_true.shape[-2]} " + f"and number of boxes in y_pred={y_pred.shape[-2]}." + ) + if y_true.shape[-2] != y_pred.shape[-2]: raise ValueError( "IoULoss expects number of boxes in y_pred to be equal to the " diff --git a/keras_cv/models/object_detection/yolox/__init__.py b/keras_cv/models/object_detection/yolox/__init__.py new file mode 100644 index 0000000000..3992ffb59a --- /dev/null +++ b/keras_cv/models/object_detection/yolox/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_cv/models/object_detection/yolox/binary_crossentropy.py b/keras_cv/models/object_detection/yolox/binary_crossentropy.py new file mode 100644 index 0000000000..a2594c65a5 --- /dev/null +++ b/keras_cv/models/object_detection/yolox/binary_crossentropy.py @@ -0,0 +1,97 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import warnings + +import tensorflow as tf + + +class BinaryCrossentropy(tf.keras.losses.Loss): + """Computes the cross-entropy loss between true labels and predicted labels. + + Use this cross-entropy loss for binary (0 or 1) classification applications. + This loss is updated for YoloX by offering support for no axis to mean over. + + Args: + from_logits: Whether to interpret `y_pred` as a tensor of + [logit](https://en.wikipedia.org/wiki/Logit) values. By default, we + assume that `y_pred` contains probabilities (i.e., values in [0, + 1]). + label_smoothing: Float in [0, 1]. When 0, no smoothing occurs. When > + 0, we compute the loss between the predicted labels and a smoothed + version of the true labels, where the smoothing squeezes the labels + towards 0.5. Larger values of `label_smoothing` correspond to + heavier smoothing. + axis: the axis along which to mean the ious. Defaults to `no_reduction` which implies + mean across no axes. + + Usage: + ```python + model.compile( + loss=keras_cv.models.object_detection.yolox.binary_crossentropy.BinaryCrossentropy(from_logits=True) + .... + ) + ``` + """ + + def __init__( + self, from_logits=False, label_smoothing=0.0, axis=None, **kwargs + ): + super().__init__(**kwargs) + self.from_logits = from_logits + self.label_smoothing = label_smoothing + self.axis = axis + + def call(self, y_true, y_pred): + y_pred = tf.convert_to_tensor(y_pred) + y_true = tf.cast(y_true, y_pred.dtype) + label_smoothing = tf.convert_to_tensor( + self.label_smoothing, dtype=y_pred.dtype + ) + + def _smooth_labels(): + return y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing + + y_true = tf.__internal__.smart_cond.smart_cond( + label_smoothing, _smooth_labels, lambda: y_true + ) + + if self.axis == "no_reduction": + warnings.warn( + "`axis='no_reduction'` is a temporary API, and the API contract " + "will be replaced in the future with a more generic solution " + "covering all losses." + ) + return tf.reduce_mean( + tf.keras.backend.binary_crossentropy( + y_true, y_pred, from_logits=self.from_logits + ), + axis=self.axis, + ) + + return tf.keras.backend.binary_crossentropy( + y_true, y_pred, from_logits=self.from_logits + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "from_logits": self.from_logits, + "label_smoothing": self.label_smoothing, + "axis": self.axis, + } + ) + return config diff --git a/keras_cv/models/object_detection/yolox/layers/__init__.py b/keras_cv/models/object_detection/yolox/layers/__init__.py new file mode 100644 index 0000000000..107862db7b --- /dev/null +++ b/keras_cv/models/object_detection/yolox/layers/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_cv.models.object_detection.yolox.layers.yolox_decoder import ( + YoloXPredictionDecoder, +) +from keras_cv.models.object_detection.yolox.layers.yolox_head import YoloXHead +from keras_cv.models.object_detection.yolox.layers.yolox_label_encoder import ( + YoloXLabelEncoder, +) +from keras_cv.models.object_detection.yolox.layers.yolox_pafpn import YoloXPAFPN diff --git a/keras_cv/models/object_detection/yolox/layers/yolox_decoder.py b/keras_cv/models/object_detection/yolox/layers/yolox_decoder.py new file mode 100644 index 0000000000..c8f5d75ea0 --- /dev/null +++ b/keras_cv/models/object_detection/yolox/layers/yolox_decoder.py @@ -0,0 +1,168 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import tensorflow as tf +from tensorflow import keras + +import keras_cv.layers as cv_layers +from keras_cv import bounding_box + + +class YoloXPredictionDecoder(keras.layers.Layer): + """Decodes the predictions from YoloX head. + + This layer is similar to the decoding code in `YoloX.compute_losses`. This is + followed by a bounding box suppression layer. + + Arguments: + bounding_box_format: The format of bounding boxes of input dataset. Refer + [to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/) + for more details on supported bounding box formats. + num_classes: The number of classes to be considered for the classification head. + suppression_layer: A `keras.layers.Layer` that follows the same API + signature of the `keras_cv.layers.MultiClassNonMaxSuppression` layer. + This layer should perform a suppression operation such as Non Max Suppression, + or Soft Non-Max Suppression. + """ + + def __init__( + self, bounding_box_format, num_classes, suppression_layer=None, **kwargs + ): + super().__init__(**kwargs) + self.bounding_box_format = bounding_box_format + self.num_classes = num_classes + + self.suppression_layer = ( + suppression_layer + or cv_layers.MultiClassNonMaxSuppression( + bounding_box_format=bounding_box_format, + from_logits=False, + confidence_threshold=0.01, + iou_threshold=0.65, + max_detections=100, + max_detections_per_class=100, + ) + ) + if ( + self.suppression_layer.bounding_box_format + != self.bounding_box_format + ): + raise ValueError( + "`suppression_layer` must have the same `bounding_box_format` " + "as the `YoloXPredictionDecoder()` layer. " + "Received `YoloXPredictionDecoder.bounding_box_format=" + f"{self.bounding_box_format}`, `suppression_layer={suppression_layer}`." + ) + self.built = True + + def call(self, images, predictions): + image_shape = tf.cast(tf.shape(images), dtype=self.compute_dtype)[1:-1] + + batch_size = tf.shape(predictions[0])[0] + + grids = [] + strides = [] + + shapes = [x.shape[1:3] for x in predictions] + + # 5 + self.num_classes is a concatenation of bounding boxes (length=4) + # + objectness score (length=1) + num_classes + # this reshape is simply collapsing axes 1 and 2 of x into a single dimension + predictions = [ + tf.reshape(x, [batch_size, -1, 5 + self.num_classes]) + for x in predictions + ] + predictions = tf.cast( + tf.concat(predictions, axis=1), dtype=self.compute_dtype + ) + predictions_shape = tf.cast( + tf.shape(predictions), dtype=self.compute_dtype + ) + + for i in range(len(shapes)): + shape_x, shape_y = shapes[i] + grid_x, grid_y = tf.meshgrid(tf.range(shape_y), tf.range(shape_x)) + grid = tf.reshape(tf.stack((grid_x, grid_y), 2), (1, -1, 2)) + shape = grid.shape[:2] + + grids.append(tf.cast(grid, self.compute_dtype)) + strides.append( + tf.ones((shape[0], shape[1], 1)) + * image_shape[0] + / tf.cast(shape_x, self.compute_dtype) + ) + + grids = tf.concat(grids, axis=1) + strides = tf.concat(strides, axis=1) + + box_xy = tf.expand_dims( + (predictions[..., :2] + grids) * strides / image_shape, axis=-2 + ) + box_xy = tf.broadcast_to( + box_xy, [batch_size, predictions_shape[1], self.num_classes, 2] + ) + box_wh = tf.expand_dims( + tf.exp(predictions[..., 2:4]) * strides / image_shape, axis=-2 + ) + box_wh = tf.broadcast_to( + box_wh, [batch_size, predictions_shape[1], self.num_classes, 2] + ) + + box_confidence = tf.math.sigmoid(predictions[..., 4:5]) + box_class_probs = tf.math.sigmoid(predictions[..., 5:]) + + # create and broadcast classes for every box before nms + box_classes = tf.expand_dims( + tf.range(self.num_classes, dtype=self.compute_dtype), axis=-1 + ) + box_classes = tf.broadcast_to( + box_classes, [batch_size, predictions_shape[1], self.num_classes, 1] + ) + + box_scores = tf.expand_dims(box_confidence * box_class_probs, axis=-1) + + outputs = tf.concat([box_xy, box_wh, box_classes, box_scores], axis=-1) + outputs = tf.reshape(outputs, [batch_size, -1, 6]) + + outputs = { + "boxes": outputs[..., :4], + "classes": outputs[..., 4], + "confidence": outputs[..., 5], + } + + # this conversion is rel_center_xywh to rel_xywh + # small workaround because rel_center_xywh isn't supported yet + outputs = bounding_box.convert_format( + outputs, + source="center_xywh", + target="xywh", + images=images, + ) + outputs = bounding_box.convert_format( + outputs, + source="rel_xywh", + target=self.suppression_layer.bounding_box_format, + images=images, + ) + + # preparing the predictions for TF NMS op + class_predictions = tf.cast(outputs["classes"], tf.int32) + class_predictions = tf.one_hot(class_predictions, self.num_classes) + + scores = ( + tf.expand_dims(outputs["confidence"], axis=-1) * class_predictions + ) + + return self.suppression_layer(outputs["boxes"], scores) diff --git a/keras_cv/models/object_detection/yolox/layers/yolox_head.py b/keras_cv/models/object_detection/yolox/layers/yolox_head.py new file mode 100644 index 0000000000..69b78bdfe8 --- /dev/null +++ b/keras_cv/models/object_detection/yolox/layers/yolox_head.py @@ -0,0 +1,155 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import tensorflow as tf +from tensorflow import keras + +from keras_cv.models.__internal__.darknet_utils import DarknetConvBlock +from keras_cv.models.__internal__.darknet_utils import DarknetConvBlockDepthwise + + +class YoloXHead(keras.layers.Layer): + """The YoloX prediction head. + + Arguments: + num_classes: The number of classes to be considered for the classification head. + bias_initializer: Bias Initializer for the final convolution layer for the + classification and regression heads. Defaults to None. + width_multiplier: A float value used to calculate the base width of the model + this changes based on the detection model being used. Defaults to 1.0. + num_level: the number of levels in the FPN output. Defaults to 3. + activation: the activation applied after the BatchNorm layer. One of "silu", + "relu" or "leaky_relu". Defaults to "silu". + use_depthwise: a boolean value used to decide whether a depthwise conv block + should be used over a regular darknet block. Defaults to False + """ + + def __init__( + self, + num_classes, + bias_initializer=None, + width_multiplier=1.0, + num_level=3, + activation="silu", + use_depthwise=False, + **kwargs, + ): + super().__init__(**kwargs) + self.stems = [] + + self.classification_convs = [] + self.regression_convs = [] + + self.classification_preds = [] + self.regression_preds = [] + self.objectness_preds = [] + + ConvBlock = ( + DarknetConvBlockDepthwise if use_depthwise else DarknetConvBlock + ) + + for _ in range(num_level): + self.stems.append( + DarknetConvBlock( + filters=int(256 * width_multiplier), + kernel_size=1, + strides=1, + activation=activation, + ) + ) + + self.classification_convs.append( + keras.Sequential( + [ + ConvBlock( + filters=int(256 * width_multiplier), + kernel_size=3, + strides=1, + activation=activation, + ), + ConvBlock( + filters=int(256 * width_multiplier), + kernel_size=3, + strides=1, + activation=activation, + ), + ] + ) + ) + + self.regression_convs.append( + keras.Sequential( + [ + ConvBlock( + filters=int(256 * width_multiplier), + kernel_size=3, + strides=1, + activation=activation, + ), + ConvBlock( + filters=int(256 * width_multiplier), + kernel_size=3, + strides=1, + activation=activation, + ), + ] + ) + ) + + self.classification_preds.append( + keras.layers.Conv2D( + filters=num_classes, + kernel_size=1, + strides=1, + padding="same", + bias_initializer=bias_initializer, + ) + ) + self.regression_preds.append( + keras.layers.Conv2D( + filters=4, + kernel_size=1, + strides=1, + padding="same", + bias_initializer=bias_initializer, + ) + ) + self.objectness_preds.append( + keras.layers.Conv2D( + filters=1, + kernel_size=1, + strides=1, + padding="same", + ) + ) + + def call(self, inputs, training=False): + outputs = [] + + for i, p_i in enumerate(inputs): + stem = self.stems[i](p_i) + + classes = self.classification_convs[i](stem) + classes = self.classification_preds[i](classes) + + boxes_feat = self.regression_convs[i](stem) + boxes = self.regression_preds[i](boxes_feat) + objectness = self.objectness_preds[i](boxes_feat) + + output = tf.keras.layers.Concatenate(axis=-1)( + [boxes, objectness, classes] + ) + outputs.append(output) + return outputs diff --git a/keras_cv/models/object_detection/yolox/layers/yolox_head_test.py b/keras_cv/models/object_detection/yolox/layers/yolox_head_test.py new file mode 100644 index 0000000000..dfd3886093 --- /dev/null +++ b/keras_cv/models/object_detection/yolox/layers/yolox_head_test.py @@ -0,0 +1,55 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import tensorflow as tf + +from keras_cv.models.object_detection.yolox.layers import YoloXHead + + +class YoloXHeadTest(tf.test.TestCase): + def test_num_parameters(self): + input1 = tf.keras.Input((80, 80, 256)) + input2 = tf.keras.Input((40, 40, 512)) + input3 = tf.keras.Input((20, 20, 1024)) + + output = YoloXHead(20)([input1, input2, input3]) + + model = tf.keras.models.Model( + inputs=[input1, input2, input3], outputs=output + ) + + keras_params = sum( + [tf.keras.backend.count_params(p) for p in model.trainable_weights] + ) + # taken from original implementation + original_params = 7563595 + + self.assertEqual(keras_params, original_params) + + def test_output_type_and_shape(self): + inputs = [ + tf.random.uniform((3, 80, 80, 256)), + tf.random.uniform((3, 40, 40, 512)), + tf.random.uniform((3, 20, 20, 1024)), + ] + + output = YoloXHead(20)(inputs) + + self.assertEqual(type(output), list) + self.assertEqual(len(output), 3) + + self.assertEqual(output[0].shape, [3, 80, 80, 25]) + self.assertEqual(output[1].shape, [3, 40, 40, 25]) + self.assertEqual(output[2].shape, [3, 20, 20, 25]) diff --git a/keras_cv/models/object_detection/yolox/layers/yolox_label_encoder.py b/keras_cv/models/object_detection/yolox/layers/yolox_label_encoder.py new file mode 100644 index 0000000000..ef4b9353e5 --- /dev/null +++ b/keras_cv/models/object_detection/yolox/layers/yolox_label_encoder.py @@ -0,0 +1,49 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import tensorflow as tf +from tensorflow.keras import layers + +from keras_cv import bounding_box + + +class YoloXLabelEncoder(layers.Layer): + """Transforms the raw labels into targets for training.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def call(self, images, box_labels): + """Creates box and classification targets for a batch""" + if isinstance(images, tf.RaggedTensor): + raise ValueError( + "`YoloXLabelEncoder`'s `call()` method does not " + "support RaggedTensor inputs for the `images` argument. Received " + f"`type(images)={type(images)}`." + ) + + if box_labels["classes"].get_shape().rank != 2: + raise ValueError( + "`YoloXLabelEncoder`'s `call()` method expects a label encoded " + "`box_labels['classes']` argument of shape `(batch_size, num_boxes)`. " + f"`Received box_labels['classes'].shape={box_labels['classes'].shape}`." + ) + + box_labels = bounding_box.to_dense(box_labels) + box_labels["classes"] = box_labels["classes"][..., tf.newaxis] + + encoded_box_targets = box_labels["boxes"] + class_targets = box_labels["classes"] + return encoded_box_targets, class_targets diff --git a/keras_cv/models/object_detection/yolox/layers/yolox_label_encoder_test.py b/keras_cv/models/object_detection/yolox/layers/yolox_label_encoder_test.py new file mode 100644 index 0000000000..6bcd2dcf1e --- /dev/null +++ b/keras_cv/models/object_detection/yolox/layers/yolox_label_encoder_test.py @@ -0,0 +1,83 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf + +from keras_cv.models.object_detection.yolox.layers import YoloXLabelEncoder + + +class YoloXLabelEncoderTest(tf.test.TestCase): + def test_ragged_images_exception(self): + img1 = tf.random.uniform((10, 11, 3)) + img2 = tf.random.uniform((9, 14, 3)) + img3 = tf.random.uniform((7, 12, 3)) + + images = tf.ragged.stack([img1, img2, img3]) + box_labels = {} + box_labels["bounding_boxes"] = tf.random.uniform((3, 4, 4)) + box_labels["classes"] = tf.random.uniform( + (3, 4), maxval=20, dtype=tf.int32 + ) + layer = YoloXLabelEncoder() + + with self.assertRaisesRegexp( + ValueError, + "method does not support RaggedTensor inputs for the `images` argument.", + ): + layer(images, box_labels) + + def test_ragged_labels(self): + images = tf.random.uniform((3, 12, 12, 3)) + + box_labels = {} + + box1 = tf.random.uniform((11, 4)) + class1 = tf.random.uniform([11], maxval=20, dtype=tf.int32) + box2 = tf.random.uniform((14, 4)) + class2 = tf.random.uniform([14], maxval=20, dtype=tf.int32) + box3 = tf.random.uniform((12, 4)) + class3 = tf.random.uniform([12], maxval=20, dtype=tf.int32) + + box_labels["boxes"] = tf.ragged.stack([box1, box2, box3]) + box_labels["classes"] = tf.ragged.stack([class1, class2, class3]) + + layer = YoloXLabelEncoder() + + encoded_boxes, _ = layer(images, box_labels) + self.assertEqual(encoded_boxes.shape, (3, 14, 4)) + + def test_one_hot_classes_exception(self): + images = tf.random.uniform((3, 12, 12, 3)) + + box_labels = {} + + box1 = tf.random.uniform((11, 4)) + class1 = tf.random.uniform([11], maxval=20, dtype=tf.int32) + class1 = tf.one_hot(class1, 20) + + box2 = tf.random.uniform((14, 4)) + class2 = tf.random.uniform([14], maxval=20, dtype=tf.int32) + class2 = tf.one_hot(class2, 20) + + box3 = tf.random.uniform((12, 4)) + class3 = tf.random.uniform([12], maxval=20, dtype=tf.int32) + class3 = tf.one_hot(class3, 20) + + box_labels["boxes"] = tf.ragged.stack([box1, box2, box3]) + box_labels["classes"] = tf.ragged.stack([class1, class2, class3]) + + layer = YoloXLabelEncoder() + + with self.assertRaises(ValueError): + layer(images, box_labels) diff --git a/keras_cv/models/object_detection/yolox/layers/yolox_pafpn.py b/keras_cv/models/object_detection/yolox/layers/yolox_pafpn.py new file mode 100644 index 0000000000..342dc0ad0c --- /dev/null +++ b/keras_cv/models/object_detection/yolox/layers/yolox_pafpn.py @@ -0,0 +1,139 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from tensorflow import keras + +from keras_cv.models.__internal__.darknet_utils import CrossStagePartial +from keras_cv.models.__internal__.darknet_utils import DarknetConvBlock +from keras_cv.models.__internal__.darknet_utils import DarknetConvBlockDepthwise + + +class YoloXPAFPN(keras.layers.Layer): + """The YoloX PAFPN. + + YoloX PAFPN is an FPN layer used in YoloX models. The YoloX PAFPN is based on the + feature pyramid module used in Path Aggregation networks (PANet). + + Arguments: + depth_multiplier: A float value used to calculate the base depth of the model + this changes based on the detection model being used. Defaults to 1.0. + width_multiplier: A float value used to calculate the base width of the model + this changes based on the detection model being used. Defaults to 1.0. + in_channels: A list representing the number of filters in the FPN output. + The length of the list will be same as the number of outputs. Defaults to + (256, 512, 1024). + use_depthwise: a boolean value used to decide whether a depthwise conv block + should be used over a regular darknet block. Defaults to False. + activation: the activation applied after the BatchNorm layer. One of "silu", + "relu" or "leaky_relu". Defaults to "silu". + """ + + def __init__( + self, + depth_multiplier=1.0, + width_multiplier=1.0, + in_channels=(256, 512, 1024), + use_depthwise=False, + activation="silu", + **kwargs + ): + super().__init__(**kwargs) + self.in_channels = in_channels + + ConvBlock = ( + DarknetConvBlockDepthwise if use_depthwise else DarknetConvBlock + ) + + self.lateral_conv0 = DarknetConvBlock( + filters=int(in_channels[1] * width_multiplier), + kernel_size=1, + strides=1, + activation=activation, + ) + self.C3_p4 = CrossStagePartial( + filters=int(in_channels[1] * width_multiplier), + num_bottlenecks=round(3 * depth_multiplier), + residual=False, + use_depthwise=use_depthwise, + activation=activation, + ) + + self.reduce_conv1 = DarknetConvBlock( + filters=int(in_channels[0] * width_multiplier), + kernel_size=1, + strides=1, + activation=activation, + ) + self.C3_p3 = CrossStagePartial( + filters=int(in_channels[0] * width_multiplier), + num_bottlenecks=round(3 * depth_multiplier), + residual=False, + use_depthwise=use_depthwise, + activation=activation, + ) + + self.bu_conv2 = ConvBlock( + filters=int(in_channels[0] * width_multiplier), + kernel_size=3, + strides=2, + activation=activation, + ) + self.C3_n3 = CrossStagePartial( + filters=int(in_channels[1] * width_multiplier), + num_bottlenecks=round(3 * depth_multiplier), + residual=False, + use_depthwise=use_depthwise, + activation=activation, + ) + + self.bu_conv1 = ConvBlock( + filters=int(in_channels[1] * width_multiplier), + kernel_size=3, + strides=2, + activation=activation, + ) + self.C3_n4 = CrossStagePartial( + filters=int(in_channels[2] * width_multiplier), + num_bottlenecks=round(3 * depth_multiplier), + residual=False, + use_depthwise=use_depthwise, + activation=activation, + ) + + self.concat = keras.layers.Concatenate(axis=-1) + self.upsample_2x = keras.layers.UpSampling2D(2) + + def call(self, inputs, training=False): + c3_output, c4_output, c5_output = inputs[3], inputs[4], inputs[5] + + fpn_out0 = self.lateral_conv0(c5_output) + f_out0 = self.upsample_2x(fpn_out0) + f_out0 = self.concat([f_out0, c4_output]) + f_out0 = self.C3_p4(f_out0) + + fpn_out1 = self.reduce_conv1(f_out0) + f_out1 = self.upsample_2x(fpn_out1) + f_out1 = self.concat([f_out1, c3_output]) + pan_out2 = self.C3_p3(f_out1) + + p_out1 = self.bu_conv2(pan_out2) + p_out1 = self.concat([p_out1, fpn_out1]) + pan_out1 = self.C3_n3(p_out1) + + p_out0 = self.bu_conv1(pan_out1) + p_out0 = self.concat([p_out0, fpn_out0]) + pan_out0 = self.C3_n4(p_out0) + + return pan_out2, pan_out1, pan_out0 diff --git a/keras_cv/models/object_detection/yolox/layers/yolox_pafpn_test.py b/keras_cv/models/object_detection/yolox/layers/yolox_pafpn_test.py new file mode 100644 index 0000000000..fda1a0dc96 --- /dev/null +++ b/keras_cv/models/object_detection/yolox/layers/yolox_pafpn_test.py @@ -0,0 +1,52 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import tensorflow as tf + +from keras_cv.models.object_detection.yolox.layers import YoloXPAFPN + + +class YoloXLabelEncoderTest(tf.test.TestCase): + def test_num_parameters(self): + input1 = tf.keras.Input((80, 80, 256)) + input2 = tf.keras.Input((40, 40, 512)) + input3 = tf.keras.Input((20, 20, 1024)) + + output = YoloXPAFPN()({3: input1, 4: input2, 5: input3}) + + model = tf.keras.models.Model( + inputs=[input1, input2, input3], outputs=output + ) + + keras_params = sum( + [tf.keras.backend.count_params(p) for p in model.trainable_weights] + ) + # taken from original implementation + original_params = 19523072 + + self.assertEqual(keras_params, original_params) + + def test_output_shape(self): + inputs = { + 3: tf.random.uniform((3, 80, 80, 256)), + 4: tf.random.uniform((3, 40, 40, 512)), + 5: tf.random.uniform((3, 20, 20, 1024)), + } + + output1, output2, output3 = YoloXPAFPN()(inputs) + + self.assertEqual(output1.shape, [3, 80, 80, 256]) + self.assertEqual(output2.shape, [3, 40, 40, 512]) + self.assertEqual(output3.shape, [3, 20, 20, 1024])