-
Notifications
You must be signed in to change notification settings - Fork 330
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Random Resized Crop #499
Random Resized Crop #499
Changes from 34 commits
8237608
ebb52bd
43ae6b6
ddcc662
7916bc5
19ff395
0a2042d
8967911
0f25f7e
258b979
2ed932f
bcb876c
be4e8f0
2e259cd
2044339
1911775
f87732a
9036bd8
309d266
5c89b18
ca405cf
83e090a
0787611
6cb4965
6e2bdcb
a4973be
55c19db
9edae6b
75ce892
fbb1f83
844ec79
4a2aa3c
8006afc
b1448a4
fa8b715
82d6de0
a2f7e24
4e85e59
4c80dc0
95e4d6e
f0f897d
b283cb4
6ab77e3
0b6c478
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
# Copyright 2022 The KerasCV Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import warnings | ||
|
||
import tensorflow as tf | ||
|
||
from keras_cv import core | ||
from keras_cv.layers.preprocessing.base_image_augmentation_layer import ( | ||
BaseImageAugmentationLayer, ) | ||
from keras_cv.utils import preprocessing | ||
|
||
|
||
@tf.keras.utils.register_keras_serializable(package="keras_cv") | ||
class RandomResizedCrop(BaseImageAugmentationLayer): | ||
"""Randomly crops a part of an image and resizes it to provided size. | ||
|
||
This implementation takes an intuitive approach, where we crop the images to a | ||
random height and width, and then resize them. To do this, we first sample a | ||
random value for area using `area_factor` and a value for aspect ratio using | ||
`aspect_ratio_factor`. Further we get the new height and width by | ||
dividing and multiplying the old height and width by the random area | ||
respectively. We then sample offsets for height and width and clip them such | ||
that the cropped area does not exceed image boundaries. Finally we do the | ||
actual cropping operation and resize the image to `target_size`. | ||
|
||
Args: | ||
target_size: A tuple of two integers used as the target size to ultimately crop | ||
images to. | ||
area_factor: (Optional) A tuple of two floats, a single float or | ||
`keras_cv.FactorSampler`. The ratio of area of the cropped part to | ||
that of original image is sampled using this factor. Represents the | ||
lower and upper bounds for the area relative to the original image | ||
of the cropped image before resizing it to `target_size`. | ||
`target_size`. Defaults to (0.08, 1.0). | ||
aspect_ratio_factor: (Optional) A tuple of two floats, a single float or | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also don't think we can take a default here because this will be surprising to users. That being said, a user SHOULD use this, so I suggest we provide no default and require this argument. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @LukeWood In the case of a single float or integer, what should be the max (or min) value for if isinstance(aspect_ratio_factor, tuple):
min_aspect_ratio = min(aspect_ratio_factor)
max_aspect_ratio = max(aspect_ratio_factor)
elif isinstance(aspect_ratio_factor, core.FactorSampler):
pass
else:
raise ValueError(
"Expected `aspect_ratio` to be tuple or FactorSampler. Received "
f"aspect_ratio_factor={aspect_ratio_factor}."
) If we want to add a case for single ints and floats, we'll need to decide on a predefined max or min. Same goes for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lets just drop integer and float support for aspect ratio factor and make them pass a tuple. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Awesome, thanks for clarifying! |
||
`keras_cv.FactorSampler`. Aspect ratio means the ratio of width to | ||
height of the cropped image. In the context of this layer, the aspect ratio | ||
sampled represents a value to distort the aspect ratio by. | ||
Represents the lower and upper bound for the aspect ratio of the | ||
cropped image before resizing it to `target_size`. Defaults to | ||
(3./4., 4./3.). | ||
interpolation: (Optional) A string specifying the sampling method for | ||
resizing. Defaults to "bilinear". | ||
seed: (Optional) Used to create a random seed. Defaults to None. | ||
""" | ||
def __init__( | ||
self, | ||
target_size, | ||
area_factor=(0.08, 1.0), | ||
aspect_ratio_factor=(3. / 4., 4. / 3.), | ||
interpolation="bilinear", | ||
seed=None, | ||
**kwargs, | ||
): | ||
super().__init__(seed=seed, **kwargs) | ||
|
||
self.target_size = target_size | ||
|
||
aspect_ratio_factor = aspect_ratio_factor or (3. / 4., 4. / 3.) | ||
if isinstance(aspect_ratio_factor, tuple): | ||
AdityaKane2001 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
min_aspect_ratio = min(aspect_ratio_factor) | ||
max_aspect_ratio = max(aspect_ratio_factor) | ||
elif isinstance(aspect_ratio_factor, core.FactorSampler): | ||
pass | ||
else: | ||
raise ValueError( | ||
"Expected `aspect_ratio` to be tuple or FactorSampler. Received " | ||
f"RandomResizedCrop(aspect_ratio_factor={aspect_ratio_factor})." | ||
) | ||
|
||
self.aspect_ratio_factor = preprocessing.parse_factor( | ||
aspect_ratio_factor, | ||
min_value=min_aspect_ratio, | ||
max_value=max_aspect_ratio, | ||
AdityaKane2001 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
param_name="aspect_ratio_factor", | ||
seed=seed, | ||
) | ||
self.area_factor = preprocessing.parse_factor( | ||
area_factor, | ||
max_value=1.0, | ||
param_name="area_factor", | ||
seed=seed, | ||
) | ||
|
||
self.interpolation = interpolation | ||
self.seed = seed | ||
|
||
if area_factor == 0.0 and aspect_ratio_factor == 0.0: | ||
AdityaKane2001 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
warnings.warn( | ||
"RandomResizedCrop received both `area_factor=0.0` and " | ||
"`aspect_ratio_factor=0.0`. As a result, the layer will perform no " | ||
"augmentation.") | ||
|
||
def get_random_transformation(self, | ||
image=None, | ||
label=None, | ||
bounding_box=None, | ||
**kwargs): | ||
area_factor = self.area_factor() | ||
AdityaKane2001 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
aspect_ratio = self.aspect_ratio_factor() | ||
|
||
new_height = tf.clip_by_value( | ||
tf.sqrt(area_factor / aspect_ratio), 0.0, | ||
1.0) # to avoid unwanted/unintuitive effects | ||
new_width = tf.clip_by_value(tf.sqrt(area_factor * aspect_ratio), 0.0, | ||
AdityaKane2001 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
1.0) | ||
|
||
height_offset = self._random_generator.random_uniform( | ||
(), | ||
minval=tf.minimum(0.0, 1.0 - new_height), | ||
maxval=tf.maximum(0.0, 1.0 - new_height), | ||
dtype=tf.float32, | ||
) | ||
|
||
width_offset = self._random_generator.random_uniform( | ||
(), | ||
minval=tf.minimum(0.0, 1.0 - new_width), | ||
maxval=tf.maximum(0.0, 1.0 - new_width), | ||
dtype=tf.float32, | ||
) | ||
|
||
y1 = height_offset | ||
y2 = height_offset + new_height | ||
x1 = width_offset | ||
x2 = width_offset + new_width | ||
|
||
return [[y1, x1, y2, x2]] | ||
|
||
def call(self, inputs, training=True): | ||
|
||
if training: | ||
return super().call(inputs, training) | ||
else: | ||
inputs = self._ensure_inputs_are_compute_dtype(inputs) | ||
inputs, is_dict, use_targets = self._format_inputs(inputs) | ||
output = inputs | ||
# self._resize() returns valid results for both batched and | ||
# unbatched | ||
output["images"] = self._resize(inputs["images"]) | ||
return self._format_output(output, is_dict, use_targets) | ||
|
||
def augment_image(self, image, transformation, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you provide a demo script so we can visualize these results? I'd like to confirm that the layer is doing what we think it is. Please make use of the elephant photo and stack it 9 times like I have done here: https://keras.io/guides/keras_cv/custom_image_augmentations/ By using the same photo we can perceptually verify the outputs. |
||
image = tf.expand_dims(image, axis=0) | ||
LukeWood marked this conversation as resolved.
Show resolved
Hide resolved
|
||
boxes = transformation | ||
AdityaKane2001 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# See bit.ly/tf_crop_resize for more details | ||
AdityaKane2001 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
augmented_image = tf.image.crop_and_resize( | ||
AdityaKane2001 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
image, # image shape: [B, H, W, C] | ||
boxes, # boxes: (1, 4) in this case; represents area | ||
# to be cropped from the original image | ||
[0], # box_indices: maps boxes to images along batch axis | ||
# [0] since there is only one image | ||
self.target_size, # output size | ||
AdityaKane2001 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
return tf.squeeze(augmented_image, axis=0) | ||
|
||
def _resize(self, image): | ||
outputs = tf.keras.preprocessing.image.smart_resize( | ||
image, self.target_size) | ||
# smart_resize will always output float32, so we need to re-cast. | ||
return tf.cast(outputs, self.compute_dtype) | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update({ | ||
"target_size": self.target_size, | ||
"area_factor": self.area_factor, | ||
"aspect_ratio_factor": self.aspect_ratio_factor, | ||
"interpolation": self.interpolation, | ||
"seed": self.seed, | ||
}) | ||
return config |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Copyright 2022 The KerasCV Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import tensorflow as tf | ||
|
||
from keras_cv.layers import preprocessing | ||
|
||
|
||
class RandomResizedCropTest(tf.test.TestCase): | ||
height, width = 300, 300 | ||
batch_size = 4 | ||
target_size = (224, 224) | ||
def test_train_augments_image(self): | ||
# Checks if original and augmented images are different | ||
|
||
input_image_shape = (self.batch_size, self.height, self.width, 3) | ||
image = tf.random.uniform(shape=input_image_shape) | ||
|
||
layer = preprocessing.RandomResizedCrop( | ||
target_size=self.target_size, area_factor=(0.08, 1.0) | ||
) | ||
output = layer(image, training=True) | ||
|
||
input_image_resized = tf.image.resize(image, self.target_size) | ||
|
||
self.assertNotAllClose(output, input_image_resized) | ||
|
||
def test_grayscale(self): | ||
input_image_shape = (self.batch_size, self.height, self.width, 1) | ||
image = tf.random.uniform(shape=input_image_shape) | ||
|
||
layer = preprocessing.RandomResizedCrop(target_size=self.target_size, | ||
area_factor=(0.8, 1.0)) | ||
output = layer(image, training=True) | ||
|
||
input_image_resized = tf.image.resize(image, self.target_size) | ||
|
||
self.assertAllEqual(output.shape, (4, 224, 224, 1)) | ||
self.assertNotAllClose(output, input_image_resized) | ||
|
||
def test_preserves_image(self): | ||
AdityaKane2001 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
image_shape = (self.batch_size, self.height, self.width, 3) | ||
image = tf.random.uniform(shape=image_shape) | ||
|
||
layer = preprocessing.RandomResizedCrop( | ||
target_size=self.target_size, area_factor=(0.8, 1.0) | ||
) | ||
|
||
input_resized = tf.image.resize(image, self.target_size) | ||
output = layer(image, training=False) | ||
|
||
self.assertAllClose(output, input_resized) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The more I think the more I think we don't want a default for this.
Also, lets rename this to
crop_area_factor
. We can recommendd 0.08 and 1.0, but a crop area of 0.08 is actually really aggressive for classification and might not be a sensible default there as it is for self supervised learning.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LukeWood In that case do we want to have a less aggressive number there? I had used 0.25 for training the smaller RegNets (200 MFLOP variant) and it worked fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the time I had experimented with other values as well and found 0.25 to be good for that case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we want a default for this because I don't think there will be a generic best value for this.
What I want to avoid is users of self supervised pass RandomResizedCrop() and it works fine, then classification users copy the code and get bad results because the default is set for one task (or vice versa!)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, makes sense.
@sayakpaul WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LukeWood has an interesting PoV here. I honestly haven't ever thought of default values at this depth. Luke seems to have a point.
But let's please do one thing -- write a comprehensive guide (preferably put it on keras.io) around this layer when it's out detailing the thought process and suggest some recommended values for different tasks. The reason why I'm stressing about this is that RRC is one of the most influential layers in the domain so we'd want to ensure our users have a good guide on using it.
WDYT? @AdityaKane2001 @LukeWood
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good.
I think this guide can be based upon seminal literature in vision. That will give a good idea about good defaults IMO.