Skip to content

Commit

Permalink
[YOLOX] Step 2/? : Setting up YoloX structure, add internal layers an…
Browse files Browse the repository at this point in the history
…d update iou losses (#1296)

* first attempt at introducing YoloX

* formatted and fixed bugs

* cast fix #1

* cast fix #2

* cast fix #3

* cast fix #4

* adding ensure shape for support

* reverting and removing ensure_shape

* fixed another bug

* updated train.py

* updated docs, tests and added support for loss strings

* first attempt at introducing YoloX

* formatted and fixed bugs

* adding ensure shape for support

* reverting and removing ensure_shape

* reformatted by black

* fixed a  linting issue

* finally rebased atop the recent changes

* finally rebased atop the new changes

* fixed linting issues

* reverted rebasing issues with iou loss

* fixing rebased errors part 2

* fixed more linting issues

* TPU testing changes

* linting fixes

* updated with implementation details from paper

* updated based on review comments and api changes

* first attempt at introducing YoloX

* updated docs, tests and added support for loss strings

* fixed linting issues

* reverted rebasing issues with iou loss

* review comments

* removed examples

* linting fix

* fixed rebasing error

* updated no_reduction warning

* review comments

* revert version and linting fixes
  • Loading branch information
quantumalaviya authored Apr 11, 2023
1 parent c142683 commit e889477
Show file tree
Hide file tree
Showing 11 changed files with 840 additions and 0 deletions.
7 changes: 7 additions & 0 deletions keras_cv/losses/iou_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ def call(self, y_true, y_pred):
f"bounding boxes. Received y_true.shape[-1]={y_true.shape[-1]}."
)

if y_true.shape[-2] < y_pred.shape[-2]:
raise ValueError(
"IoULoss expects number of boxes in y_pred to be equal to the number "
f"of boxes in y_true. Received number of boxes in y_true={y_true.shape[-2]} "
f"and number of boxes in y_pred={y_pred.shape[-2]}."
)

if y_true.shape[-2] != y_pred.shape[-2]:
raise ValueError(
"IoULoss expects number of boxes in y_pred to be equal to the "
Expand Down
13 changes: 13 additions & 0 deletions keras_cv/models/object_detection/yolox/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
97 changes: 97 additions & 0 deletions keras_cv/models/object_detection/yolox/binary_crossentropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import warnings

import tensorflow as tf


class BinaryCrossentropy(tf.keras.losses.Loss):
"""Computes the cross-entropy loss between true labels and predicted labels.
Use this cross-entropy loss for binary (0 or 1) classification applications.
This loss is updated for YoloX by offering support for no axis to mean over.
Args:
from_logits: Whether to interpret `y_pred` as a tensor of
[logit](https://en.wikipedia.org/wiki/Logit) values. By default, we
assume that `y_pred` contains probabilities (i.e., values in [0,
1]).
label_smoothing: Float in [0, 1]. When 0, no smoothing occurs. When >
0, we compute the loss between the predicted labels and a smoothed
version of the true labels, where the smoothing squeezes the labels
towards 0.5. Larger values of `label_smoothing` correspond to
heavier smoothing.
axis: the axis along which to mean the ious. Defaults to `no_reduction` which implies
mean across no axes.
Usage:
```python
model.compile(
loss=keras_cv.models.object_detection.yolox.binary_crossentropy.BinaryCrossentropy(from_logits=True)
....
)
```
"""

def __init__(
self, from_logits=False, label_smoothing=0.0, axis=None, **kwargs
):
super().__init__(**kwargs)
self.from_logits = from_logits
self.label_smoothing = label_smoothing
self.axis = axis

def call(self, y_true, y_pred):
y_pred = tf.convert_to_tensor(y_pred)
y_true = tf.cast(y_true, y_pred.dtype)
label_smoothing = tf.convert_to_tensor(
self.label_smoothing, dtype=y_pred.dtype
)

def _smooth_labels():
return y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing

y_true = tf.__internal__.smart_cond.smart_cond(
label_smoothing, _smooth_labels, lambda: y_true
)

if self.axis == "no_reduction":
warnings.warn(
"`axis='no_reduction'` is a temporary API, and the API contract "
"will be replaced in the future with a more generic solution "
"covering all losses."
)
return tf.reduce_mean(
tf.keras.backend.binary_crossentropy(
y_true, y_pred, from_logits=self.from_logits
),
axis=self.axis,
)

return tf.keras.backend.binary_crossentropy(
y_true, y_pred, from_logits=self.from_logits
)

def get_config(self):
config = super().get_config()
config.update(
{
"from_logits": self.from_logits,
"label_smoothing": self.label_smoothing,
"axis": self.axis,
}
)
return config
22 changes: 22 additions & 0 deletions keras_cv/models/object_detection/yolox/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_cv.models.object_detection.yolox.layers.yolox_decoder import (
YoloXPredictionDecoder,
)
from keras_cv.models.object_detection.yolox.layers.yolox_head import YoloXHead
from keras_cv.models.object_detection.yolox.layers.yolox_label_encoder import (
YoloXLabelEncoder,
)
from keras_cv.models.object_detection.yolox.layers.yolox_pafpn import YoloXPAFPN
168 changes: 168 additions & 0 deletions keras_cv/models/object_detection/yolox/layers/yolox_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import tensorflow as tf
from tensorflow import keras

import keras_cv.layers as cv_layers
from keras_cv import bounding_box


class YoloXPredictionDecoder(keras.layers.Layer):
"""Decodes the predictions from YoloX head.
This layer is similar to the decoding code in `YoloX.compute_losses`. This is
followed by a bounding box suppression layer.
Arguments:
bounding_box_format: The format of bounding boxes of input dataset. Refer
[to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/)
for more details on supported bounding box formats.
num_classes: The number of classes to be considered for the classification head.
suppression_layer: A `keras.layers.Layer` that follows the same API
signature of the `keras_cv.layers.MultiClassNonMaxSuppression` layer.
This layer should perform a suppression operation such as Non Max Suppression,
or Soft Non-Max Suppression.
"""

def __init__(
self, bounding_box_format, num_classes, suppression_layer=None, **kwargs
):
super().__init__(**kwargs)
self.bounding_box_format = bounding_box_format
self.num_classes = num_classes

self.suppression_layer = (
suppression_layer
or cv_layers.MultiClassNonMaxSuppression(
bounding_box_format=bounding_box_format,
from_logits=False,
confidence_threshold=0.01,
iou_threshold=0.65,
max_detections=100,
max_detections_per_class=100,
)
)
if (
self.suppression_layer.bounding_box_format
!= self.bounding_box_format
):
raise ValueError(
"`suppression_layer` must have the same `bounding_box_format` "
"as the `YoloXPredictionDecoder()` layer. "
"Received `YoloXPredictionDecoder.bounding_box_format="
f"{self.bounding_box_format}`, `suppression_layer={suppression_layer}`."
)
self.built = True

def call(self, images, predictions):
image_shape = tf.cast(tf.shape(images), dtype=self.compute_dtype)[1:-1]

batch_size = tf.shape(predictions[0])[0]

grids = []
strides = []

shapes = [x.shape[1:3] for x in predictions]

# 5 + self.num_classes is a concatenation of bounding boxes (length=4)
# + objectness score (length=1) + num_classes
# this reshape is simply collapsing axes 1 and 2 of x into a single dimension
predictions = [
tf.reshape(x, [batch_size, -1, 5 + self.num_classes])
for x in predictions
]
predictions = tf.cast(
tf.concat(predictions, axis=1), dtype=self.compute_dtype
)
predictions_shape = tf.cast(
tf.shape(predictions), dtype=self.compute_dtype
)

for i in range(len(shapes)):
shape_x, shape_y = shapes[i]
grid_x, grid_y = tf.meshgrid(tf.range(shape_y), tf.range(shape_x))
grid = tf.reshape(tf.stack((grid_x, grid_y), 2), (1, -1, 2))
shape = grid.shape[:2]

grids.append(tf.cast(grid, self.compute_dtype))
strides.append(
tf.ones((shape[0], shape[1], 1))
* image_shape[0]
/ tf.cast(shape_x, self.compute_dtype)
)

grids = tf.concat(grids, axis=1)
strides = tf.concat(strides, axis=1)

box_xy = tf.expand_dims(
(predictions[..., :2] + grids) * strides / image_shape, axis=-2
)
box_xy = tf.broadcast_to(
box_xy, [batch_size, predictions_shape[1], self.num_classes, 2]
)
box_wh = tf.expand_dims(
tf.exp(predictions[..., 2:4]) * strides / image_shape, axis=-2
)
box_wh = tf.broadcast_to(
box_wh, [batch_size, predictions_shape[1], self.num_classes, 2]
)

box_confidence = tf.math.sigmoid(predictions[..., 4:5])
box_class_probs = tf.math.sigmoid(predictions[..., 5:])

# create and broadcast classes for every box before nms
box_classes = tf.expand_dims(
tf.range(self.num_classes, dtype=self.compute_dtype), axis=-1
)
box_classes = tf.broadcast_to(
box_classes, [batch_size, predictions_shape[1], self.num_classes, 1]
)

box_scores = tf.expand_dims(box_confidence * box_class_probs, axis=-1)

outputs = tf.concat([box_xy, box_wh, box_classes, box_scores], axis=-1)
outputs = tf.reshape(outputs, [batch_size, -1, 6])

outputs = {
"boxes": outputs[..., :4],
"classes": outputs[..., 4],
"confidence": outputs[..., 5],
}

# this conversion is rel_center_xywh to rel_xywh
# small workaround because rel_center_xywh isn't supported yet
outputs = bounding_box.convert_format(
outputs,
source="center_xywh",
target="xywh",
images=images,
)
outputs = bounding_box.convert_format(
outputs,
source="rel_xywh",
target=self.suppression_layer.bounding_box_format,
images=images,
)

# preparing the predictions for TF NMS op
class_predictions = tf.cast(outputs["classes"], tf.int32)
class_predictions = tf.one_hot(class_predictions, self.num_classes)

scores = (
tf.expand_dims(outputs["confidence"], axis=-1) * class_predictions
)

return self.suppression_layer(outputs["boxes"], scores)
Loading

0 comments on commit e889477

Please sign in to comment.