-
Notifications
You must be signed in to change notification settings - Fork 330
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[YOLOX] Step 2/? : Setting up YoloX structure, add internal layers an…
…d update iou losses (#1296) * first attempt at introducing YoloX * formatted and fixed bugs * cast fix #1 * cast fix #2 * cast fix #3 * cast fix #4 * adding ensure shape for support * reverting and removing ensure_shape * fixed another bug * updated train.py * updated docs, tests and added support for loss strings * first attempt at introducing YoloX * formatted and fixed bugs * adding ensure shape for support * reverting and removing ensure_shape * reformatted by black * fixed a linting issue * finally rebased atop the recent changes * finally rebased atop the new changes * fixed linting issues * reverted rebasing issues with iou loss * fixing rebased errors part 2 * fixed more linting issues * TPU testing changes * linting fixes * updated with implementation details from paper * updated based on review comments and api changes * first attempt at introducing YoloX * updated docs, tests and added support for loss strings * fixed linting issues * reverted rebasing issues with iou loss * review comments * removed examples * linting fix * fixed rebasing error * updated no_reduction warning * review comments * revert version and linting fixes
- Loading branch information
1 parent
c142683
commit e889477
Showing
11 changed files
with
840 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2023 The KerasCV Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
97 changes: 97 additions & 0 deletions
97
keras_cv/models/object_detection/yolox/binary_crossentropy.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright 2023 The KerasCV Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import warnings | ||
|
||
import tensorflow as tf | ||
|
||
|
||
class BinaryCrossentropy(tf.keras.losses.Loss): | ||
"""Computes the cross-entropy loss between true labels and predicted labels. | ||
Use this cross-entropy loss for binary (0 or 1) classification applications. | ||
This loss is updated for YoloX by offering support for no axis to mean over. | ||
Args: | ||
from_logits: Whether to interpret `y_pred` as a tensor of | ||
[logit](https://en.wikipedia.org/wiki/Logit) values. By default, we | ||
assume that `y_pred` contains probabilities (i.e., values in [0, | ||
1]). | ||
label_smoothing: Float in [0, 1]. When 0, no smoothing occurs. When > | ||
0, we compute the loss between the predicted labels and a smoothed | ||
version of the true labels, where the smoothing squeezes the labels | ||
towards 0.5. Larger values of `label_smoothing` correspond to | ||
heavier smoothing. | ||
axis: the axis along which to mean the ious. Defaults to `no_reduction` which implies | ||
mean across no axes. | ||
Usage: | ||
```python | ||
model.compile( | ||
loss=keras_cv.models.object_detection.yolox.binary_crossentropy.BinaryCrossentropy(from_logits=True) | ||
.... | ||
) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, from_logits=False, label_smoothing=0.0, axis=None, **kwargs | ||
): | ||
super().__init__(**kwargs) | ||
self.from_logits = from_logits | ||
self.label_smoothing = label_smoothing | ||
self.axis = axis | ||
|
||
def call(self, y_true, y_pred): | ||
y_pred = tf.convert_to_tensor(y_pred) | ||
y_true = tf.cast(y_true, y_pred.dtype) | ||
label_smoothing = tf.convert_to_tensor( | ||
self.label_smoothing, dtype=y_pred.dtype | ||
) | ||
|
||
def _smooth_labels(): | ||
return y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing | ||
|
||
y_true = tf.__internal__.smart_cond.smart_cond( | ||
label_smoothing, _smooth_labels, lambda: y_true | ||
) | ||
|
||
if self.axis == "no_reduction": | ||
warnings.warn( | ||
"`axis='no_reduction'` is a temporary API, and the API contract " | ||
"will be replaced in the future with a more generic solution " | ||
"covering all losses." | ||
) | ||
return tf.reduce_mean( | ||
tf.keras.backend.binary_crossentropy( | ||
y_true, y_pred, from_logits=self.from_logits | ||
), | ||
axis=self.axis, | ||
) | ||
|
||
return tf.keras.backend.binary_crossentropy( | ||
y_true, y_pred, from_logits=self.from_logits | ||
) | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"from_logits": self.from_logits, | ||
"label_smoothing": self.label_smoothing, | ||
"axis": self.axis, | ||
} | ||
) | ||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Copyright 2023 The KerasCV Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from keras_cv.models.object_detection.yolox.layers.yolox_decoder import ( | ||
YoloXPredictionDecoder, | ||
) | ||
from keras_cv.models.object_detection.yolox.layers.yolox_head import YoloXHead | ||
from keras_cv.models.object_detection.yolox.layers.yolox_label_encoder import ( | ||
YoloXLabelEncoder, | ||
) | ||
from keras_cv.models.object_detection.yolox.layers.yolox_pafpn import YoloXPAFPN |
168 changes: 168 additions & 0 deletions
168
keras_cv/models/object_detection/yolox/layers/yolox_decoder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
# Copyright 2023 The KerasCV Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import tensorflow as tf | ||
from tensorflow import keras | ||
|
||
import keras_cv.layers as cv_layers | ||
from keras_cv import bounding_box | ||
|
||
|
||
class YoloXPredictionDecoder(keras.layers.Layer): | ||
"""Decodes the predictions from YoloX head. | ||
This layer is similar to the decoding code in `YoloX.compute_losses`. This is | ||
followed by a bounding box suppression layer. | ||
Arguments: | ||
bounding_box_format: The format of bounding boxes of input dataset. Refer | ||
[to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/) | ||
for more details on supported bounding box formats. | ||
num_classes: The number of classes to be considered for the classification head. | ||
suppression_layer: A `keras.layers.Layer` that follows the same API | ||
signature of the `keras_cv.layers.MultiClassNonMaxSuppression` layer. | ||
This layer should perform a suppression operation such as Non Max Suppression, | ||
or Soft Non-Max Suppression. | ||
""" | ||
|
||
def __init__( | ||
self, bounding_box_format, num_classes, suppression_layer=None, **kwargs | ||
): | ||
super().__init__(**kwargs) | ||
self.bounding_box_format = bounding_box_format | ||
self.num_classes = num_classes | ||
|
||
self.suppression_layer = ( | ||
suppression_layer | ||
or cv_layers.MultiClassNonMaxSuppression( | ||
bounding_box_format=bounding_box_format, | ||
from_logits=False, | ||
confidence_threshold=0.01, | ||
iou_threshold=0.65, | ||
max_detections=100, | ||
max_detections_per_class=100, | ||
) | ||
) | ||
if ( | ||
self.suppression_layer.bounding_box_format | ||
!= self.bounding_box_format | ||
): | ||
raise ValueError( | ||
"`suppression_layer` must have the same `bounding_box_format` " | ||
"as the `YoloXPredictionDecoder()` layer. " | ||
"Received `YoloXPredictionDecoder.bounding_box_format=" | ||
f"{self.bounding_box_format}`, `suppression_layer={suppression_layer}`." | ||
) | ||
self.built = True | ||
|
||
def call(self, images, predictions): | ||
image_shape = tf.cast(tf.shape(images), dtype=self.compute_dtype)[1:-1] | ||
|
||
batch_size = tf.shape(predictions[0])[0] | ||
|
||
grids = [] | ||
strides = [] | ||
|
||
shapes = [x.shape[1:3] for x in predictions] | ||
|
||
# 5 + self.num_classes is a concatenation of bounding boxes (length=4) | ||
# + objectness score (length=1) + num_classes | ||
# this reshape is simply collapsing axes 1 and 2 of x into a single dimension | ||
predictions = [ | ||
tf.reshape(x, [batch_size, -1, 5 + self.num_classes]) | ||
for x in predictions | ||
] | ||
predictions = tf.cast( | ||
tf.concat(predictions, axis=1), dtype=self.compute_dtype | ||
) | ||
predictions_shape = tf.cast( | ||
tf.shape(predictions), dtype=self.compute_dtype | ||
) | ||
|
||
for i in range(len(shapes)): | ||
shape_x, shape_y = shapes[i] | ||
grid_x, grid_y = tf.meshgrid(tf.range(shape_y), tf.range(shape_x)) | ||
grid = tf.reshape(tf.stack((grid_x, grid_y), 2), (1, -1, 2)) | ||
shape = grid.shape[:2] | ||
|
||
grids.append(tf.cast(grid, self.compute_dtype)) | ||
strides.append( | ||
tf.ones((shape[0], shape[1], 1)) | ||
* image_shape[0] | ||
/ tf.cast(shape_x, self.compute_dtype) | ||
) | ||
|
||
grids = tf.concat(grids, axis=1) | ||
strides = tf.concat(strides, axis=1) | ||
|
||
box_xy = tf.expand_dims( | ||
(predictions[..., :2] + grids) * strides / image_shape, axis=-2 | ||
) | ||
box_xy = tf.broadcast_to( | ||
box_xy, [batch_size, predictions_shape[1], self.num_classes, 2] | ||
) | ||
box_wh = tf.expand_dims( | ||
tf.exp(predictions[..., 2:4]) * strides / image_shape, axis=-2 | ||
) | ||
box_wh = tf.broadcast_to( | ||
box_wh, [batch_size, predictions_shape[1], self.num_classes, 2] | ||
) | ||
|
||
box_confidence = tf.math.sigmoid(predictions[..., 4:5]) | ||
box_class_probs = tf.math.sigmoid(predictions[..., 5:]) | ||
|
||
# create and broadcast classes for every box before nms | ||
box_classes = tf.expand_dims( | ||
tf.range(self.num_classes, dtype=self.compute_dtype), axis=-1 | ||
) | ||
box_classes = tf.broadcast_to( | ||
box_classes, [batch_size, predictions_shape[1], self.num_classes, 1] | ||
) | ||
|
||
box_scores = tf.expand_dims(box_confidence * box_class_probs, axis=-1) | ||
|
||
outputs = tf.concat([box_xy, box_wh, box_classes, box_scores], axis=-1) | ||
outputs = tf.reshape(outputs, [batch_size, -1, 6]) | ||
|
||
outputs = { | ||
"boxes": outputs[..., :4], | ||
"classes": outputs[..., 4], | ||
"confidence": outputs[..., 5], | ||
} | ||
|
||
# this conversion is rel_center_xywh to rel_xywh | ||
# small workaround because rel_center_xywh isn't supported yet | ||
outputs = bounding_box.convert_format( | ||
outputs, | ||
source="center_xywh", | ||
target="xywh", | ||
images=images, | ||
) | ||
outputs = bounding_box.convert_format( | ||
outputs, | ||
source="rel_xywh", | ||
target=self.suppression_layer.bounding_box_format, | ||
images=images, | ||
) | ||
|
||
# preparing the predictions for TF NMS op | ||
class_predictions = tf.cast(outputs["classes"], tf.int32) | ||
class_predictions = tf.one_hot(class_predictions, self.num_classes) | ||
|
||
scores = ( | ||
tf.expand_dims(outputs["confidence"], axis=-1) * class_predictions | ||
) | ||
|
||
return self.suppression_layer(outputs["boxes"], scores) |
Oops, something went wrong.