Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Random Resized Crop #499

Merged
merged 44 commits into from
Jun 29, 2022
Merged

Random Resized Crop #499

merged 44 commits into from
Jun 29, 2022

Conversation

AdityaKane2001
Copy link
Contributor

@AdityaKane2001 AdityaKane2001 commented Jun 16, 2022

@sayakpaul @LukeWood

Creating this draft PR to initiate discussions regarding RandomResizedCrop layer. Code is copied over from #457. Please let me know if you would prefer to start from scratch.

I will study and make relevant changes to the code shortly according to the comments in #457.

/auto Closes #131

@AdityaKane2001
Copy link
Contributor Author

@sayakpaul @LukeWood

Please feel free to review this PR now.

I went through the implementation and refactored the same. I believe this implementation can be easily done using tf.image.crop_resize as mentioned in the comment in #457. I have jotted down a gist of the same.

ImageProjectiveTransformV3 seems a bit of over the top to me, since the same result can be achieved with tf.image.crop_resize but with better readability. So I have replaced ImageProjectiveTransformV3 with tf.image.crop_resize. Here is a gist using the updated implementation.

Both tf.image.crop_resize and tf.raw_ops.ImageProjectiveTransformV3 are vectorized themselves. So batching and unbatching images and boxes (or transform) arguments is redundant. Is there some way to avoid doing that?

@AdityaKane2001 AdityaKane2001 marked this pull request as ready for review June 18, 2022 18:50
@sayakpaul
Copy link
Contributor

@AdityaKane2001 @LukeWood

If we're using code from #457 let's add the author of #457 as a co-author of the commit?

@AdityaKane2001
Copy link
Contributor Author

@sayakpaul @LukeWood

I don't understand what do we mean by having area_factor > 1.0. How would the image look like in that case? Torchvision simply outputs the entire image if area_factor > 1. Gist here.

My general approach to this PR is to make it consistent with the torchvision implementation. The earlier implementation in #457 was a bit different from it, both in terms of behavior and API signature.

If we're using code from #457 let's add the author of #457 as a co-author of the commit?

I had included code credits for the earlier PR in the initial commits, but I removed it later since I refactored a major chunk of the code. But if we will be going with preprocessing.transform, then I'll include the credits again.

keras_cv/layers/preprocessing/random_resized_crop.py Outdated Show resolved Hide resolved
keras_cv/layers/preprocessing/random_resized_crop.py Outdated Show resolved Hide resolved
keras_cv/layers/preprocessing/random_resized_crop.py Outdated Show resolved Hide resolved
keras_cv/layers/preprocessing/random_resized_crop.py Outdated Show resolved Hide resolved
keras_cv/layers/preprocessing/random_resized_crop_test.py Outdated Show resolved Hide resolved
keras_cv/layers/serialization_test.py Outdated Show resolved Hide resolved
keras_cv/layers/serialization_test.py Outdated Show resolved Hide resolved
keras_cv/layers/serialization_test.py Outdated Show resolved Hide resolved
@sayakpaul
Copy link
Contributor

Regarding #499 (comment), how far Luke's comment in #499 (comment) impacts this?

I'd prefer using tf.image.crop_and_resize() since the code does read cleaner and crisper with it. Conceptually also, it becomes simpler.

@AdityaKane2001
Copy link
Contributor Author

Regarding #499 (comment), how far Luke's comment in #499 (comment) impacts this?

@LukeWood's comments point towards the direction of ImageProjectiveTransformV3. It is definitely more robust but comes at the cost of reduced readability. Also I'm not sure how will images with area_factor=2 look like, and whether that is required.

I'd prefer using tf.image.crop_and_resize() since the code does read cleaner and crisper with it. Conceptually also, it becomes simpler.

I agree.

@LukeWood
Copy link
Contributor

@sayakpaul @LukeWood

Please feel free to review this PR now.

I went through the implementation and refactored the same. I believe this implementation can be easily done using tf.image.crop_resize as mentioned in the comment in #457. I have jotted down a gist of the same.

ImageProjectiveTransformV3 seems a bit of over the top to me, since the same result can be achieved with tf.image.crop_resize but with better readability. So I have replaced ImageProjectiveTransformV3 with tf.image.crop_resize. Here is a gist using the updated implementation.

Both tf.image.crop_resize and tf.raw_ops.ImageProjectiveTransformV3 are vectorized themselves. So batching and unbatching images and boxes (or transform) arguments is redundant. Is there some way to avoid doing that?

No there is not. But, under the hood vectorized_map will vectorize this for us.

@LukeWood
Copy link
Contributor

Ok @AdityaKane2001 you are right that crop first is smarter. I just did some additional thought work and they are indeed equivalent, but your way is more efficient.

carry on with the crop first API

@AdityaKane2001
Copy link
Contributor Author

@LukeWood

Thanks for the confirmation! In that case, area_factor > 2, fill_mode and fill_value no longer make sense. I think conceptually the PR remains the same, I will do rest of the changes and get back.

/cc @sayakpaul

@sayakpaul
Copy link
Contributor

I see, are these values widely optimal across almost all use cases?

We could say that, yes. At least most of the pre-training schemes (supervised, self-supervised, semi-supervised, vision-text, etc.) I have studied these are the default values.

@AdityaKane2001
Copy link
Contributor Author

@sayakpaul @LukeWood @tanzhenyu

I think I have made all the requested changes. However, if I have left out any of them by mistake, please let me know. If everything's alright, I'll run the formatter so that this can be merged.

Copy link
Contributor

@LukeWood LukeWood left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor changes

that of original image is sampled using this factor. Represents the
lower and upper bounds for the area relative to the original image
of the cropped image before resizing it to `target_size`.
`target_size`. Defaults to (0.08, 1.0).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The more I think the more I think we don't want a default for this.

Also, lets rename this to crop_area_factor. We can recommendd 0.08 and 1.0, but a crop area of 0.08 is actually really aggressive for classification and might not be a sensible default there as it is for self supervised learning.

Copy link
Contributor Author

@AdityaKane2001 AdityaKane2001 Jun 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LukeWood In that case do we want to have a less aggressive number there? I had used 0.25 for training the smaller RegNets (200 MFLOP variant) and it worked fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the time I had experimented with other values as well and found 0.25 to be good for that case.

Copy link
Contributor

@LukeWood LukeWood Jun 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want a default for this because I don't think there will be a generic best value for this.

What I want to avoid is users of self supervised pass RandomResizedCrop() and it works fine, then classification users copy the code and get bad results because the default is set for one task (or vice versa!)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, makes sense.

@sayakpaul WDYT?

Copy link
Contributor

@sayakpaul sayakpaul Jun 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LukeWood has an interesting PoV here. I honestly haven't ever thought of default values at this depth. Luke seems to have a point.

But let's please do one thing -- write a comprehensive guide (preferably put it on keras.io) around this layer when it's out detailing the thought process and suggest some recommended values for different tasks. The reason why I'm stressing about this is that RRC is one of the most influential layers in the domain so we'd want to ensure our users have a good guide on using it.

WDYT? @AdityaKane2001 @LukeWood

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.

I think this guide can be based upon seminal literature in vision. That will give a good idea about good defaults IMO.

lower and upper bounds for the area relative to the original image
of the cropped image before resizing it to `target_size`.
`target_size`. Defaults to (0.08, 1.0).
aspect_ratio_factor: (Optional) A tuple of two floats, a single float or
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also don't think we can take a default here because this will be surprising to users. That being said, a user SHOULD use this, so I suggest we provide no default and require this argument.

Copy link
Contributor Author

@AdityaKane2001 AdityaKane2001 Jun 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LukeWood In the case of a single float or integer, what should be the max (or min) value for aspect_ratio_factor? Currently we are doing this:

if isinstance(aspect_ratio_factor, tuple):
  min_aspect_ratio = min(aspect_ratio_factor)
  max_aspect_ratio = max(aspect_ratio_factor)
elif isinstance(aspect_ratio_factor, core.FactorSampler):
  pass
else:
  raise ValueError(
      "Expected `aspect_ratio` to be tuple or FactorSampler. Received "
      f"aspect_ratio_factor={aspect_ratio_factor}."
  )

If we want to add a case for single ints and floats, we'll need to decide on a predefined max or min. Same goes for crop_area_factor. Or do we want to ignore single ints or floats completely?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets just drop integer and float support for aspect ratio factor and make them pass a tuple.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks for clarifying!

keras_cv/layers/preprocessing/random_resized_crop.py Outdated Show resolved Hide resolved
@sayakpaul
Copy link
Contributor

@LukeWood when you get a moment, could you please assign the "GSOC 2022" label?

keras_cv/layers/preprocessing/random_resized_crop.py Outdated Show resolved Hide resolved
keras_cv/layers/preprocessing/random_resized_crop.py Outdated Show resolved Hide resolved
keras_cv/layers/preprocessing/random_resized_crop.py Outdated Show resolved Hide resolved
output["images"] = self._resize(inputs["images"])
return self._format_output(output, is_dict, use_targets)

def augment_image(self, image, transformation, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you provide a demo script so we can visualize these results? I'd like to confirm that the layer is doing what we think it is. Please make use of the elephant photo and stack it 9 times like I have done here:

https://keras.io/guides/keras_cv/custom_image_augmentations/

By using the same photo we can perceptually verify the outputs.

keras_cv/layers/preprocessing/random_resized_crop.py Outdated Show resolved Hide resolved
keras_cv/layers/preprocessing/random_resized_crop.py Outdated Show resolved Hide resolved
keras_cv/layers/preprocessing/random_resized_crop.py Outdated Show resolved Hide resolved
or len(target_size) != 2
or not isinstance(target_size[0], int)
or not isinstance(target_size[1], int)
or isinstance(target_size, int)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure we catch it, we still need checks for crop_area_factor and aspect_ratio_factor to prevent them from being single ints. Thanks!

Can you write a @parameterized.TestCase covering every instance of this error message?

@LukeWood LukeWood merged commit 23c8686 into keras-team:master Jun 29, 2022
@AdityaKane2001
Copy link
Contributor Author

@LukeWood I'll add the checks in a separate PR then.

ianstenbit pushed a commit to ianstenbit/keras-cv that referenced this pull request Aug 6, 2022
* Created random resized crop files

* Added test from keras-team#457

* Used `tf.image.crop_and_resize` instead of `ImageProjectionTransformV3`

* Minor bug

* Formatted

* Reformatted

* Update keras_cv/layers/preprocessing/random_resized_crop.py

Co-authored-by: Sayak Paul <[email protected]>

* Apply suggestions from code review

Co-authored-by: Sayak Paul <[email protected]>

* Doc changes

* Doc changes

* Made requested changes

* Reformatted and tested

* Update keras_cv/layers/preprocessing/random_resized_crop.py

Co-authored-by: Sayak Paul <[email protected]>

* Made requested changes

* Created random resized crop files

* Added test from keras-team#457

* Used `tf.image.crop_and_resize` instead of `ImageProjectionTransformV3`

* Minor bug

* Formatted

* Reformatted

* Doc changes

* Update keras_cv/layers/preprocessing/random_resized_crop.py

Co-authored-by: Sayak Paul <[email protected]>

* Apply suggestions from code review

Co-authored-by: Sayak Paul <[email protected]>

* Doc changes

* Made requested changes

* Reformatted and tested

* Luke edits

* remove merge conflict

* Fix broken test cases

* Made changes to rrc

* Minor changes

* Added checks

* Added checks

* Added checks for inputs

* Docstring updates

* Made requested changes

* Minor changes

* Added demo

* Formatted

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Luke Wood <[email protected]>
adhadse pushed a commit to adhadse/keras-cv that referenced this pull request Sep 17, 2022
* Created random resized crop files

* Added test from keras-team#457

* Used `tf.image.crop_and_resize` instead of `ImageProjectionTransformV3`

* Minor bug

* Formatted

* Reformatted

* Update keras_cv/layers/preprocessing/random_resized_crop.py

Co-authored-by: Sayak Paul <[email protected]>

* Apply suggestions from code review

Co-authored-by: Sayak Paul <[email protected]>

* Doc changes

* Doc changes

* Made requested changes

* Reformatted and tested

* Update keras_cv/layers/preprocessing/random_resized_crop.py

Co-authored-by: Sayak Paul <[email protected]>

* Made requested changes

* Created random resized crop files

* Added test from keras-team#457

* Used `tf.image.crop_and_resize` instead of `ImageProjectionTransformV3`

* Minor bug

* Formatted

* Reformatted

* Doc changes

* Update keras_cv/layers/preprocessing/random_resized_crop.py

Co-authored-by: Sayak Paul <[email protected]>

* Apply suggestions from code review

Co-authored-by: Sayak Paul <[email protected]>

* Doc changes

* Made requested changes

* Reformatted and tested

* Luke edits

* remove merge conflict

* Fix broken test cases

* Made changes to rrc

* Minor changes

* Added checks

* Added checks

* Added checks for inputs

* Docstring updates

* Made requested changes

* Minor changes

* Added demo

* Formatted

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Luke Wood <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement RandomResizedCrop layer
4 participants