diff --git a/keras_hub/src/layers/modeling/transformer_encoder.py b/keras_hub/src/layers/modeling/transformer_encoder.py index 8d3fb0f950..5ed121e457 100644 --- a/keras_hub/src/layers/modeling/transformer_encoder.py +++ b/keras_hub/src/layers/modeling/transformer_encoder.py @@ -170,7 +170,12 @@ def build(self, inputs_shape): self.built = True def call( - self, inputs, padding_mask=None, attention_mask=None, training=None + self, + inputs, + padding_mask=None, + attention_mask=None, + training=None, + return_attention_scores=False, ): """Forward pass of the TransformerEncoder. @@ -185,6 +190,7 @@ def call( [batch_size, sequence_length, sequence_length]. training: a boolean indicating whether the layer should behave in training mode or in inference mode. + return_attention_scores: a boolean indicating whether the output should be `(attention_output, attention_scores)` if `True` or `attention_output` if `False`. Defaults to `False`. Returns: A Tensor of the same shape as the `inputs`. @@ -200,12 +206,24 @@ def call( residual = x if self.normalize_first: x = self._self_attention_layer_norm(x) - x = self._self_attention_layer( - query=x, - value=x, - attention_mask=self_attention_mask, - training=training, - ) + + if return_attention_scores: + x, attention_scores = self._self_attention_layer( + query=x, + value=x, + attention_mask=self_attention_mask, + return_attention_scores=return_attention_scores, + training=training, + ) + return x, attention_scores + else: + x = self._self_attention_layer( + query=x, + value=x, + attention_mask=self_attention_mask, + training=training, + ) + x = self._self_attention_dropout(x, training=training) x = x + residual if not self.normalize_first: @@ -222,6 +240,9 @@ def call( if not self.normalize_first: x = self._feedforward_layer_norm(x) + if return_attention_scores: + return x, attention_scores + return x def get_config(self): diff --git a/keras_hub/src/layers/modeling/transformer_encoder_test.py b/keras_hub/src/layers/modeling/transformer_encoder_test.py index c4763d3763..0f12a0920b 100644 --- a/keras_hub/src/layers/modeling/transformer_encoder_test.py +++ b/keras_hub/src/layers/modeling/transformer_encoder_test.py @@ -95,3 +95,14 @@ def test_mask_propagation(self): inputs._keras_mask = mask outputs = encoder(inputs) self.assertAllEqual(outputs._keras_mask, mask) + + def test_attention_scores(self): + encoder = TransformerEncoder(intermediate_dim=4, num_heads=2) + inputs = random.uniform(shape=[1, 4, 6]) + outputs, attention_scores = encoder( + inputs, return_attention_scores=True + ) + self.assertAllEqual(outputs.shape, inputs.shape) + + # attention scores shape (batch_size, num_of_attn_heads, seq_length, seq_length) + self.assertAllEqual(attention_scores.shape, [1, 2, 4, 4])