From 3f6e17aef441c6f154f06d4068f1ed466a5d1305 Mon Sep 17 00:00:00 2001 From: Arno Eigenwillig Date: Fri, 27 Jan 2023 09:17:35 -0800 Subject: [PATCH] TF2.10+ is required for transform_values_after_pooling=True in MultiHeadAttentionConv, because it needs EinsumDense (and not the older experimental version). PiperOrigin-RevId: 505134783 --- .../models/multi_head_attention/layers.py | 13 ++++++++++++- .../models/multi_head_attention/layers_test.py | 15 +++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/tensorflow_gnn/models/multi_head_attention/layers.py b/tensorflow_gnn/models/multi_head_attention/layers.py index 71d521d9..53d9f1b8 100644 --- a/tensorflow_gnn/models/multi_head_attention/layers.py +++ b/tensorflow_gnn/models/multi_head_attention/layers.py @@ -172,7 +172,9 @@ class MultiHeadAttentionConv(tfgnn.keras.layers.AnyToAnyConvolutionBase): Setting this option pools inputs with attention coefficients, then applies the transformation. This is mathematically equivalent but can be faster or slower to compute, depending on the platform and the dataset. - IMPORANT: Toggling this option breaks checkpoint compatibility. + IMPORTANT: Toggling this option breaks checkpoint compatibility. + IMPORTANT: Setting this option requires TensorFlow 2.10 or greater, + because it uses `tf.keras.layers.EinsumDense`. """ def __init__( @@ -300,6 +302,15 @@ def __init__( else: self._w_sender_edge_to_value = None else: + # TODO(b/266868417): Remove when TF2.10+ is required by all of TF-GNN. + try: + _ = tf.keras.layers.EinsumDense + except AttributeError as e: + raise ValueError( + "MultiHeadAttentionConv(transform_values_after_pooling=True) " + "requires tf.keras.layers.EinsumDense from " + f"TensorFlow 2.10 or newer, got TensorFlow {tf.__version__}" + ) from e self._w_sender_pooled_to_value = tf.keras.layers.EinsumDense( equation="...hv,hvc->...hc", output_shape=(num_heads, per_head_channels), diff --git a/tensorflow_gnn/models/multi_head_attention/layers_test.py b/tensorflow_gnn/models/multi_head_attention/layers_test.py index f4286c4a..43fe5fd7 100644 --- a/tensorflow_gnn/models/multi_head_attention/layers_test.py +++ b/tensorflow_gnn/models/multi_head_attention/layers_test.py @@ -34,9 +34,20 @@ class ReloadModel(int, enum.Enum): class MultiHeadAttentionTest(tf.test.TestCase, parameterized.TestCase): + # TODO(b/266868417): Remove when TF2.10+ is required by all of TF-GNN. + def _skip_if_unsupported(self, transform_values_after_pooling=None): + """Skips test if TF is tool old for a requested option.""" + if transform_values_after_pooling: + if tf.__version__.startswith("2.8.") or tf.__version__.startswith("2.9."): + self.skipTest( + "MultiHeadAttentionConv(transform_values_after_pooling=True) " + f"requires TF 2.10+, got {tf.__version__}") + @parameterized.named_parameters(("", False), ("TransformAfter", True)) def testBasic(self, transform_values_after_pooling): """Tests that a single-headed MHA is correct given predefined weights.""" + self._skip_if_unsupported( + transform_values_after_pooling=transform_values_after_pooling) # NOTE: Many following tests use minor variations of the explicit # construction of weights and results introduced here. @@ -503,6 +514,8 @@ def testNoTransformKeys(self): @parameterized.named_parameters(("", False), ("TransformAfter", True)) def testMultihead(self, transform_values_after_pooling): """Extends testBasic with multiple attention heads.""" + self._skip_if_unsupported( + transform_values_after_pooling=transform_values_after_pooling) # The same test graph as in the testBasic above. gt_input = _get_test_bidi_cycle_graph( tf.constant([ @@ -595,6 +608,8 @@ def testMultihead(self, transform_values_after_pooling): ("RestoredKerasTransformAfter", ReloadModel.KERAS, True)) def testFullModel(self, reload_model, transform_values_after_pooling): """Tests MultiHeadAttentionHomGraphUpdate in a Model with edge input.""" + self._skip_if_unsupported( + transform_values_after_pooling=transform_values_after_pooling) # The same example as in the testBasic above, but with extra inputs # from edges. gt_input = _get_test_bidi_cycle_graph(