Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add benchmarking script, update API_DESIGN.md to reflect the results. #161

Merged
merged 5 commits into from
Mar 7, 2022

Conversation

LukeWood
Copy link
Contributor

@LukeWood LukeWood commented Mar 3, 2022

No description provided.

.github/API_DESIGN.md Outdated Show resolved Hide resolved
"""

images = []
for aug in [VectorizedRandomCutout, VMapRandomCutout, MapFnRandomCutout]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be useful to also have numbers with XLA compilation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#165 lets do in a follow up.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(but I do think so)

Copy link
Contributor

@bhack bhack Mar 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this graph I've only temp enabled the XLA for the VMap class (TF 2.8):

immagine

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case for Vectorized and VMap:

immagine

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case for all the 3 classes:
immagine

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need to consider the general discussion we had at #146 (comment)

As this one is a special case that is running fine with vectorized_map/fallback_to_while_loop=False.

Also, this simple case doesn't have any specific not JITable ops like e.g. ImageProjectiveTransformV3 (E.g. for rotation) or tf.bincount (see #141 (comment))

@LukeWood LukeWood merged commit 9fd9885 into master Mar 7, 2022
"""
import time

import matplotlib.pyplot as plt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

`BaseImageAugmentationLayer` requires you to implement augmentations in an
image-wise basis instead of using a vectorized approach. This design choice
was based made on the results found in the
[vectorization\_strategy\_benchmark.py](../benchmarks/vectorization_strategy_benchmark.py)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How we are going to run this from command-line?

Copy link
Contributor

@bhack bhack Mar 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this was not isolated as a performance test we need to tell to the devs to:

 pip install -e .

And the to pip uninstall keras-cv

plt.subplot(3, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.axis("off")
plt.show()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not going to work well in containers like in our official keras dev container. Isn't better to default plot on file?

@LukeWood LukeWood deleted the design_guidelines branch March 31, 2022 18:40
ianstenbit pushed a commit to ianstenbit/keras-cv that referenced this pull request Aug 6, 2022
…keras-team#161)

* add benchmarking script, update API_DESIGN.md to reflect the results.

* Format benchmarks

* Fix lint

* add () to call method

* update api design guidelines
adhadse pushed a commit to adhadse/keras-cv that referenced this pull request Sep 17, 2022
…keras-team#161)

* add benchmarking script, update API_DESIGN.md to reflect the results.

* Format benchmarks

* Fix lint

* add () to call method

* update api design guidelines
freedomtan pushed a commit to freedomtan/keras-cv that referenced this pull request Jul 20, 2023
* Adds unit normalization and tests

* Adds layer normalization and initial tests

* Fixes formatting in docstrings

* Fix type issues for JAX

* Fix nits

* Initial stash for group_normalization and spectral_normalization

* Adds spectral normalization and tests

* Adds group normalization and tests

* Formatting fixes

* Fix small nit in docstring

* Fix docstring and tests

* Adds RandomContrast and associated tests

* Remove arithmetic comment

* Adds RandomBrightness and tests

* Fix docstring and format

* Fix nits and add backend generator

* Inlines random_contrast helper

* Add bincount op

* Add CategoryEncoding layer and tests

* Fix formatting

* Fix JAX issues

* Fix JAX bincount

* Formatting and small fix

* Fix nits and docstrings

* Add args to bincount op test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants