Skip to content

Commit

Permalink
TF2.10+ is required for transform_values_after_pooling=True
Browse files Browse the repository at this point in the history
in MultiHeadAttentionConv, because it needs EinsumDense
(and not the older experimental version).

PiperOrigin-RevId: 505134783
  • Loading branch information
arnoegw authored and tensorflower-gardener committed Jan 27, 2023
1 parent 6626331 commit 3f6e17a
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
13 changes: 12 additions & 1 deletion tensorflow_gnn/models/multi_head_attention/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,9 @@ class MultiHeadAttentionConv(tfgnn.keras.layers.AnyToAnyConvolutionBase):
Setting this option pools inputs with attention coefficients, then applies
the transformation. This is mathematically equivalent but can be faster
or slower to compute, depending on the platform and the dataset.
IMPORANT: Toggling this option breaks checkpoint compatibility.
IMPORTANT: Toggling this option breaks checkpoint compatibility.
IMPORTANT: Setting this option requires TensorFlow 2.10 or greater,
because it uses `tf.keras.layers.EinsumDense`.
"""

def __init__(
Expand Down Expand Up @@ -300,6 +302,15 @@ def __init__(
else:
self._w_sender_edge_to_value = None
else:
# TODO(b/266868417): Remove when TF2.10+ is required by all of TF-GNN.
try:
_ = tf.keras.layers.EinsumDense
except AttributeError as e:
raise ValueError(
"MultiHeadAttentionConv(transform_values_after_pooling=True) "
"requires tf.keras.layers.EinsumDense from "
f"TensorFlow 2.10 or newer, got TensorFlow {tf.__version__}"
) from e
self._w_sender_pooled_to_value = tf.keras.layers.EinsumDense(
equation="...hv,hvc->...hc",
output_shape=(num_heads, per_head_channels),
Expand Down
15 changes: 15 additions & 0 deletions tensorflow_gnn/models/multi_head_attention/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,20 @@ class ReloadModel(int, enum.Enum):

class MultiHeadAttentionTest(tf.test.TestCase, parameterized.TestCase):

# TODO(b/266868417): Remove when TF2.10+ is required by all of TF-GNN.
def _skip_if_unsupported(self, transform_values_after_pooling=None):
"""Skips test if TF is tool old for a requested option."""
if transform_values_after_pooling:
if tf.__version__.startswith("2.8.") or tf.__version__.startswith("2.9."):
self.skipTest(
"MultiHeadAttentionConv(transform_values_after_pooling=True) "
f"requires TF 2.10+, got {tf.__version__}")

@parameterized.named_parameters(("", False), ("TransformAfter", True))
def testBasic(self, transform_values_after_pooling):
"""Tests that a single-headed MHA is correct given predefined weights."""
self._skip_if_unsupported(
transform_values_after_pooling=transform_values_after_pooling)
# NOTE: Many following tests use minor variations of the explicit
# construction of weights and results introduced here.

Expand Down Expand Up @@ -503,6 +514,8 @@ def testNoTransformKeys(self):
@parameterized.named_parameters(("", False), ("TransformAfter", True))
def testMultihead(self, transform_values_after_pooling):
"""Extends testBasic with multiple attention heads."""
self._skip_if_unsupported(
transform_values_after_pooling=transform_values_after_pooling)
# The same test graph as in the testBasic above.
gt_input = _get_test_bidi_cycle_graph(
tf.constant([
Expand Down Expand Up @@ -595,6 +608,8 @@ def testMultihead(self, transform_values_after_pooling):
("RestoredKerasTransformAfter", ReloadModel.KERAS, True))
def testFullModel(self, reload_model, transform_values_after_pooling):
"""Tests MultiHeadAttentionHomGraphUpdate in a Model with edge input."""
self._skip_if_unsupported(
transform_values_after_pooling=transform_values_after_pooling)
# The same example as in the testBasic above, but with extra inputs
# from edges.
gt_input = _get_test_bidi_cycle_graph(
Expand Down

0 comments on commit 3f6e17a

Please sign in to comment.