diff --git a/keras_cv/models/object_detection/yolox/yolox.py b/keras_cv/models/object_detection/yolox/yolox.py index 34a4308cad..78e9a9e669 100644 --- a/keras_cv/models/object_detection/yolox/yolox.py +++ b/keras_cv/models/object_detection/yolox/yolox.py @@ -643,7 +643,7 @@ def loop_across_batch_2(b, matching_matrix): ) fg_mask_inboxes = tf.reduce_sum(matching_matrix, 0) > 0.0 - num_fg = tf.reduce_sum(fg_mask_inboxes) + num_fg = tf.reduce_sum(tf.cast(fg_mask_inboxes, tf.float32)) fg_mask_indices = tf.reshape(tf.where(fg_mask), [-1]) fg_mask_inboxes_indices = tf.reshape(tf.where(fg_mask_inboxes), [-1, 1])