From 5dce72d8214ef5055c233d99aa3e036109585a24 Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Tue, 13 Feb 2024 17:51:30 -0800 Subject: [PATCH] fixes --- keras_nlp/models/mistral/mistral_transformer_decoder.py | 7 ++++--- keras_nlp/models/t5/t5_transformer_layer.py | 2 +- keras_nlp/samplers/beam_sampler.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/keras_nlp/models/mistral/mistral_transformer_decoder.py b/keras_nlp/models/mistral/mistral_transformer_decoder.py index 7c90ab91b9..209cd0fd17 100644 --- a/keras_nlp/models/mistral/mistral_transformer_decoder.py +++ b/keras_nlp/models/mistral/mistral_transformer_decoder.py @@ -214,9 +214,10 @@ def _compute_self_attention_mask( # Below is a workaround for `ops.triu` for Keras 2. # TODO(tirthasheshpatel): Use `ops.triu` once Keras 2 support is removed. # causal_mask = ops.triu(causal_mask_lower, k=-self.sliding_window) - i = ops.arange(output_length)[:, None] + cache_update_index - j = ops.arange(input_length)[None, :] - causal_mask_upper = ops.cast(i < j + self.sliding_window, "int32") + i = ops.arange(output_length, dtype="float32")[:, None] + i = i + ops.cast(cache_update_index, "float32") + j = ops.arange(input_length, dtype="float32")[None, :] + causal_mask_upper = i < j + self.sliding_window causal_mask = ops.minimum(causal_mask_lower, causal_mask_upper) return ( diff --git a/keras_nlp/models/t5/t5_transformer_layer.py b/keras_nlp/models/t5/t5_transformer_layer.py index 697af20899..c82e7e77e9 100644 --- a/keras_nlp/models/t5/t5_transformer_layer.py +++ b/keras_nlp/models/t5/t5_transformer_layer.py @@ -131,7 +131,7 @@ def call( shape = ops.shape(hidden_states) batch_size, length = shape[0], shape[1] causal_mask = compute_causal_mask(batch_size, length, length) - attention_mask = ops.cast(attention_mask, "int32") + attention_mask = ops.cast(attention_mask, "bool") attention_mask = causal_mask & attention_mask x = hidden_states # Intermediate result. diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index ed68ef0e4b..92b9e4b8ae 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -47,7 +47,7 @@ def start(self, data): data = tree.map_structure(self.create_beams, data) # Setup the initial beam log-likelihoods. log_probs = [[0.0] + [-1e9] * (self.num_beams - 1)] - log_probs = ops.array(log_probs) + log_probs = ops.array(log_probs, dtype="float32") log_probs = self.flatten_beams(ops.repeat(log_probs, batch_size, 0)) return {**data, "log_probabilities": log_probs}