Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Support for Returning Attention Scores in TransformerEncoder call #1879

Merged
merged 4 commits into from
Oct 21, 2024

Conversation

anirudhr20
Copy link
Contributor

Summary: This pull request introduces a new feature that adds support for optionally returning attention scores in the TransformerEncoder class. This is controlled by the return_attention_scores flag, which when set to True, returns both the output and the attention scores from the attention mechanism.

Changes Introduced:

  • Updated the call method in the TransformerEncoder to handle the return_attention_scores flag.
  • Refactored the code to ensure that the attention scores are computed and returned when required.
  • Updated the documentation in the call method to reflect the changes.
  • Added unit tests to ensure the correctness of this feature using pytest.

Testing:
Ran the unit tests to verify that the return_attention_scores flag works as expected.
return_attention_scores=True (verifies that attention scores are returned and the shapes are correct).
Related Issue: #1644

Copy link

google-cla bot commented Sep 25, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Copy link
Member

@SamanehSaadat SamanehSaadat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @anirudhr20! The PR looks great! I just left a couple of nit comments in the test.

outputs, attention_scores = encoder(
inputs, return_attention_scores=True
)
print(attention_scores)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you remove this print?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes will remove the print statement. Thanks for pointing it out!

print(attention_scores)
assert outputs.shape == inputs.shape
# attention scores shape (batch_size, num_of_attn_heads, seq_length, seq_length)
assert attention_scores.shape == [1, 2, 4, 4]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you use self.assertAllEqual instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks I have made the changes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making the changes!

Copy link
Member

@SamanehSaadat SamanehSaadat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

@SamanehSaadat SamanehSaadat merged commit 8a943da into keras-team:master Oct 21, 2024
7 checks passed
ushareng pushed a commit to ushareng/keras-nlp that referenced this pull request Oct 24, 2024
…ll (keras-team#1879)

* Added: Return attention scores argument to transformer encoder

* Added: docstring for return_attention_scores and added a test to chek the working of the argument

* Fixed: Test case by removing print stmts and using self.assertAllEqual

* Fixed: Linting
ushareng pushed a commit to ushareng/keras-nlp that referenced this pull request Oct 24, 2024
BytePairTokenizer must not split sequences of \n (keras-team#1910)

* fix for loading of special tokens in Llama tokenizer

* fix for Llama tokenizer which can have multiple end tokens

* bug fix

* adding some missing tokens to Llama3 tokenizer

* fixed tests and Llama3Tokenizer init.

* now loading correct eos_token config from Hugging Face checkpoint. Using hack for Keras checkpoint because it does not have this info

* fix for BytePairTokenizer to make Lllama3-instruct work in chat: \n\n sequences are significant in the chat template and must be preserved by the tokenizer

---------

Co-authored-by: Martin Görner <[email protected]>

fix for generation that never stops in Llama3-Instruct variants (keras-team#1904)

* fix for loading of special tokens in Llama tokenizer

* fix for Llama tokenizer which can have multiple end tokens

* bug fix

* adding some missing tokens to Llama3 tokenizer

* fixed tests and Llama3Tokenizer init.

* now loading correct eos_token config from Hugging Face checkpoint. Using hack for Keras checkpoint because it does not have this info

---------

Co-authored-by: Martin Görner <[email protected]>

fix failing JAX GPU test (keras-team#1911)

* fix tests

* fix test

Refactor `MMDiT`, add `ImageToImage` and `Inpaint` for SD3 (keras-team#1909)

* Refactor `MMDiT` and add `ImageToImage`

* Update model version

* Fix minor bugs.

* Add `Inpaint` for SD3.

* Fix warnings of MMDiT.

* Addcomment to Inpaint

* Simplify `MMDiT` implementation and info of `summary()`.

* Refactor `generate()` API of `TextToImage`, `ImageToImage` and `Inpaint`.

Minor bug fix (keras-team#1915)

Change to image_converter.image_size since it is a tuple and it's not a callable function.

[Mix Transformer] Add Presets for MiTB0...MiTB5 (keras-team#1893)

* add presets for mit

* add standin paths

* register presets in __init__.py

* fix op in overlapping patching and embedding, start adding conversion utils

* style

* add padding to MiT patchingandembedding

* update to support other presets

* update conversin script

* fix link for b5

* add cityscapes weights

* update presets

* update presets

* update conversion script to make directories

* use save_preset

* change name of output dir

* add preprocessor flow

* api gen and add preprocessor to mits

* conform to new image classifier style

* format

* resizing image converter -> ImageConverter

* address comments

refactoring

remove default resizing for vision backbones (keras-team#1916)

* remove defailt resizing

* fix GPU test

Update VGG model to be compatible with HF and add conversion scripts (keras-team#1914)

Deeplab presets (keras-team#1918)

* add preset configurations for deeplabv3

* fix uri

* Add training details

update presets to point to the main Keras Kaggle page (keras-team#1921)

* update presets to point to the main keras page

* update mit path

Added test for the way BytePairTokenizer handles the \n\n sequence, which is important in Lama chat templates (keras-team#1912)

* added test for the way BytePairTokenizer handles the \n\n sequence, which is important in Lama chat templates

* un commented the test lines that were commented by mistake

* fixed linter errors

Task models fix (keras-team#1922)

* added test for the way BytePairTokenizer handles the \n\n sequence, which is important in Lama chat templates

* fix for wrongly configured task models LLama, PaliGemma, Mistral and Phi3 + test

* comments

* un commented the test lines that were commented by mistake

* fixed linter errors

adding option strip_prompt to generate() (keras-team#1913)

* added test for the way BytePairTokenizer handles the \n\n sequence, which is important in Lama chat templates

* un commented the test lines that were commented by mistake

* fixed linter errors

* added options strip_prompt to generate()

* fix for tensorflow: the compiled version of generate(strip_prompt=True) now works + code refactoring to make it more understandable

* added test for generate(strip_prompt=True)

* minor edits

Layout map for Llama (keras-team#1923)

* added test for the way BytePairTokenizer handles the \n\n sequence, which is important in Lama chat templates

* un commented the test lines that were commented by mistake

* fixed linter errors

* added default layout map for Llama

* minor fixes in tests

Update deeplab_v3_presets.py (keras-team#1924)

Add paths to get SAM weights from (keras-team#1925)

Two fixes for image resizing in preprocessing (keras-team#1927)

1. Properly display when are not resizing the input image in
   `model.summary()`
2. Allow setting the `image_size` directly on a preprocessing layer.

2. is just to allow a more consistent way to set the input shape
across tasks. We now have:

```python
text_classifier = keras_hub.models.TextClassifer.from_preset(
    "bert_base_en",
)
text_classifier.preprocessor.sequence_length = 256

image_classifier = keras_hub.models.TextClassifer.from_preset(
    "bert_base_en",
)
image_classifier.preprocessor.image_size = (256, 256)

multi_modal_lm = keras_hub.models.CausalLM.from_preset(
    "some_preset",
)
multi_modal_lm.preprocessor.sequence_length = 256
multi_modal_lm.preprocessor.image_size = (256, 256)
```

add back default image resizing (keras-team#1926)

Update deeplab_v3_presets.py (keras-team#1928)

* Update deeplab_v3_presets.py

* Update deeplab_v3_presets.py

Update PaliGemma to remove `include_rescaling` arg (keras-team#1917)

* update PaliGemma

* update conversion script

* fix GPU tests

fix path (keras-team#1929)

* fix path

* nit

Fix paligemma checkpoint conversion script (keras-team#1931)

* add back default image resizing

* fix bug in image converter

* fix paligemma checkpoint conversion file

* fix preset name

* remove debug code

* revert unintended changes

update preset path to point to latest version of models (keras-team#1932)

Update sdv3 path (keras-team#1934)

update sam docstring to show correct backbone in docstring (keras-team#1936)

Convert input dict to tensors during train_on_batch (keras-team#1919)

Register VGG presets. (keras-team#1935)

* register vgg preset

* nit

* nit

* nit

Add ResNetVD presets (keras-team#1897)

* Add ResNetVD presets

* Updated Kaggle handles

* Add weight conversion script for ResNet_vd

* Add usage

rebase conflict resolved

conflict resolve

Update sam_presets.py (keras-team#1940)

Update vit_det_backbone.py (keras-team#1941)

fix gpu test (keras-team#1939)

* fix gpu test

* cast input

* update dtype

* change to resnet preset

* remove arg

Added Support for Returning Attention Scores in TransformerEncoder call (keras-team#1879)

* Added: Return attention scores argument to transformer encoder

* Added: docstring for return_attention_scores and added a test to chek the working of the argument

* Fixed: Test case by removing print stmts and using self.assertAllEqual

* Fixed: Linting

Mark preset tests as large (keras-team#1942)

* fix tests

* fix test

* Update preset_utils_test.py

version bump to 0.17.0.dev0 (keras-team#1944)

Update stable_diffusion_3_presets.py (keras-team#1946)

[Semantic Segmentation] - Add SegFormer Architecture, Weight Conversion Script and Presets (keras-team#1883)

* initial commit - tf-based, kcv

* porting to keras_hub structure - removing aliases, presets, etc.

* enable instantiation of segformer backbone with custom MiT backbone

* remove num_classes from backbone

* fix input

* add imports to __init__

* update preset

* update docstrings

* add basic tests

* remove redundant imports

* update docstrings

* remove unused import

* running api_gen.py

* undo refactor of mit

* update docstrings

* add presets for mit

* add standin paths

* add presets for segformer backbone

* register presets in __init__.py

* addressing comments

* addressing comments

* addressing comments

* update most tests

* add remaining tests

* remove copyright

* fix test

* override from_config

* fix op in overlapping patching and embedding, start adding conversion utils

* style

* add padding to MiT patchingandembedding

* update to support other presets

* update conversin script

* fix link for b5

* add cityscapes weights

* update presets

* update presets

* update conversion script to make directories

* use save_preset

* change name of output dir

* add preprocessor flow

* api gen and add preprocessor to mits

* conform to new image classifier style

* format

* resizing image converter -> ImageConverter

* merge mit branch into segformer branch

* add preprocessor and converter

* address comments

* clarify backbone usage

* add conversion script

* numerical equivalence changes

* fix numerical inaccuracies

* update conversion script

* update conversion script

* remove transpose

* add preprocessor to segformer class

* fix preset path

* update test shape

* update presets

* update test shape

* expand docstrings

* add rescaling and normalization to preprocessor

* remove backbone presets, remove copyrights, remove backbone cls from segmenter

* remove copyright and unused import

* apply same transformation to masks as input images

* fix import

* fix shape in tests

Update readme (keras-team#1949)

* Update README.md

* Update README.md

Update llama_backbone.py docstring (keras-team#1950)

Update path (keras-team#1953)

Update preset path for keras.io.

There is no LLaMA2 in keras.io https://keras.io/api/keras_hub/models/llama2

This is the actual link:
https://keras.io/api/keras_hub/models/llama2

For Vicuna it does not have it's own model direcotry, since it is also the part of Llama,, updated the path.

Update SD3 init parameters (replacing `height`, `width` with `image_shape`) (keras-team#1951)

* Replace SD3 `height` and `width` with `image_shape`

* Update URI

* Revert comment

* Update SD3 handle

* Replace `height` and `width` with `image_shape`

* Update docstrings

* Fix CI

Update docstring (keras-team#1954)

AudioConverter is registered as "keras_hub.layers.WhisperAudioConverter" and not as part of models.

 updated Mobilenet backbone to match it with torch implementation

timm script added

checkpoint conversion added

Refactoring
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants