Skip to content

Commit

Permalink
Merge branch 'master' of github.com:keras-team/keras
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Jan 29, 2024
2 parents 4766e80 + 419973e commit f7401d4
Showing 1 changed file with 117 additions and 0 deletions.
117 changes: 117 additions & 0 deletions keras/layers/attention/attention_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np

from keras import layers
from keras import ops
from keras import testing


Expand Down Expand Up @@ -205,6 +206,7 @@ def test_attention_compute_mask_does_not_return_none_with_valid_mask(self):
valid_mask = np.array([True, False, True])
mask = [valid_mask, np.array([False, True, False])]
computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask)
computed_mask = ops.convert_to_numpy(computed_mask)
self.assertIsNotNone(
computed_mask,
"compute_mask should not return None with a valid mask",
Expand All @@ -221,7 +223,122 @@ def test_attention_compute_mask_returns_correct_tensor_with_valid_mask(
valid_mask = np.array([True, False, True])
mask = [valid_mask, np.array([False, True, False])]
computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask)
computed_mask = ops.convert_to_numpy(computed_mask)
self.assertTrue(
np.array_equal(computed_mask, valid_mask),
"compute_mask did not return the correct mask tensor",
)

def test_attention_compute_mask_returns_correct_tensor_with_all_true_mask(
self,
):
layer = layers.Attention()
dummy_inputs = [np.ones((2, 3, 4)), np.ones((2, 4, 4))]
valid_mask = np.array([True, True, True])
mask = [valid_mask, np.array([True, True, True])]
computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask)
computed_mask = ops.convert_to_numpy(computed_mask)
expected_mask = np.array([True, True, True])
self.assertTrue(
np.array_equal(computed_mask, expected_mask),
"compute_mask did not return the correct mask tensor",
)

def test_attention_compute_mask_returns_correct_tensor_with_all_false_mask(
self,
):
layer = layers.Attention()
dummy_inputs = [np.ones((2, 3, 4)), np.ones((2, 4, 4))]
valid_mask = np.array([False, False, False])
mask = [valid_mask, np.array([False, False, False])]
computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask)
computed_mask = ops.convert_to_numpy(computed_mask)
expected_mask = np.array([False, False, False])
self.assertTrue(
np.array_equal(computed_mask, expected_mask),
"compute_mask did not return the correct mask tensor",
)

def test_attention_compute_mask_with_tolerance_1e_3(self):
layer = layers.Attention()
dummy_inputs = [np.ones((2, 3, 4)), np.ones((2, 4, 4))]
valid_mask = np.array([1.0, 0.0, 1.0], dtype=float)
mask = [valid_mask, np.array([0.0, 1.0, 0.0], dtype=float)]
computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask)
computed_mask = ops.convert_to_numpy(computed_mask)
expected_mask = valid_mask
self.assertTrue(
np.allclose(computed_mask, expected_mask, atol=1e-3),
"Incorrect mask tensor within tolerance 1e-3",
)

def test_attention_compute_mask_with_tolerance_1e_5(self):
layer = layers.Attention()
dummy_inputs = [np.ones((2, 3, 4)), np.ones((2, 4, 4))]
valid_mask = np.array([1.0, 0.0, 1.0], dtype=float)
mask = [valid_mask, np.array([0.0, 1.0, 0.0], dtype=float)]
computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask)
computed_mask = ops.convert_to_numpy(computed_mask)
expected_mask = valid_mask
self.assertTrue(
np.allclose(computed_mask, expected_mask, atol=1e-5),
"Incorrect mask tensor within tolerance 1e-5",
)

def test_attention_compute_mask_with_tolerance_1e_7(self):
layer = layers.Attention()
dummy_inputs = [np.ones((2, 3, 4)), np.ones((2, 4, 4))]
valid_mask = np.array([1.0, 0.0, 1.0], dtype=float)
mask = [valid_mask, np.array([0.0, 1.0, 0.0], dtype=float)]
computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask)
computed_mask = ops.convert_to_numpy(computed_mask)
expected_mask = valid_mask
self.assertTrue(
np.allclose(computed_mask, expected_mask, atol=1e-7),
"Incorrect mask tensor within tolerance 1e-7 ",
)

def test_attention_compute_mask_with_single_element_masks(self):
layer = layers.Attention()
dummy_inputs = [np.ones((2, 3, 4)), np.ones((2, 4, 4))]
valid_mask = np.array([True])
mask = [valid_mask, np.array([False])]
computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask)
computed_mask = ops.convert_to_numpy(computed_mask)
expected_shape = (1,)
self.assertEqual(computed_mask.shape, expected_shape)

def test_attention_compute_mask_with_non_boolean_masks(self):
layer = layers.Attention()
dummy_inputs = [np.ones((2, 3, 4)), np.ones((2, 4, 4))]
valid_mask = np.array([1, 0, 1])
mask = [valid_mask, np.array([0, 1, 0])]
computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask)
computed_mask = ops.convert_to_numpy(computed_mask)
self.assertTrue(np.array_equal(computed_mask, valid_mask))

def test_attention_compute_mask_with_edge_case_masks(self):
layer = layers.Attention()
dummy_inputs = [np.ones((2, 3, 4)), np.ones((2, 4, 4))]
edge_case_masks = [
np.array([True, True, True]),
np.array([False, False, False]),
np.array([True, False, True]),
]
for mask in edge_case_masks:
computed_mask = layer.compute_mask(
inputs=dummy_inputs, mask=[mask, mask]
)
computed_mask = ops.convert_to_numpy(computed_mask)
self.assertTrue(np.array_equal(computed_mask, mask))

def test_attention_compute_mask_with_different_input_shapes(self):
layer = layers.Attention()
input_shapes = [(2, 3, 4), (3, 2, 5), (4, 1, 6)]
valid_mask = np.array([True, False, True])
for shape in input_shapes:
dummy_inputs = [np.ones(shape), np.ones(shape)]
mask = [valid_mask, np.array([False, True, False])]
computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask)
computed_mask = ops.convert_to_numpy(computed_mask)
self.assertTrue(np.array_equal(computed_mask, valid_mask))

0 comments on commit f7401d4

Please sign in to comment.