Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GridMask implementation with BaseImageAugmentationLayer #159

Merged
merged 47 commits into from
Apr 1, 2022
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
ba3413d
temp driver
chjort Feb 5, 2022
a743946
vectorize mask computation
Feb 7, 2022
70587d9
vectorize mask computation
Feb 7, 2022
1fe49cd
vectorize mask computation
Feb 8, 2022
e18016c
comment
Feb 9, 2022
8f03787
random rotation and center cropping
Feb 9, 2022
6a0e9a6
finish vectorized gridmask computation
Feb 11, 2022
27f2025
refactor
Feb 11, 2022
92a019c
comment
Feb 11, 2022
15968f4
comment
Feb 11, 2022
2a24c94
comments
Feb 11, 2022
fb443c7
comments
chjort Feb 13, 2022
9edfdf3
Merge branch 'master' into vectorize_gridmask
chjort Feb 14, 2022
abd8048
merge master into branch
chjort Feb 15, 2022
9599c15
initial vectorized layer
chjort Feb 15, 2022
ac8fa28
initial vectorized layer
chjort Feb 15, 2022
60759ce
debug memory usage
chjort Feb 15, 2022
fb6cfcd
fix tests
chjort Feb 15, 2022
79b8387
remove support for single image
chjort Feb 16, 2022
80c080a
set ratio random in demo
chjort Feb 16, 2022
aa55af5
minimize memory by reducing number of simultaneous Logical And operat…
chjort Feb 19, 2022
a1d5f03
individual ratio for each image when ratio="random"
chjort Feb 19, 2022
3e8d0c4
add vectorized argument
chjort Feb 19, 2022
4a033ab
use float32 instead of int32
chjort Feb 22, 2022
de60f96
use float32 instead of int32
chjort Feb 24, 2022
48df666
remove vectorized arg
chjort Feb 24, 2022
a9f8700
minor refactor
chjort Feb 24, 2022
bbf8977
minor refactor
chjort Feb 25, 2022
54416ff
support single image
chjort Feb 25, 2022
783b5c4
refactor coordinates to mask
chjort Feb 26, 2022
25f706b
merge master
chjort Feb 26, 2022
357717b
optimize bounding box mask generation
chjort Feb 27, 2022
ed5161e
add tests
chjort Feb 27, 2022
994242a
minor refactor
chjort Feb 28, 2022
532889a
minor refactor
chjort Feb 28, 2022
f68c4b4
Merge branch 'master' into vectorize_gridmask
chjort Feb 28, 2022
37e6760
formatting
chjort Feb 28, 2022
796cdd2
black
chjort Feb 28, 2022
420d6cb
merge with master
chjort Mar 26, 2022
1489531
WIP. GridMask to BaseImageAugmentationLayer
chjort Mar 26, 2022
c2900cf
WIP. GridMask to BaseImageAugmentationLayer
chjort Mar 26, 2022
19794f1
merge master
chjort Mar 30, 2022
0a50a86
Merge branch 'master' into vectorize_gridmask
chjort Mar 30, 2022
a926115
GridMask to BaseImageAugmentationLayer
chjort Mar 30, 2022
12dbc14
Merge remote-tracking branch 'origin/vectorize_gridmask' into vectori…
chjort Mar 30, 2022
5287f7f
Apply changes from review
chjort Mar 31, 2022
e4debc2
fix test
chjort Mar 31, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 102 additions & 110 deletions keras_cv/layers/preprocessing/grid_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,23 @@
# limitations under the License.

import tensorflow as tf
from tensorflow.keras import backend
from tensorflow.keras import layers

from keras_cv.utils import fill_utils


def _center_crop(mask, width, height):
masks_shape = tf.shape(mask)
h_diff = masks_shape[0] - height
w_diff = masks_shape[1] - width

h_start = tf.cast(h_diff / 2, tf.int32)
w_start = tf.cast(w_diff / 2, tf.int32)
return tf.image.crop_to_bounding_box(mask, h_start, w_start, height, width)


@tf.keras.utils.register_keras_serializable(package="keras_cv")
class GridMask(layers.Layer):
class GridMask(tf.keras.__internal__.layers.BaseImageAugmentationLayer):
"""GridMask class for grid-mask augmentation.


Expand Down Expand Up @@ -75,8 +86,7 @@ def __init__(
ratio="random",
rotation_factor=0.15,
fill_mode="constant",
fill_value=0,
seed=None,
fill_value=0.0,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -86,8 +96,13 @@ def __init__(
self.fill_mode = fill_mode
self.fill_value = fill_value
self.rotation_factor = rotation_factor
self.seed = seed
self.random_rotate = layers.RandomRotation(factor=rotation_factor, seed=seed)
self.random_rotate = layers.RandomRotation(
factor=rotation_factor,
fill_mode="constant",
fill_value=0.0,
seed=self._random_generator._seed,
)
self.auto_vectorize = False
self._check_parameter_values()

def _check_parameter_values(self):
Expand Down Expand Up @@ -116,135 +131,112 @@ def _check_parameter_values(self):
f'"gaussian_noise", or "random". Got `fill_mode`={fill_mode}'
)

@staticmethod
def _crop(mask, image_height, image_width):
"""crops in middle of mask and image corners."""
mask_width = mask_height = tf.shape(mask)[0]
mask = mask[
(mask_height - image_height) // 2 : (mask_height - image_height) // 2
+ image_height,
(mask_width - image_width) // 2 : (mask_width - image_width) // 2
+ image_width,
]
return mask

@tf.function
def _compute_mask(self, image_height, image_width):
"""mask helper function for initializing grid mask of required size."""
image_height = tf.cast(image_height, dtype=tf.float32)
image_width = tf.cast(image_width, dtype=tf.float32)

mask_width = mask_height = tf.cast(
tf.math.maximum(image_height, image_width) * 2.0, dtype=tf.int32
)

if self.fill_mode == "constant":
mask = tf.fill([mask_height, mask_width], value=-1)
elif self.fill_mode == "gaussian_noise":
mask = tf.cast(tf.random.normal([mask_height, mask_width]), dtype=tf.int32)
else:
raise ValueError(
"Unsupported fill_mode. `fill_mode` should be 'constant' or "
"'gaussian_noise'."
)

gridblock = tf.random.uniform(
shape=[],
minval=int(tf.math.minimum(image_height * 0.5, image_width * 0.3)),
maxval=int(tf.math.maximum(image_height * 0.5, image_width * 0.3)) + 1,
dtype=tf.int32,
seed=self.seed,
)

def get_random_transformation(self, image=None, label=None, bounding_box=None):
if self.ratio == "random":
length = tf.random.uniform(
shape=[], minval=1, maxval=gridblock + 1, dtype=tf.int32, seed=self.seed
ratio = self._random_generator.random_uniform(
shape=(), minval=0.0, maxval=1.0, dtype=tf.float32
)
else:
length = tf.cast(
tf.math.minimum(
tf.math.maximum(
int(tf.cast(gridblock, tf.float32) * self.ratio + 0.5), 1
),
gridblock - 1,
),
tf.int32,
)
ratio = self.ratio

for _ in range(2):
start_x = tf.random.uniform(
shape=[], minval=0, maxval=gridblock + 1, dtype=tf.int32, seed=self.seed
)
# compute grid mask
input_shape = tf.shape(image)
mask = self._compute_grid_mask(input_shape, ratio=ratio)

for i in range(mask_width // gridblock):
start = gridblock * i + start_x
end = tf.math.minimum(start + length, mask_width)
indices = tf.reshape(tf.range(start, end), [end - start, 1])
updates = tf.fill([end - start, mask_width], value=self.fill_value)
mask = tf.tensor_scatter_nd_update(mask, indices, updates)
mask = tf.transpose(mask)
# convert mask to single-channel image
mask = tf.cast(mask, tf.float32)
mask = tf.expand_dims(mask, axis=-1)

return tf.equal(mask, self.fill_value)
# randomly rotate mask
mask = self.random_rotate(mask)

@tf.function
def _grid_mask(self, image):
image_height = tf.shape(image)[0]
image_width = tf.shape(image)[1]
# compute fill
if self.fill_mode == "constant":
fill_value = tf.fill(input_shape, self.fill_value)
else:
# gaussian noise
fill_value = self._random_generator.random_normal(
shape=input_shape, dtype=image.dtype
)

grid = self._compute_mask(image_height, image_width)
grid = self.random_rotate(tf.cast(grid[:, :, tf.newaxis], tf.float32))
return mask, fill_value

mask = tf.reshape(
tf.cast(self._crop(grid, image_height, image_width), dtype=image.dtype),
(image_height, image_width),
)
mask = tf.expand_dims(mask, -1) if image._rank() != mask._rank() else mask
def _compute_grid_mask(self, input_shape, ratio):
height = tf.cast(input_shape[0], tf.float32)
width = tf.cast(input_shape[1], tf.float32)

if self.fill_mode == "constant":
return tf.where(tf.cast(mask, tf.bool), image, self.fill_value)
else:
return mask * image
# mask side length
input_diagonal_len = tf.sqrt(tf.square(width) + tf.square(height))
mask_side_len = tf.math.ceil(input_diagonal_len)

def _augment_images(self, images):
unbatched = images.shape.rank == 3
# grid unit size
unit_size = self._random_generator.random_uniform(
shape=(),
minval=tf.math.minimum(height * 0.5, width * 0.3),
maxval=tf.math.maximum(height * 0.5, width * 0.3) + 1,
dtype=tf.float32,
)
rectangle_side_len = tf.cast((1 - ratio) * unit_size, tf.float32)

# The transform op only accepts rank 4 inputs, so if we have an unbatched
# image, we need to temporarily expand dims to a batch.
if unbatched:
images = tf.expand_dims(images, axis=0)
# sample x and y offset for grid units randomly between 0 and unit_size
delta_x = self._random_generator.random_uniform(
shape=(), minval=0.0, maxval=unit_size, dtype=tf.float32
)
delta_y = self._random_generator.random_uniform(
shape=(), minval=0.0, maxval=unit_size, dtype=tf.float32
)

# TODO: Make the batch operation vectorize.
output = tf.map_fn(lambda image: self._grid_mask(image), images)
# grid size (number of diagonal units in grid)
grid_size = mask_side_len // unit_size + 1
grid_size_range = tf.range(1, grid_size + 1)

# diagonal corner coordinates
unit_size_range = grid_size_range * unit_size
x1 = unit_size_range - delta_x
x0 = x1 - rectangle_side_len
y1 = unit_size_range - delta_y
y0 = y1 - rectangle_side_len

# compute grid coordinates
x0, y0 = tf.meshgrid(x0, y0)
x1, y1 = tf.meshgrid(x1, y1)

# flatten mesh grid
x0 = tf.reshape(x0, [-1])
y0 = tf.reshape(y0, [-1])
x1 = tf.reshape(x1, [-1])
y1 = tf.reshape(y1, [-1])

# convert coordinates to mask
corners = tf.stack([x0, y0, x1, y1], axis=-1)
mask_side_len = tf.cast(mask_side_len, tf.int32)
rectangle_masks = fill_utils.corners_to_mask(
corners, mask_shape=(mask_side_len, mask_side_len)
)
grid_mask = tf.reduce_any(rectangle_masks, axis=0)

if unbatched:
output = tf.squeeze(output, axis=0)
return output
return grid_mask

def call(self, images, training=None):
"""call method for the GridMask layer.
def augment_image(self, image, transformation=None):
mask, fill_value = transformation
input_shape = tf.shape(image)

Args:
images: Tensor representing images with shape
[batch_size, width, height, channels] or [width, height, channels]
of type int or float. Values should be in the range [0, 255].
Returns:
images: augmented images, same shape as input.
"""
# center crop mask
input_height = input_shape[0]
input_width = input_shape[1]
mask = _center_crop(mask, input_width, input_height)

if training is None:
training = backend.learning_phase()
# convert back to boolean mask
mask = tf.cast(mask, tf.bool)

if not training:
return images
return self._augment_images(images)
return tf.where(mask, fill_value, image)

def get_config(self):
config = {
"ratio": self.ratio,
"rotation_factor": self.rotation_factor,
"fill_mode": self.fill_mode,
"fill_value": self.fill_value,
"seed": self.seed,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
11 changes: 4 additions & 7 deletions keras_cv/layers/preprocessing/grid_mask_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_gridmask_call_results_one_channel(self):
dtype=tf.float32,
)

fill_value = 0
fill_value = 0.0
layer = GridMask(
ratio=0.3,
rotation_factor=(0.2, 0.3),
Expand All @@ -60,7 +60,7 @@ def test_non_square_image(self):
dtype=tf.float32,
)

fill_value = 100
fill_value = 100.0
layer = GridMask(
ratio=0.6, rotation_factor=0.3, fill_mode="constant", fill_value=fill_value
)
Expand All @@ -78,7 +78,7 @@ def test_in_tf_function(self):
dtype=tf.float32,
)

fill_value = 255
fill_value = 255.0
layer = GridMask(
ratio=0.4, rotation_factor=0.5, fill_mode="constant", fill_value=fill_value
)
Expand All @@ -101,10 +101,7 @@ def test_in_single_image(self):
dtype=tf.float32,
)

layer = GridMask(
ratio="random",
fill_mode="gaussian_noise",
)
layer = GridMask(ratio="random", fill_mode="constant", fill_value=0.0)
xs = layer(xs, training=True)
self.assertTrue(tf.math.reduce_any(xs == 0.0))
self.assertTrue(tf.math.reduce_any(xs == 1.0))
60 changes: 25 additions & 35 deletions keras_cv/utils/fill_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,50 +16,40 @@
from keras_cv.utils import bounding_box


def rectangle_masks(corners, mask_shape):
"""Computes masks of rectangles
def _axis_mask(starts, ends, mask_len):
# index range of axis
batch_size = tf.shape(starts)[0]
axis_indices = tf.range(mask_len, dtype=starts.dtype)
axis_indices = tf.expand_dims(axis_indices, 0)
axis_indices = tf.tile(axis_indices, [batch_size, 1])

# mask of index bounds
axis_mask = tf.greater_equal(axis_indices, starts) & tf.less(axis_indices, ends)
return axis_mask


def corners_to_mask(bounding_boxes, mask_shape):
"""Converts bounding boxes in corners format to boolean masks

Args:
corners: tensor of rectangle coordinates with shape (batch_size, 4) in
bounding_boxes: tensor of rectangle coordinates with shape (batch_size, 4) in
corners format (x0, y0, x1, y1).
mask_shape: a shape tuple as (width, height) indicating the output
width and height of masks.

Returns:
boolean masks with shape (batch_size, width, height) where True values
indicate positions within rectangle coordinates.
indicate positions within bounding box coordinates.
"""
# add broadcasting axes
corners = corners[..., tf.newaxis, tf.newaxis]

# split coordinates
x0 = corners[:, 0]
y0 = corners[:, 1]
x1 = corners[:, 2]
y1 = corners[:, 3]

# repeat height and width
width, height = mask_shape
x0_rep = tf.repeat(x0, height, axis=1)
y0_rep = tf.repeat(y0, width, axis=2)
x1_rep = tf.repeat(x1, height, axis=1)
y1_rep = tf.repeat(y1, width, axis=2)

# range grid
batch_size = tf.shape(corners)[0]
range_row = tf.range(0, height, dtype=corners.dtype)
range_col = tf.range(0, width, dtype=corners.dtype)
range_row = tf.repeat(range_row[tf.newaxis, :, tf.newaxis], batch_size, 0)
range_col = tf.repeat(range_col[tf.newaxis, tf.newaxis, :], batch_size, 0)

# boolean masks
mask_x0 = tf.less_equal(x0_rep, range_col)
mask_y0 = tf.less_equal(y0_rep, range_row)
mask_x1 = tf.less(range_col, x1_rep)
mask_y1 = tf.less(range_row, y1_rep)

masks = mask_x0 & mask_y0 & mask_x1 & mask_y1
mask_width, mask_height = mask_shape
x0, y0, x1, y1 = tf.split(bounding_boxes, [1, 1, 1, 1], axis=-1)

w_mask = _axis_mask(x0, x1, mask_width)
h_mask = _axis_mask(y0, y1, mask_height)

w_mask = tf.expand_dims(w_mask, axis=1)
h_mask = tf.expand_dims(h_mask, axis=2)
masks = tf.logical_and(w_mask, h_mask)
return masks


Expand All @@ -85,7 +75,7 @@ def fill_rectangle(images, centers_x, centers_y, widths, heights, fill_values):
corners = bounding_box.xywh_to_corners(xywh)

mask_shape = (images_width, images_height)
is_rectangle = rectangle_masks(corners, mask_shape)
is_rectangle = corners_to_mask(corners, mask_shape)
is_rectangle = tf.expand_dims(is_rectangle, -1)

images = tf.where(is_rectangle, fill_values, images)
Expand Down
Loading