Skip to content

Commit

Permalink
Added Support for Returning Attention Scores in TransformerEncoder ca…
Browse files Browse the repository at this point in the history
…ll (keras-team#1879)

* Added: Return attention scores argument to transformer encoder

* Added: docstring for return_attention_scores and added a test to chek the working of the argument

* Fixed: Test case by removing print stmts and using self.assertAllEqual

* Fixed: Linting
  • Loading branch information
anirudhr20 authored and ushareng committed Oct 24, 2024
1 parent ed035d3 commit 150fae2
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 7 deletions.
35 changes: 28 additions & 7 deletions keras_hub/src/layers/modeling/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,12 @@ def build(self, inputs_shape):
self.built = True

def call(
self, inputs, padding_mask=None, attention_mask=None, training=None
self,
inputs,
padding_mask=None,
attention_mask=None,
training=None,
return_attention_scores=False,
):
"""Forward pass of the TransformerEncoder.
Expand All @@ -185,6 +190,7 @@ def call(
[batch_size, sequence_length, sequence_length].
training: a boolean indicating whether the layer should behave in
training mode or in inference mode.
return_attention_scores: a boolean indicating whether the output should be `(attention_output, attention_scores)` if `True` or `attention_output` if `False`. Defaults to `False`.
Returns:
A Tensor of the same shape as the `inputs`.
Expand All @@ -200,12 +206,24 @@ def call(
residual = x
if self.normalize_first:
x = self._self_attention_layer_norm(x)
x = self._self_attention_layer(
query=x,
value=x,
attention_mask=self_attention_mask,
training=training,
)

if return_attention_scores:
x, attention_scores = self._self_attention_layer(
query=x,
value=x,
attention_mask=self_attention_mask,
return_attention_scores=return_attention_scores,
training=training,
)
return x, attention_scores
else:
x = self._self_attention_layer(
query=x,
value=x,
attention_mask=self_attention_mask,
training=training,
)

x = self._self_attention_dropout(x, training=training)
x = x + residual
if not self.normalize_first:
Expand All @@ -222,6 +240,9 @@ def call(
if not self.normalize_first:
x = self._feedforward_layer_norm(x)

if return_attention_scores:
return x, attention_scores

return x

def get_config(self):
Expand Down
11 changes: 11 additions & 0 deletions keras_hub/src/layers/modeling/transformer_encoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,14 @@ def test_mask_propagation(self):
inputs._keras_mask = mask
outputs = encoder(inputs)
self.assertAllEqual(outputs._keras_mask, mask)

def test_attention_scores(self):
encoder = TransformerEncoder(intermediate_dim=4, num_heads=2)
inputs = random.uniform(shape=[1, 4, 6])
outputs, attention_scores = encoder(
inputs, return_attention_scores=True
)
self.assertAllEqual(outputs.shape, inputs.shape)

# attention scores shape (batch_size, num_of_attn_heads, seq_length, seq_length)
self.assertAllEqual(attention_scores.shape, [1, 2, 4, 4])

0 comments on commit 150fae2

Please sign in to comment.