Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Support for Returning Attention Scores in TransformerEncoder call #1879

Merged
merged 4 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions keras_hub/src/layers/modeling/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,12 @@ def build(self, inputs_shape):
self.built = True

def call(
self, inputs, padding_mask=None, attention_mask=None, training=None
self,
inputs,
padding_mask=None,
attention_mask=None,
training=None,
return_attention_scores=False,
):
"""Forward pass of the TransformerEncoder.

Expand All @@ -199,6 +204,7 @@ def call(
[batch_size, sequence_length, sequence_length].
training: a boolean indicating whether the layer should behave in
training mode or in inference mode.
return_attention_scores: a boolean indicating whether the output should be `(attention_output, attention_scores)` if `True` or `attention_output` if `False`. Defaults to `False`.

Returns:
A Tensor of the same shape as the `inputs`.
Expand All @@ -214,12 +220,24 @@ def call(
residual = x
if self.normalize_first:
x = self._self_attention_layer_norm(x)
x = self._self_attention_layer(
query=x,
value=x,
attention_mask=self_attention_mask,
training=training,
)

if return_attention_scores:
x, attention_scores = self._self_attention_layer(
query=x,
value=x,
attention_mask=self_attention_mask,
return_attention_scores=return_attention_scores,
training=training,
)
return x, attention_scores
else:
x = self._self_attention_layer(
query=x,
value=x,
attention_mask=self_attention_mask,
training=training,
)

x = self._self_attention_dropout(x, training=training)
x = x + residual
if not self.normalize_first:
Expand All @@ -236,6 +254,9 @@ def call(
if not self.normalize_first:
x = self._feedforward_layer_norm(x)

if return_attention_scores:
return x, attention_scores

return x

def get_config(self):
Expand Down
11 changes: 11 additions & 0 deletions keras_hub/src/layers/modeling/transformer_encoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,14 @@ def test_mask_propagation(self):
inputs._keras_mask = mask
outputs = encoder(inputs)
self.assertAllEqual(outputs._keras_mask, mask)

def test_attention_scores(self):
encoder = TransformerEncoder(intermediate_dim=4, num_heads=2)
inputs = random.uniform(shape=[1, 4, 6])
outputs, attention_scores = encoder(
inputs, return_attention_scores=True
)
print(attention_scores)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you remove this print?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes will remove the print statement. Thanks for pointing it out!

assert outputs.shape == inputs.shape
# attention scores shape (batch_size, num_of_attn_heads, seq_length, seq_length)
assert attention_scores.shape == [1, 2, 4, 4]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you use self.assertAllEqual instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks I have made the changes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making the changes!

Loading