Skip to content

Commit

Permalink
Implement transform_bounding_boxes for random_flip (#20468)
Browse files Browse the repository at this point in the history
* Implement transform_bounding_boxes for random_flip

* fix test case for torch env

* Add channel first test cases also

* Add condition for channel_first
  • Loading branch information
shashaka authored Nov 8, 2024
1 parent 5bf4ac7 commit 8409e18
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 2 deletions.
86 changes: 84 additions & 2 deletions keras/src/layers/preprocessing/image_preprocessing/random_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
BaseImagePreprocessingLayer,
)
from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501
clip_to_image_size,
)
from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501
convert_format,
)
from keras.src.random.seed_generator import SeedGenerator
from keras.src.utils import backend_utils

HORIZONTAL = "horizontal"
VERTICAL = "vertical"
Expand Down Expand Up @@ -77,7 +84,7 @@ def get_random_transformation(self, data, training=True, seed=None):
flips = self.backend.numpy.less_equal(
self.backend.random.uniform(shape=flips_shape, seed=seed), 0.5
)
return {"flips": flips}
return {"flips": flips, "input_shape": shape}

def transform_images(self, images, transformation, training=True):
images = self.backend.cast(images, self.compute_dtype)
Expand All @@ -94,7 +101,82 @@ def transform_bounding_boxes(
transformation,
training=True,
):
raise NotImplementedError
def _flip_boxes_horizontal(boxes):
x1, x2, x3, x4 = self.backend.numpy.split(boxes, 4, axis=-1)
outputs = self.backend.numpy.concatenate(
[1 - x3, x2, 1 - x1, x4], axis=-1
)
return outputs

def _flip_boxes_vertical(boxes):
x1, x2, x3, x4 = self.backend.numpy.split(boxes, 4, axis=-1)
outputs = self.backend.numpy.concatenate(
[x1, 1 - x4, x3, 1 - x2], axis=-1
)
return outputs

def _transform_xyxy(boxes, box_flips):
if backend_utils.in_tf_graph():
self.backend.set_backend("tensorflow")

bboxes = boxes["boxes"]
if self.mode in {HORIZONTAL, HORIZONTAL_AND_VERTICAL}:
bboxes = self.backend.numpy.where(
box_flips,
_flip_boxes_horizontal(bboxes),
bboxes,
)
if self.mode in {VERTICAL, HORIZONTAL_AND_VERTICAL}:
bboxes = self.backend.numpy.where(
box_flips,
_flip_boxes_vertical(bboxes),
bboxes,
)

self.backend.reset()

return bboxes

flips = self.backend.numpy.squeeze(transformation["flips"], axis=-1)

if self.data_format == "channels_first":
height_axis = -2
width_axis = -1
else:
height_axis = -3
width_axis = -2

input_height, input_width = (
transformation["input_shape"][height_axis],
transformation["input_shape"][width_axis],
)

bounding_boxes = convert_format(
bounding_boxes,
source=self.bounding_box_format,
target="rel_xyxy",
height=input_height,
width=input_width,
)

bounding_boxes["boxes"] = _transform_xyxy(bounding_boxes, flips)

bounding_boxes = clip_to_image_size(
bounding_boxes=bounding_boxes,
height=input_height,
width=input_width,
format="xyxy",
)

bounding_boxes = convert_format(
bounding_boxes,
source="rel_xyxy",
target=self.bounding_box_format,
height=input_height,
width=input_width,
)

return bounding_boxes

def transform_segmentation_masks(
self, segmentation_masks, transformation, training=True
Expand Down
115 changes: 115 additions & 0 deletions keras/src/layers/preprocessing/image_preprocessing/random_flip_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,118 @@ def test_tf_data_compatibility(self):
ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
output = next(iter(ds)).numpy()
self.assertAllClose(output, expected_output)

@parameterized.named_parameters(
(
"with_horizontal",
"horizontal",
[[4, 1, 6, 3], [0, 4, 2, 6]],
),
(
"with_vertical",
"vertical",
[[2, 7, 4, 9], [6, 4, 8, 6]],
),
(
"with_horizontal_and_vertical",
"horizontal_and_vertical",
[[4, 7, 6, 9], [0, 4, 2, 6]],
),
)
def test_random_flip_bounding_boxes(self, mode, expected_boxes):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
image_shape = (10, 8, 3)
else:
image_shape = (3, 10, 8)
input_image = np.random.random(image_shape)
bounding_boxes = {
"boxes": np.array(
[
[2, 1, 4, 3],
[6, 4, 8, 6],
]
),
"labels": np.array([[1, 2]]),
}
input_data = {"images": input_image, "bounding_boxes": bounding_boxes}
random_flip_layer = layers.RandomFlip(
mode,
data_format=data_format,
seed=42,
bounding_box_format="xyxy",
)

transformation = {
"flips": np.asarray([[True]]),
"input_shape": input_image.shape,
}
output = random_flip_layer.transform_bounding_boxes(
input_data["bounding_boxes"],
transformation=transformation,
training=True,
)

self.assertAllClose(output["boxes"], expected_boxes)

@parameterized.named_parameters(
(
"with_horizontal",
"horizontal",
[[4, 1, 6, 3], [0, 4, 2, 6]],
),
(
"with_vertical",
"vertical",
[[2, 7, 4, 9], [6, 4, 8, 6]],
),
(
"with_horizontal_and_vertical",
"horizontal_and_vertical",
[[4, 7, 6, 9], [0, 4, 2, 6]],
),
)
def test_random_flip_tf_data_bounding_boxes(self, mode, expected_boxes):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
image_shape = (1, 10, 8, 3)
else:
image_shape = (1, 3, 10, 8)
input_image = np.random.random(image_shape)
bounding_boxes = {
"boxes": np.array(
[
[
[2, 1, 4, 3],
[6, 4, 8, 6],
]
]
),
"labels": np.array([[1, 2]]),
}

input_data = {"images": input_image, "bounding_boxes": bounding_boxes}

ds = tf_data.Dataset.from_tensor_slices(input_data)
random_flip_layer = layers.RandomFlip(
mode,
data_format=data_format,
seed=42,
bounding_box_format="xyxy",
)

transformation = {
"flips": np.asarray([[True]]),
"input_shape": input_image.shape,
}
ds = ds.map(
lambda x: random_flip_layer.transform_bounding_boxes(
x["bounding_boxes"],
transformation=transformation,
training=True,
)
)

output = next(iter(ds))
expected_boxes = np.array(expected_boxes)
self.assertAllClose(output["boxes"], expected_boxes)

0 comments on commit 8409e18

Please sign in to comment.