Skip to content

Commit

Permalink
Let TF-GNN choose between keras or tf_keras consistently with TF 2.15:
Browse files Browse the repository at this point in the history
both provide Keras 2.15, but it matters which one is used,
because they have separate class hierarchies and global registries.

Along the way, refactor the nested case distinctions of tf_internal.py
into a clear list of supported older TF/Keras versions.

PiperOrigin-RevId: 604619976
  • Loading branch information
arnoegw committed Feb 6, 2024
1 parent b7a9027 commit 022fc47
Showing 1 changed file with 69 additions and 46 deletions.
115 changes: 69 additions & 46 deletions tensorflow_gnn/graph/tf_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
TODO(b/188399175): Use the public ExtensionType API instead.
"""

import os

##
## Part 1: TensorFlow symbols
##

# The following imports work in all supported versions of TF.
# pylint: disable=g-direct-tensorflow-import,g-import-not-at-top,g-bad-import-order
from tensorflow.python.framework import composite_tensor
Expand All @@ -32,39 +38,6 @@
except ImportError:
type_spec_registry = None # Not available before TF 2.12.

# NOTE: See ../__init__.py for an up-front check of supported Keras versions.

try:
try:
# Get Keras v2 from the separate tf_keras package.
# In OSS, it exists for TF2.14+. It may become required for TF2.16+.
from tf_keras.src.engine import keras_tensor # pytype: disable=import-error
from tf_keras.src.layers import core as core_layers # pytype: disable=import-error
import tf_keras.src.backend as keras_backend # pytype: disable=import-error
except ImportError:
# Get Keras v2 from the keras package.
# In OSS, this is possible for TF2.15 and older.
import keras # pytype: disable=import-error
if not keras.__version__.startswith('2.'):
raise ImportError(
'tensorflow_gnn requires tf_keras to be installed or keras version <'
f' 3. Instead got keras version {keras.__version__}.'
) from None # A Keras version mismatch is different to lacking tf_keras.
import keras # pytype: disable=import-error
if hasattr(keras, 'src'): # As of TF/Keras 2.13.
from keras.src.engine import keras_tensor # pytype: disable=import-error
from keras.src.layers import core as core_layers # pytype: disable=import-error
import keras.src.backend as keras_backend # pytype: disable=import-error
else:
from keras.engine import keras_tensor # pytype: disable=import-error
from keras.layers import core as core_layers # pytype: disable=import-error
import keras.backend as keras_backend # pytype: disable=import-error
except ImportError:
# Internal
keras_tensor = tf._keras_internal.engine.keras_tensor # pylint: disable=protected-access
core_layers = tf._keras_internal.layers.core # pylint: disable=protected-access
keras_backend = tf._keras_internal.backend # pylint: disable=protected-access

CompositeTensor = composite_tensor.CompositeTensor
BatchableTypeSpec = type_spec.BatchableTypeSpec
type_spec_register = (
Expand All @@ -79,28 +52,78 @@
type_spec_registry.lookup if type_spec_registry is not None
else type_spec.lookup)

try:
# These types are semi-public as of TF/Keras 2.13.
# Whenever possible, get them the official way.
OpDispatcher = tf.__internal__.dispatch.OpDispatcher


##
## Part 2: Keras symbols, compatible with `tf.keras.*`
##

# pytype: disable=import-error

if tf.__version__.startswith("2.12."):
# tf.keras is keras 2.12, which does not yet have the `src` subdirectory.
from keras import backend as keras_backend
from keras.engine import keras_tensor as kt
from keras.layers import core as core_layers
# In 2.12, these symbols are not exposed yet under tf.keras.__internal__.
KerasTensor = kt.KerasTensor
RaggedKerasTensor = kt.RaggedKerasTensor

elif tf.__version__.startswith("2.13.") or tf.__version__.startswith("2.14."):
KerasTensor = tf.keras.__internal__.KerasTensor
RaggedKerasTensor = tf.keras.__internal__.RaggedKerasTensor
except AttributeError:
KerasTensor = keras_tensor.KerasTensor
RaggedKerasTensor = keras_tensor.RaggedKerasTensor
# These KerasTensor helpers are still private in TF/Keras 2.13.
register_keras_tensor_specialization = (
keras_tensor.register_keras_tensor_specialization)
delegate_property = core_layers._delegate_property # pylint: disable=protected-access
delegate_method = core_layers._delegate_method # pylint: disable=protected-access
# tf.keras is keras.
# For TF 2.14, there also exists a tf_keras package, but TF does not use it.
from keras.src import backend as keras_backend
from keras.src.engine import keras_tensor as kt
from keras.src.layers import core as core_layers

OpDispatcher = tf.__internal__.dispatch.OpDispatcher
elif tf.__version__.startswith("2.15."):
KerasTensor = tf.keras.__internal__.KerasTensor
RaggedKerasTensor = tf.keras.__internal__.RaggedKerasTensor
# OSS TensorFlow 2.15 can choose between keras 2.15 and tf_keras 2.15
# BUT THESE ARE DIFFERENT PACKAGES WITH SEPARATE GLOBAL REGISTRIES
# so it is essential that we pick the right one by replicating the logic from
# https://github.com/tensorflow/tensorflow/blob/r2.15/tensorflow/python/util/lazy_loader.py#L96
if os.environ.get("TF_USE_LEGACY_KERAS", None) in ("true", "True", "1"):
from tf_keras.src import backend as keras_backend
from tf_keras.src.layers import core as core_layers
from tf_keras.src.engine import keras_tensor as kt
else:
from keras.src import backend as keras_backend
from keras.src.layers import core as core_layers
from keras.src.engine import keras_tensor as kt

elif hasattr(tf, "_keras_internal"): # Special case: internal.
KerasTensor = tf.keras.__internal__.KerasTensor
RaggedKerasTensor = tf.keras.__internal__.RaggedKerasTensor
kt = tf._keras_internal.engine.keras_tensor # pylint: disable=protected-access
core_layers = tf._keras_internal.layers.core # pylint: disable=protected-access
keras_backend = tf._keras_internal.backend # pylint: disable=protected-access

else: # TF2.16 and onwards.
# ../__init__.py has already checked that tf.keras has version 2, not 3,
# which implies that tf.keras is tf_keras, and we do not second-guess
# the selection logic.
KerasTensor = tf.keras.__internal__.KerasTensor
RaggedKerasTensor = tf.keras.__internal__.RaggedKerasTensor
from tf_keras.src import backend as keras_backend
from tf_keras.src.layers import core as core_layers
from tf_keras.src.engine import keras_tensor as kt

# pytype: enable=import-error

register_keras_tensor_specialization = kt.register_keras_tensor_specialization
delegate_property = core_layers._delegate_property # pylint: disable=protected-access
delegate_method = core_layers._delegate_method # pylint: disable=protected-access
unique_keras_object_name = keras_backend.unique_object_name

# Delete imports, in their order above.
del composite_tensor
del type_spec
del tf
del type_spec_registry
del keras_tensor
del keras_backend
del core_layers
del kt

0 comments on commit 022fc47

Please sign in to comment.