Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw committed Feb 14, 2024
1 parent f274d1e commit 5dce72d
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
7 changes: 4 additions & 3 deletions keras_nlp/models/mistral/mistral_transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,10 @@ def _compute_self_attention_mask(
# Below is a workaround for `ops.triu` for Keras 2.
# TODO(tirthasheshpatel): Use `ops.triu` once Keras 2 support is removed.
# causal_mask = ops.triu(causal_mask_lower, k=-self.sliding_window)
i = ops.arange(output_length)[:, None] + cache_update_index
j = ops.arange(input_length)[None, :]
causal_mask_upper = ops.cast(i < j + self.sliding_window, "int32")
i = ops.arange(output_length, dtype="float32")[:, None]
i = i + ops.cast(cache_update_index, "float32")
j = ops.arange(input_length, dtype="float32")[None, :]
causal_mask_upper = i < j + self.sliding_window
causal_mask = ops.minimum(causal_mask_lower, causal_mask_upper)

return (
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/models/t5/t5_transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def call(
shape = ops.shape(hidden_states)
batch_size, length = shape[0], shape[1]
causal_mask = compute_causal_mask(batch_size, length, length)
attention_mask = ops.cast(attention_mask, "int32")
attention_mask = ops.cast(attention_mask, "bool")
attention_mask = causal_mask & attention_mask

x = hidden_states # Intermediate result.
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/samplers/beam_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def start(self, data):
data = tree.map_structure(self.create_beams, data)
# Setup the initial beam log-likelihoods.
log_probs = [[0.0] + [-1e9] * (self.num_beams - 1)]
log_probs = ops.array(log_probs)
log_probs = ops.array(log_probs, dtype="float32")
log_probs = self.flatten_beams(ops.repeat(log_probs, batch_size, 0))
return {**data, "log_probabilities": log_probs}

Expand Down

0 comments on commit 5dce72d

Please sign in to comment.