diff --git a/tensorflow_gnn/models/multi_head_attention/layers.py b/tensorflow_gnn/models/multi_head_attention/layers.py index 71d521d9..53d9f1b8 100644 --- a/tensorflow_gnn/models/multi_head_attention/layers.py +++ b/tensorflow_gnn/models/multi_head_attention/layers.py @@ -172,7 +172,9 @@ class MultiHeadAttentionConv(tfgnn.keras.layers.AnyToAnyConvolutionBase): Setting this option pools inputs with attention coefficients, then applies the transformation. This is mathematically equivalent but can be faster or slower to compute, depending on the platform and the dataset. - IMPORANT: Toggling this option breaks checkpoint compatibility. + IMPORTANT: Toggling this option breaks checkpoint compatibility. + IMPORTANT: Setting this option requires TensorFlow 2.10 or greater, + because it uses `tf.keras.layers.EinsumDense`. """ def __init__( @@ -300,6 +302,15 @@ def __init__( else: self._w_sender_edge_to_value = None else: + # TODO(b/266868417): Remove when TF2.10+ is required by all of TF-GNN. + try: + _ = tf.keras.layers.EinsumDense + except AttributeError as e: + raise ValueError( + "MultiHeadAttentionConv(transform_values_after_pooling=True) " + "requires tf.keras.layers.EinsumDense from " + f"TensorFlow 2.10 or newer, got TensorFlow {tf.__version__}" + ) from e self._w_sender_pooled_to_value = tf.keras.layers.EinsumDense( equation="...hv,hvc->...hc", output_shape=(num_heads, per_head_channels), diff --git a/tensorflow_gnn/models/multi_head_attention/layers_test.py b/tensorflow_gnn/models/multi_head_attention/layers_test.py index f4286c4a..43fe5fd7 100644 --- a/tensorflow_gnn/models/multi_head_attention/layers_test.py +++ b/tensorflow_gnn/models/multi_head_attention/layers_test.py @@ -34,9 +34,20 @@ class ReloadModel(int, enum.Enum): class MultiHeadAttentionTest(tf.test.TestCase, parameterized.TestCase): + # TODO(b/266868417): Remove when TF2.10+ is required by all of TF-GNN. + def _skip_if_unsupported(self, transform_values_after_pooling=None): + """Skips test if TF is tool old for a requested option.""" + if transform_values_after_pooling: + if tf.__version__.startswith("2.8.") or tf.__version__.startswith("2.9."): + self.skipTest( + "MultiHeadAttentionConv(transform_values_after_pooling=True) " + f"requires TF 2.10+, got {tf.__version__}") + @parameterized.named_parameters(("", False), ("TransformAfter", True)) def testBasic(self, transform_values_after_pooling): """Tests that a single-headed MHA is correct given predefined weights.""" + self._skip_if_unsupported( + transform_values_after_pooling=transform_values_after_pooling) # NOTE: Many following tests use minor variations of the explicit # construction of weights and results introduced here. @@ -503,6 +514,8 @@ def testNoTransformKeys(self): @parameterized.named_parameters(("", False), ("TransformAfter", True)) def testMultihead(self, transform_values_after_pooling): """Extends testBasic with multiple attention heads.""" + self._skip_if_unsupported( + transform_values_after_pooling=transform_values_after_pooling) # The same test graph as in the testBasic above. gt_input = _get_test_bidi_cycle_graph( tf.constant([ @@ -595,6 +608,8 @@ def testMultihead(self, transform_values_after_pooling): ("RestoredKerasTransformAfter", ReloadModel.KERAS, True)) def testFullModel(self, reload_model, transform_values_after_pooling): """Tests MultiHeadAttentionHomGraphUpdate in a Model with edge input.""" + self._skip_if_unsupported( + transform_values_after_pooling=transform_values_after_pooling) # The same example as in the testBasic above, but with extra inputs # from edges. gt_input = _get_test_bidi_cycle_graph(