From ba4df3d1cdb893c290d134e4a17b205a06147a6c Mon Sep 17 00:00:00 2001 From: Divya S Date: Mon, 14 Aug 2023 14:14:20 -0700 Subject: [PATCH] Add DynamicEmbedding to Keras PiperOrigin-RevId: 556908566 --- keras/BUILD | 1 + ...callbacks.-update-embedding-callback.pbtxt | 99 +++++ .../v1/tensorflow.keras.callbacks.pbtxt | 4 + ...callbacks.-update-embedding-callback.pbtxt | 99 +++++ .../v2/tensorflow.keras.callbacks.pbtxt | 4 + keras/callbacks.py | 117 ++++++ keras/layers/__init__.py | 1 + keras/layers/experimental/BUILD | 101 +++++ keras/layers/experimental/__init__.py | 0 .../layers/experimental/dynamic_embedding.py | 253 ++++++++++++ .../dynamic_embedding_distributed_test.py | 82 ++++ .../experimental/dynamic_embedding_test.py | 304 ++++++++++++++ keras/layers/experimental/dynamic_lookup.py | 372 ++++++++++++++++++ .../dynamic_lookup_distributed_test.py | 67 ++++ .../experimental/dynamic_lookup_test.py | 253 ++++++++++++ 15 files changed, 1757 insertions(+) create mode 100644 keras/api/golden/v1/tensorflow.keras.callbacks.-update-embedding-callback.pbtxt create mode 100644 keras/api/golden/v2/tensorflow.keras.callbacks.-update-embedding-callback.pbtxt create mode 100644 keras/layers/experimental/BUILD create mode 100644 keras/layers/experimental/__init__.py create mode 100644 keras/layers/experimental/dynamic_embedding.py create mode 100644 keras/layers/experimental/dynamic_embedding_distributed_test.py create mode 100644 keras/layers/experimental/dynamic_embedding_test.py create mode 100644 keras/layers/experimental/dynamic_lookup.py create mode 100644 keras/layers/experimental/dynamic_lookup_distributed_test.py create mode 100644 keras/layers/experimental/dynamic_lookup_test.py diff --git a/keras/BUILD b/keras/BUILD index d31fcbc2b0e3..e9a787993c53 100644 --- a/keras/BUILD +++ b/keras/BUILD @@ -143,6 +143,7 @@ py_library( "//keras/protobuf:projector_config_proto_py_pb2", "//keras/utils:engine_utils", "//keras/utils:mode_keys", + "//keras/utils:timed_threads", ], ) diff --git a/keras/api/golden/v1/tensorflow.keras.callbacks.-update-embedding-callback.pbtxt b/keras/api/golden/v1/tensorflow.keras.callbacks.-update-embedding-callback.pbtxt new file mode 100644 index 000000000000..11e6da447004 --- /dev/null +++ b/keras/api/golden/v1/tensorflow.keras.callbacks.-update-embedding-callback.pbtxt @@ -0,0 +1,99 @@ +path: "tensorflow.keras.callbacks.UpdateEmbeddingCallback" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dynamic_embedding_layer\', \'interval\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_alive" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "on_batch_begin" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_batch_end" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_epoch_begin" + argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_epoch_end" + argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_interval" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "on_predict_batch_begin" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_predict_batch_end" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_predict_begin" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_predict_end" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_test_batch_begin" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_test_batch_end" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_test_begin" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_test_end" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_batch_begin" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_batch_end" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_begin" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_end" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_model" + argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_params" + argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "start" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "stop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/keras/api/golden/v1/tensorflow.keras.callbacks.pbtxt b/keras/api/golden/v1/tensorflow.keras.callbacks.pbtxt index 1d92b38192a5..e393201db1fc 100644 --- a/keras/api/golden/v1/tensorflow.keras.callbacks.pbtxt +++ b/keras/api/golden/v1/tensorflow.keras.callbacks.pbtxt @@ -60,4 +60,8 @@ tf_module { name: "TerminateOnNaN" mtype: "" } + member { + name: "UpdateEmbeddingCallback" + mtype: "" + } } diff --git a/keras/api/golden/v2/tensorflow.keras.callbacks.-update-embedding-callback.pbtxt b/keras/api/golden/v2/tensorflow.keras.callbacks.-update-embedding-callback.pbtxt new file mode 100644 index 000000000000..11e6da447004 --- /dev/null +++ b/keras/api/golden/v2/tensorflow.keras.callbacks.-update-embedding-callback.pbtxt @@ -0,0 +1,99 @@ +path: "tensorflow.keras.callbacks.UpdateEmbeddingCallback" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dynamic_embedding_layer\', \'interval\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_alive" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "on_batch_begin" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_batch_end" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_epoch_begin" + argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_epoch_end" + argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_interval" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "on_predict_batch_begin" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_predict_batch_end" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_predict_begin" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_predict_end" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_test_batch_begin" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_test_batch_end" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_test_begin" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_test_end" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_batch_begin" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_batch_end" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_begin" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_end" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_model" + argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_params" + argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "start" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "stop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/keras/api/golden/v2/tensorflow.keras.callbacks.pbtxt b/keras/api/golden/v2/tensorflow.keras.callbacks.pbtxt index 6b162ce1e347..68202fa7c461 100644 --- a/keras/api/golden/v2/tensorflow.keras.callbacks.pbtxt +++ b/keras/api/golden/v2/tensorflow.keras.callbacks.pbtxt @@ -64,6 +64,10 @@ tf_module { name: "TerminateOnNaN" mtype: "" } + member { + name: "UpdateEmbeddingCallback" + mtype: "" + } member { name: "experimental" mtype: "" diff --git a/keras/callbacks.py b/keras/callbacks.py index bc5a3080512a..a2a79e05171c 100644 --- a/keras/callbacks.py +++ b/keras/callbacks.py @@ -40,6 +40,7 @@ from keras.utils.data_utils import Sequence from keras.utils.generic_utils import Progbar from keras.utils.mode_keys import ModeKeys +from keras.utils.timed_threads import TimedThread # isort: off from tensorflow.python.platform import tf_logging as logging @@ -3306,3 +3307,119 @@ def __init__( self.on_train_begin = on_train_begin if on_train_end is not None: self.on_train_end = on_train_end + + +@keras_export("keras.callbacks.experimental.UpdateEmbeddingCallback") +class UpdateEmbeddingCallback(TimedThread, Callback): + """A callback to update the DynamicEmbedding layer at specific time + interval. + + Updating the embedding matrix would mean that the optimizer variables will + be reset in this callback and this could have potential side effects. This + means that any existing slot variables associated with the optimizer will + likely be discarded when the optimizer is rebuilt. This affects optimizers + that rely on states of optimizer slot variables. + + Example: + ``` + # Generate dummy data + train_data = np.array([ + ['a', 'j', 'c', 'd', 'e'], + ['a', 'h', 'i', 'j', 'b'], + ['i', 'h', 'c', 'j', 'e'], + ]) + train_labels = np.array([0, 1, 2]) + vocab = tf.constant(['a', 'b', 'c', 'd', 'e']) + eviction_policy = 'LFU' + # Define the model + model = tf.keras.models.Sequential([ + DynamicEmbedding( + input_dim=5, + output_dim=2, + input_length=5, + eviction_policy=eviction_policy, + initial_vocabulary=vocab, + ), + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(3, activation='softmax'), + ]) + + # Compile the model + model.compile( + optimizer='adam', + loss='sparse_categorical_crossentropy', + metrics=['accuracy'], + ) + # update the vocabulary every 1 second + update_embedding_callback = UpdateEmbeddingCallback( + model.layers[0], interval=1 + ) + with update_embedding_callback: + result = model.fit( + train_data, + train_labels, + epochs=100, + batch_size=1, + callbacks=[update_embedding_callback], + ) + ``` + """ + + def __init__(self, dynamic_embedding_layer, interval): + """Initialize Timed Callback object. + + Args: + dynamic_embedding_layer: The dynamic embedding + layer to be updated. + interval: the interval, in seconds, to wait between calls to the + thread function. The thread function here updates the embeddings + matrix and resets the optimizer states. + """ + self._epoch = 0 + TimedThread.__init__(self, interval) + Callback.__init__(self) + self._dynamic_embedding_layer = dynamic_embedding_layer + self.strategy = tf.distribute.get_strategy() + + def on_interval(self): + try: + critical_section = tf.CriticalSection() + + # Using `tf.CriticalSection` when updating embeddings using timed + # thread can help ensure thread safety and prevent race conditions + # in the shared variables. + def execute_critical_section(): + critical_section.execute( + lambda: self._dynamic_embedding_layer.update_embeddings( # pylint: disable=g-long-lambda + self.strategy + ) + ) + + # update embeddings across all devices if distributed training is + # used + self.strategy.run(execute_critical_section) + # update optimizer variables across all devices if distributed + # training is used. + self.strategy.run( + lambda: self._reset_optimizer() + ) # pylint: disable=unnecessary-lambda + except AttributeError: + logging.info( + "Time interval specified to the UpdateEmbeddingCallback may be" + " too small, please try increasing the value of `interval`." + ) + + def _reset_optimizer(self): + """Resetting the optimizer variables. + + Resetting the optimizer variables is necessary after updating the + variable in the layer. This ensures that the optimizer is working with a + consistent internal state. This helps to prevent unexpected behavior and + can lead to more stable and faster training of the model. + """ + for var in self.model.optimizer.variables(): + if "dynamic_embedding" in var.name: + backend.set_value(var, backend.zeros_like(var)) + + def on_epoch_begin(self, epoch, logs=None): + self._epoch = epoch diff --git a/keras/layers/__init__.py b/keras/layers/__init__.py index 6812e92aa4ec..a920d99cf669 100644 --- a/keras/layers/__init__.py +++ b/keras/layers/__init__.py @@ -73,6 +73,7 @@ from keras.layers.core.tf_op_layer import SlicingOpLambda from keras.layers.core.tf_op_layer import TFOpLambda + # Locally-connected layers. from keras.layers.locally_connected.locally_connected1d import ( LocallyConnected1D, diff --git a/keras/layers/experimental/BUILD b/keras/layers/experimental/BUILD new file mode 100644 index 000000000000..f3830107cefe --- /dev/null +++ b/keras/layers/experimental/BUILD @@ -0,0 +1,101 @@ +# Description: +# DynamicEmbeddingLayer allows for the continuous updating of the vocabulary and embeddings during +# the training process. + +# Placeholder: load unaliased py_library +load("@org_keras//keras:keras.bzl", "tf_py_test") +load("@org_keras//keras:keras.bzl", "distribute_py_test") + +package( + # copybara:uncomment default_applicable_licenses = ["//keras:license"], + default_visibility = [ + "//keras:friends", + ], + licenses = ["notice"], +) + +py_library( + name = "dynamic_lookup", + srcs = ["dynamic_lookup.py"], + srcs_version = "PY3", + deps = [ + "//:expect_tensorflow_installed", + "//keras/layers", + "//third_party/tensorflow/python/util:core", + ], +) + +tf_py_test( + name = "dynamic_lookup_test", + size = "small", + srcs = ["dynamic_lookup_test.py"], + python_version = "PY3", + shard_count = 6, + deps = [ + ":dynamic_lookup", + "//:expect_numpy_installed", + "//:expect_tensorflow_installed", + "//keras", + "//keras/testing_infra:test_combinations", + "//testing/pybase", + ], +) + +py_library( + name = "dynamic_embedding", + srcs = ["dynamic_embedding.py"], + deps = [ + ":dynamic_lookup", + "//:expect_tensorflow_installed", + "//keras", + "//keras/layers", + "//keras/utils", + "//third_party/py/absl/logging", + "//third_party/tensorflow/python/util:core", + ], +) + +tf_py_test( + name = "dynamic_embedding_test", + srcs = ["dynamic_embedding_test.py"], + deps = [ + ":dynamic_embedding", + "//:expect_tensorflow_installed", + "//keras:callbacks", + "//keras/layers", + "//keras/models", + "//keras/testing_infra:test_combinations", + "//testing/pybase", + ], +) + +distribute_py_test( + name = "dynamic_embedding_distributed_test", + srcs = ["dynamic_embedding_distributed_test.py"], + tags = [ + "no_oss", + "no_windows", + "nomultivm", + ], + deps = [ + ":dynamic_embedding", + "//:expect_absl_installed", + "//:expect_numpy_installed", + "//:expect_tensorflow_installed", + "//keras", + "//keras:callbacks", + ], +) + +distribute_py_test( + name = "dynamic_lookup_distributed_test", + srcs = ["dynamic_lookup_distributed_test.py"], + tags = ["nomultivm"], + deps = [ + ":dynamic_lookup", + "//:expect_absl_installed", + "//:expect_numpy_installed", + "//:expect_tensorflow_installed", + "//keras", + ], +) diff --git a/keras/layers/experimental/__init__.py b/keras/layers/experimental/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/keras/layers/experimental/dynamic_embedding.py b/keras/layers/experimental/dynamic_embedding.py new file mode 100644 index 000000000000..6813e36c2c78 --- /dev/null +++ b/keras/layers/experimental/dynamic_embedding.py @@ -0,0 +1,253 @@ +"""A layer that updates its vocab and embedding matrix during training.""" + +import tensorflow as tf +from absl import logging +from tensorflow.python.util.tf_export import keras_export + +import keras +from keras.layers import Layer +from keras.layers.experimental import dynamic_lookup +from keras.utils import warmstart_embedding_matrix + + +@keras_export("keras.layers.experimental.DynamicEmbedding") +class DynamicEmbedding(Layer): + """A layer that updates its vocab and embedding matrix during training. + + DynamicEmbedding allows for the continuous updating of the vocabulary + and embeddings during the training process. In traditional methods, the + vocabulary and mapping to the embedding vector are set at the beginning of + the training process and remain fixed throughout the training process. + However, in many real-world scenarios, the vocabulary and mapping to the + embeddings need to be updated to reflect the changing nature of the data. + + For instance, in natural language processing tasks, the vocabulary of + words in a corpus may change over time, and it's important to update the + embeddings to reflect these changes. Similarly, in recommendation systems, + the items in the vocabulary may change over time. + + A layer that supports dynamic embeddings addresses this issue by allowing + for the continuous updating of the vocabulary and embeddings during the + training process. and it also updates the embedding matrix to reflect the + new vocabulary. + + This layer maintains a hash table to track the most up-to-date vocabulary + based on the inputs received by the layer and the eviction policy. When + this layer is used with an `UpdateEmbeddingCallback`, which is a + time-based callback, the vocabulary lookup tensor is updated at the time + interval set in the `UpdateEmbeddingCallback` based on the most up-to-date + vocabulary hash table maintained by the layer. If this layer is not used + in conjunction with `UpdateEmbeddingCallback` the behavior of the layer + would be same as `keras.layers.Embedding`. + + Args: + input_dim: Size of the vocabulary in the input data. Expects an integer. + output_dim: The size of the embedding space. Expects an integer. + initial_vocabulary: The vocabulary to initialize the layer with. If a 1D + tensor is provided, the vocabulary will be initialized with that tensor. + If a `tf.DType` object is provided, a random tensor of that dtype and of + length `input_dim` will be generated as the initial vocabulary. + Supported `tf.DType` values include `tf.int32`, `tf.int64` and + `tf.string`. + eviction_policy: The eviction policy for the vocabulary. Available options + are "LFU" (Least Frequently Used) and *more to come*. Defaults to "LFU". + Expects a string. + input_length: Length of input sequences, when it is constant. This + argument is required if you are going to connect `Flatten` then `Dense` + layers upstream (without it, the shape of the dense outputs cannot be + computed).Expects an integer. + embedding_initializer: Initializer for embedding vectors for new input + vocab tokens to be added to the updated embedding matrix (see + keras.initializers). Defaults to "uniform". + num_oov_indices: Number of out of vocabulary token to use. Currently + supports 1. Expects an integer. + **kwargs: Additional keyword arguments for the parent class. + + Attributes: + embedding_layer: Embedding layer of DynamicEmbedding layer. + dynamic_lookup_layer: DynamicLookup layer of DynamicEmbedding layer. + embedding_initializer: Initializer for embedding vectors for new input + vocab tokens to be added to the updated embedding matrix (see + keras.initializers). + num_oov_indices: Number of out of vocabulary token to use. + + Example: + ``` + # Generate dummy data + train_data = np.array([ + ['a', 'j', 'c', 'd', 'e'], + ['a', 'h', 'i', 'j', 'b'], + ['i', 'h', 'c', 'j', 'e'], + ]) + train_labels = np.array([0, 1, 2]) + vocab = tf.constant(['a', 'b', 'c', 'd', 'e']) + eviction_policy = 'LFU' + # Define the model + model = tf.keras.models.Sequential([ + DynamicEmbedding( + input_dim=5, + output_dim=2, + input_length=5, + eviction_policy=eviction_policy, + initial_vocabulary=vocab, + ), + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(3, activation='softmax'), + ]) + + # Compile the model + model.compile( + optimizer='adam', + loss='sparse_categorical_crossentropy', + metrics=['accuracy'], + ) + # update the vocabulary every 1 second + update_embedding_callback = UpdateEmbeddingCallback( + model.layers[0], interval=1 + ) + with update_embedding_callback: + result = model.fit( + train_data, + train_labels, + epochs=100, + batch_size=1, + callbacks=[update_embedding_callback], + ) + ``` + """ + + def __init__( + self, + input_dim, + output_dim, + initial_vocabulary, + eviction_policy="LFU", + input_length=None, + embedding_initializer="uniform", + num_oov_indices=1, + **kwargs, + ): + """Initialize DynamicEmbedding layer.""" + + super().__init__(**kwargs) + # assuming one oov bucket for now + self.embedding_layer = keras.layers.Embedding( + input_dim=input_dim + num_oov_indices, + output_dim=output_dim, + embeddings_initializer=embedding_initializer, + input_length=input_length, + **kwargs, + ) + self.dynamic_lookup_layer = dynamic_lookup.DynamicLookup( + vocabulary_size=input_dim, + eviction_policy=eviction_policy, + initial_vocabulary=initial_vocabulary, + **kwargs, + ) + self.embedding_initializer = embedding_initializer + self.num_oov_indices = num_oov_indices + + def build(self, input_shape=None): + self.embedding_layer.build(input_shape) + self.dynamic_lookup_layer.build(input_shape) + + def call(self, inputs, learn_vocab=True): + # get vocab to index mapped for dynamic_lookup_layer + output = self.dynamic_lookup_layer(inputs, learn_vocab=learn_vocab) + # pass the indices as inputs to embedding_layer + return self.embedding_layer(output) + + def get_config(self): + config = super().get_config() + config.update( + { + "input_dim": self.embedding_layer.input_dim, + "output_dim": self.embedding_layer.output_dim, + "input_length": self.embedding_layer.input_length, + "eviction_policy": self.dynamic_lookup_layer.eviction_policy, + "initial_vocabulary": ( + self.dynamic_lookup_layer.initial_vocabulary.numpy().tolist() + ), + "embedding_initializer": self.embedding_initializer, + "num_oov_indices": self.num_oov_indices, + } + ) + return config + + def get_vocabulary(self): + return self.dynamic_lookup_layer.get_vocabulary() + + def save_assets(self, dir_path): + initial_vocabulary = ( + self.dynamic_lookup_layer.initial_vocabulary.numpy().tolist() + ) + initial_vocabulary_filepath = tf.io.gfile.join( + dir_path, "initial_vocabulary.txt" + ) + with open(initial_vocabulary_filepath, "w") as f: + f.write("\n".join([str(w) for w in initial_vocabulary])) + + def update_embeddings(self, strategy): + """Update embedding matrix of dynamic embedding layer.""" + try: + if isinstance(strategy, tf.distribute.ParameterServerStrategy): + # if using PSS agrregate values + keys_list = ( + self.dynamic_lookup_layer.vocabulary_table_keys.read_all() + ) + values_list = ( + self.dynamic_lookup_layer.vocabulary_table_values.read_all() + ) + keys, values = self.aggregate_lookup_table( + keys_list, values_list + ) + else: + # if using on device strategy, just read values + keys, values = ( + self.dynamic_lookup_layer.vocabulary_table_keys, + self.dynamic_lookup_layer.vocabulary_table_values, + ) + old_vocab = self.dynamic_lookup_layer.vocabulary + new_vocab = self.get_top_vocabulary( + keys, + values, + self.dynamic_lookup_layer.vocabulary_size, + ) + # remap and update the embedding matrix + embedding_matrix = self.embedding_layer.embeddings + oov_token = tf.fill([self.num_oov_indices], "UNK") + updated_new_vocab = tf.concat([new_vocab, oov_token], axis=0) + embedding_matrix = warmstart_embedding_matrix( + base_vocabulary=list(old_vocab.numpy()), + new_vocabulary=updated_new_vocab, + base_embeddings=embedding_matrix, + new_embeddings_initializer=self.embedding_initializer, + ) + self.dynamic_lookup_layer.vocabulary.assign(new_vocab) + self.embedding_layer.embeddings.assign(embedding_matrix) + except AttributeError: + logging.info( + "Time interval specified by the UpdateEmbeddingCallback may be" + " too small, please try increasing the value of `interval`." + ) + + def aggregate_lookup_table(self, keys_list, values_list): + # Flatten the keys and values matrices + keys_1d = tf.reshape(keys_list, [-1]) + values_1d = tf.reshape(values_list, [-1]) + # Get unique keys and their corresponding summed values + unique_keys, idx, _ = tf.unique_with_counts(keys_1d) + summed_values = tf.math.unsorted_segment_sum( + values_1d, idx, tf.shape(unique_keys)[0] + ) + return unique_keys, summed_values + + def get_top_vocabulary(self, keys, values, k): + """Get Top vocabulary keys and values.""" + values_len = tf.shape(keys)[0] + if values_len > k: + _, indices = tf.nn.top_k(values, k=k) + else: + _, indices = tf.nn.top_k(values, k=values_len) + top_k_vocab = tf.gather(keys, indices) + return top_k_vocab diff --git a/keras/layers/experimental/dynamic_embedding_distributed_test.py b/keras/layers/experimental/dynamic_embedding_distributed_test.py new file mode 100644 index 000000000000..4fb530082d08 --- /dev/null +++ b/keras/layers/experimental/dynamic_embedding_distributed_test.py @@ -0,0 +1,82 @@ +"""Test DynamicEmbedding with Parameter server strategy.""" + +import numpy as np +import tensorflow.compat.v2 as tf +from absl.testing import parameterized + +import keras +from keras.callbacks import UpdateEmbeddingCallback +from keras.layers.experimental import dynamic_embedding + +ds_combinations = tf.__internal__.distribute.combinations + + +class DistributedDynamicEmbeddingTest(tf.test.TestCase, parameterized.TestCase): + @ds_combinations.generate( + tf.__internal__.test.combinations.combine( + strategy=[ + ds_combinations.parameter_server_strategy_3worker_2ps_cpu + ], + mode="eager", + ) + ) + def test_dynamic_embedding_with_pss(self, strategy): + # Generate dummy data + train_data = np.array( + [ + ["a", "j", "c", "d", "e"], + ["a", "h", "i", "j", "b"], + ["i", "h", "c", "j", "e"], + ] + ) + train_labels = np.array([0, 1, 2]) + vocab = tf.constant(["a", "b", "c", "d", "e"]) + eviction_policy = "LFU" + with strategy.scope(): + # Define the model + model = keras.models.Sequential( + [ + dynamic_embedding.DynamicEmbedding( + input_dim=5, + output_dim=2, + input_length=5, + eviction_policy=eviction_policy, + initial_vocabulary=vocab, + ), + keras.layers.Flatten(), + keras.layers.Dense(3, activation="softmax"), + ] + ) + update_embedding_callback = UpdateEmbeddingCallback( + model.layers[0], + interval=1, + ) + with update_embedding_callback: + # Compile the model + model.compile( + optimizer="adam", + loss="sparse_categorical_crossentropy", + metrics=["accuracy"], + ) + result = model.fit( + train_data, + train_labels, + epochs=100, + batch_size=1, + steps_per_epoch=2, + callbacks=[update_embedding_callback], + ) + # Assert model trains + self.assertEqual(result.history["loss"][0] > 0, True) + self.assertTrue( + tf.reduce_all( + tf.not_equal( + model.layers[0].dynamic_lookup_layer.vocabulary, + vocab, + ) + ) + ) + + +if __name__ == "__main__": + tf.__internal__.distribute.multi_process_runner.test_main() diff --git a/keras/layers/experimental/dynamic_embedding_test.py b/keras/layers/experimental/dynamic_embedding_test.py new file mode 100644 index 000000000000..eacafb3b21ad --- /dev/null +++ b/keras/layers/experimental/dynamic_embedding_test.py @@ -0,0 +1,304 @@ +"""Test DynamicEmbeddingLayer.""" + +import numpy as np +import tensorflow as tf + +from keras import layers +from keras import models +from keras.callbacks import UpdateEmbeddingCallback +from keras.layers.experimental import dynamic_embedding +from keras.testing_infra import test_combinations + + +class DynamicEmbeddingTest(test_combinations.TestCase): + def test_dynamic_embedding_layer(self): + input_ = np.array([["a", "j", "c", "d", "e"]]) + vocab = tf.constant(["a", "b", "c", "d", "e"]) + eviction_policy = "LFU" + # Define the layer + layer = dynamic_embedding.DynamicEmbedding( + input_dim=5, + output_dim=2, + input_length=5, + eviction_policy=eviction_policy, + initial_vocabulary=vocab, + ) + output = layer(input_) + self.assertTrue( + tf.reduce_all(tf.equal(tf.shape(output), tf.constant([1, 5, 2]))) + ) + self.assertTrue((layer.built)) + self.assertTrue((layer.dynamic_lookup_layer.built)) + self.assertTrue((layer.embedding_layer.built)) + + def test_model_save_load(self): + train_data = np.array( + [ + ["a", "j", "c", "d", "e"], + ["a", "h", "i", "j", "b"], + ["i", "h", "c", "j", "e"], + ] + ) + train_labels = np.array([0, 1, 2]) + vocab = tf.constant(["a", "b", "c", "d", "e"]) + eviction_policy = "LFU" + # Define the model + model = models.Sequential( + [ + dynamic_embedding.DynamicEmbedding( + input_dim=5, + output_dim=2, + input_length=5, + eviction_policy=eviction_policy, + initial_vocabulary=vocab, + name="dynamic_embedding", + ), + layers.Flatten(), + layers.Dense(3, activation="softmax"), + ] + ) + + # Compile the model + model.compile( + optimizer="adam", + loss="sparse_categorical_crossentropy", + metrics=["accuracy"], + ) + model.fit( + train_data, + train_labels, + epochs=10, + batch_size=1, + ) + # Save the model to a temporary file + filepath = self.create_tempdir() + model.save(filepath) + # Load the model from the temporary file + reloaded_model = models.load_model(filepath) + self.assertTrue( + tf.reduce_all( + tf.equal( + model.get_layer( + "dynamic_embedding" + ).dynamic_lookup_layer.vocabulary.numpy(), + reloaded_model.get_layer( + "dynamic_embedding" + ).dynamic_lookup_layer.vocabulary.numpy(), + ) + ) + ) + + def test_dynamic_embedding_layer_with_callback(self): + # Generate dummy data + train_data = np.array( + [ + ["a", "j", "c", "d", "e"], + ["a", "h", "i", "j", "b"], + ["i", "h", "c", "j", "e"], + ] + ) + train_labels = np.array([0, 1, 2]) + vocab = tf.constant(["a", "b", "c", "d", "e"]) + eviction_policy = "LFU" + # Define the model + model = models.Sequential( + [ + dynamic_embedding.DynamicEmbedding( + input_dim=5, + output_dim=2, + input_length=5, + eviction_policy=eviction_policy, + initial_vocabulary=vocab, + ), + layers.Flatten(), + layers.Dense(3, activation="softmax"), + ] + ) + + # Compile the model + model.compile( + optimizer="adam", + loss="sparse_categorical_crossentropy", + metrics=["accuracy"], + ) + update_embedding_callback = UpdateEmbeddingCallback( + model.layers[0], + interval=2, + ) + with update_embedding_callback: + result = model.fit( + train_data, + train_labels, + epochs=100, + batch_size=1, + callbacks=[update_embedding_callback], + ) + # Assert model trains + self.assertEqual(result.history["loss"][0] > 0, True) + # assert vocab is updated in DynamicLookup + self.assertTrue( + tf.reduce_all( + tf.not_equal( + model.layers[0].dynamic_lookup_layer.vocabulary, vocab + ) + ) + ) + # assert embedding matrix size + self.assertTrue( + tf.reduce_all( + tf.equal( + tf.shape(model.layers[0].embedding_layer.embeddings), + tf.constant([6, 2]), + ) + ) + ) + + def test_embedding_matrix_update(self): + # Generate dummy data + train_data = np.array( + [ + ["a", "j", "c", "d", "e"], + ["a", "h", "i", "j", "b"], + ["i", "h", "c", "j", "e"], + ] + ) + train_labels = np.array([0, 1, 2]) + vocab = tf.constant(["a", "b", "c", "d", "e"]) + eviction_policy = "LFU" + # Define the model + model = models.Sequential( + [ + dynamic_embedding.DynamicEmbedding( + input_dim=5, + output_dim=2, + input_length=5, + eviction_policy=eviction_policy, + initial_vocabulary=vocab, + ), + layers.Flatten(), + layers.Dense(3, activation="softmax"), + ] + ) + + # Compile the model + model.compile( + optimizer="adam", + loss="sparse_categorical_crossentropy", + metrics=["accuracy"], + ) + # freeze training of all layers + for layer in model.layers: + layer.trainable = False + # define update_embedding_callback to update embedding matrix and + # vocabulary + update_embedding_callback = UpdateEmbeddingCallback( + model.layers[0], + interval=5, + ) + embedding_matrix_before = model.layers[0].embedding_layer.get_weights() + with update_embedding_callback: + model.fit( + train_data, + train_labels, + epochs=100, + batch_size=1, + callbacks=[update_embedding_callback], + ) + # assert the UpdateEmbeddingCallback did modify the embedding matrix + self.assertNotEqual( + model.layers[0].embedding_layer.get_weights(), + embedding_matrix_before, + ) + + def test_get_vocabulary(self): + # Generate dummy data + train_data = np.array( + [ + ["a", "j", "c", "d", "e"], + ["a", "h", "i", "j", "b"], + ["i", "h", "c", "j", "e"], + ] + ) + train_labels = np.array([0, 1, 2]) + vocab = tf.constant(["a", "b", "c", "d", "e"]) + eviction_policy = "LFU" + # Define the model + model = models.Sequential( + [ + dynamic_embedding.DynamicEmbedding( + input_dim=5, + output_dim=2, + input_length=5, + eviction_policy=eviction_policy, + initial_vocabulary=vocab, + ), + layers.Flatten(), + layers.Dense(3, activation="softmax"), + ] + ) + model.compile( + optimizer="adam", + loss="sparse_categorical_crossentropy", + metrics=["accuracy"], + ) + model.fit( + train_data, + train_labels, + epochs=100, + batch_size=1, + ) + vocabulary_output = model.layers[0].get_vocabulary() + self.assertTrue( + tf.reduce_all( + tf.equal( + vocabulary_output, + vocab, + ) + ) + ) + + def test_default_initial_vocabulary(self): + train_data = np.array( + [ + ["a", "j", "c", "d", "e"], + ["a", "h", "i", "j", "b"], + ["i", "h", "c", "j", "e"], + ] + ) + train_labels = np.array([0, 1, 2]) + eviction_policy = "LFU" + # Define the model + model = models.Sequential( + [ + dynamic_embedding.DynamicEmbedding( + input_dim=5, + output_dim=2, + input_length=5, + eviction_policy=eviction_policy, + initial_vocabulary=tf.string, + name="dynamic_embedding", + ), + layers.Flatten(), + layers.Dense(3, activation="softmax"), + ] + ) + + # Compile the model + model.compile( + optimizer="adam", + loss="sparse_categorical_crossentropy", + metrics=["accuracy"], + ) + model.fit( + train_data, + train_labels, + epochs=10, + batch_size=1, + ) + vocabulary_output = model.layers[0].get_vocabulary() + self.assertEqual(vocabulary_output.dtype, tf.string) + self.assertEqual(tf.shape(vocabulary_output)[0], 5) + + +if __name__ == "__main__": + tf.test.main() diff --git a/keras/layers/experimental/dynamic_lookup.py b/keras/layers/experimental/dynamic_lookup.py new file mode 100644 index 000000000000..83a19feb934f --- /dev/null +++ b/keras/layers/experimental/dynamic_lookup.py @@ -0,0 +1,372 @@ +"""Builds a vocabulary from inputs to the layer.""" + +import random +import string + +import tensorflow as tf +from tensorflow.python.util.tf_export import keras_export + +from keras.layers import Layer + + +@keras_export("keras.layers.experimental.DynamicLookup") +class DynamicLookup(Layer): + """A layer that builds a vocabulary from inputs. + + This layer maintains a vocabulary that is continuously updated based on the + inputs passed in every forward pass. The frequency of the input is tracked + and used to maintain a vocabulary. The very last index will be treated as + the index. If `vocabulary_size=10`, OOV index will be 9. + + Args: + vocabulary_size: Integer value representing size of the vocabulary to + build. + initial_vocabulary: The vocabulary to initialize the layer with. If a 1D + tensor is provided, the vocabulary will be initialized with that tensor. + If a `tf.DType` object is provided, a random tensor of that dtype and of + length `vocabulary_size` will be generated as the initial vocabulary. + Supported `tf.DType` values include `tf.int32`, `tf.int64` and + `tf.string`. + eviction_policy: The eviction policy for the vocabulary. Available options + are string values like "LFU" (Least Frequently Used) and *more to come*. + If not specified, the default eviction policy is "LFU". Expects a + string. + **kwargs: Arguments for super class. + + Attributes: get_vocabulary(): Returns a tensor representing the current + vocabulary of the layer. If you want to look up the vocabulary keys given + a set of indices, you can simply use `tf.gather(vocabulary, indices)`. + + Example: + Here is an example to demonstrate how to use the DynamicLookup layer + ``` + vocabulary_size = 3 + eviction_policy = "LFU" + vocab = tf.constant(["apple", "banana", "cherry"]) + layer = DynamicLookup( + vocabulary_size, + vocab, + eviction_policy=eviction_policy, + ) + inputs = tf.constant([ + ["apple", "banana"], + ]) + + outputs = layer(inputs) + tf.print(outputs) + + # you get the following output + [[0 1]] + + # get top k vocab + top_k_vocab = layer.get_top_vocabulary(2) + tf.print(top_k_vocab) + + # you get the following output + ["apple", "banana"] + ``` + If you want to checkpoint the vocabulary or vocabulary frequency, see + the following example + + ``` + checkpoint = + tf.train.Checkpoint(vocabulary=self.vocabulary) + checkpoint.write(filepath) + ``` + """ + + def __init__( + self, + vocabulary_size, + initial_vocabulary, + eviction_policy="LFU", + **kwargs, + ): + """Initializes the DynamicLookup layer.""" + + super().__init__(**kwargs) + self.vocabulary_size = vocabulary_size + self.eviction_policy = eviction_policy + if tf.is_tensor(initial_vocabulary): + self.initial_vocabulary = initial_vocabulary + elif initial_vocabulary in ( + tf.string, + tf.int32, + tf.int64, + ): + self.initial_vocabulary = ( + DynamicLookup._generate_random_initial_vocab( + vocabulary_size, initial_vocabulary + ) + ) + else: + raise ValueError( + "Either specify the initial vocabulary or provide a " + "valid dtype. The dtype argument must be one of the " + "following: tf.string, tf.int32, tf.int64." + ) + # maintain a 20% bigger hash table + self.internal_table_size = tf.cast( + tf.floor( + tf.multiply( + tf.cast(self.vocabulary_size, dtype=tf.float32), 1.2 + ) + ), + dtype=tf.int32, + ) + self.vocabulary_dtype = self.initial_vocabulary.dtype + if self.eviction_policy == "LFU": + self.vocabulary_table_keys = tf.Variable( + initial_value=pad_tensor( + self.initial_vocabulary, self.internal_table_size + ), + shape=tf.TensorShape(self.internal_table_size), + dtype=self.vocabulary_dtype, + trainable=False, + name="vocabulary_table_keys", + per_worker_variable=True, + ) + self.vocabulary_table_values = tf.Variable( + initial_value=pad_tensor( + tf.zeros_like(self.initial_vocabulary, dtype=tf.int32), + self.internal_table_size, + ), + shape=tf.TensorShape(self.internal_table_size), + dtype=tf.int32, + trainable=False, + name="vocabulary_table_values", + per_worker_variable=True, + ) + else: + raise ValueError( + "{} eviction policy is currently unsupported by DynamicLookup" + " layer." + " It currently only supports `LFU`".format(self.eviction_policy) + ) + + # TODO(b/268243335): add more eviction policy + # TODO(b/268243996): provide multiple OOV + + def build(self, input_shape=None): + self.vocabulary = self.add_weight( + shape=self.initial_vocabulary.shape, + dtype=self.vocabulary_dtype, + initializer=tf.constant_initializer( + self.initial_vocabulary.numpy() + ), + trainable=False, + name="vocabulary", + ) + super().build(input_shape) + + def call(self, inputs, learn_vocab=True): + """Learn vocabulary from inputs and perform vocabulary lookup. + + Args: + inputs: Input tensor, or dict/list/tuple of input tensors. + learn_vocab: A boolean value that specifies whether the vocabulary + should be learned from the layer inputs or not. Defaults to True. + + Returns: + A tensor or list/tuple of tensors. + """ + flattened_inputs = tf.reshape(inputs, [-1]) + # get unique values from inputs + unique, _ = tf.unique(flattened_inputs) + unique = tf.cast(unique, dtype=self.vocabulary_dtype) + # learn vocab form inputs + if learn_vocab and self.eviction_policy == "LFU": + self.update_internal_vocabulary(unique) + + # lookup for inputs in self.vocabulary + top_k_vocab = self.vocabulary + lookup_values = tf.expand_dims(flattened_inputs, axis=-1) + condition = tf.reduce_any( + tf.equal(top_k_vocab, tf.expand_dims(lookup_values, -1)), axis=-1 + ) + # the very last index will be the OOV index + indices = tf.where( + condition, + tf.argmax( + tf.equal(top_k_vocab, tf.expand_dims(lookup_values, -1)), + axis=-1, + ), + self.vocabulary_size, + ) + # reshape output to the same shape as input + out = tf.reshape(tf.squeeze(indices), tf.shape(inputs)) + return out + + def update_internal_vocabulary(self, unique): + # get new keys + unpadded_keys = remove_padding(self.vocabulary_table_keys) + unpadded_values = remove_padding(self.vocabulary_table_values) + table_expanded = tf.expand_dims(unpadded_keys, axis=0) + unique_expanded = tf.expand_dims(unique, axis=0) + new_keys = tf.sets.difference( + unique_expanded, table_expanded, aminusb=True + ) + number_of_new_keys = tf.shape(new_keys.values)[0] + # get number of keys to be removed from vocab_frequency + number_of_keys_to_remove = ( + tf.shape(unpadded_keys)[0] + - self.internal_table_size + + number_of_new_keys + ) + number_of_keys_to_remove = tf.cast(number_of_keys_to_remove, tf.int32) + number_of_keys_to_remove = tf.maximum(number_of_keys_to_remove, 0) + # remove old keys + updated_keys, updated_values = self._remove_old_keys( + unpadded_keys, + unpadded_values, + number_of_keys_to_remove, + ) + # add new keys + self._add_new_keys( + updated_keys, + updated_values, + unique, + new_keys, + ) + return unique + + def _remove_old_keys(self, unpadded_keys, unpadded_values, n): + """remove old keys.""" + updated_keys, updated_values = None, None + if self.eviction_policy == "LFU": + # LFU eviction + # negate the values of counts to find the lower n keys to remove + negative_count = tf.math.negative(unpadded_values) + # get index of lower n counts + _, lower_n_index = tf.nn.top_k(negative_count, k=n) + # gather keys that needs to be removed + keys_to_remove = tf.gather(unpadded_keys, lower_n_index) + # get masks for keys not present in inputs + mask = tf.reduce_all( + unpadded_keys[:, tf.newaxis] != keys_to_remove, axis=1 + ) + # updated keys and values with least frequent keys removed + updated_keys = tf.boolean_mask( + unpadded_keys, + mask, + ) + updated_values = tf.boolean_mask( + unpadded_values, + mask, + ) + return updated_keys, updated_values + + def _add_new_keys(self, updated_keys, updated_values, unique, new_keys): + """Add new keys and update internal vocabulary table.""" + if self.eviction_policy == "LFU": + # increment values of old keys when present in current inputs + matches = tf.where( + tf.equal(tf.expand_dims(updated_keys, axis=1), unique) + )[:, 0] + updates = tf.ones_like(matches, dtype=tf.int32) + matches = tf.expand_dims(matches, axis=-1) + values_2 = tf.tensor_scatter_nd_add( + updated_values, matches, updates + ) + # add new keys and corresponding values = 1 + values_difference = tf.ones_like(new_keys.values, dtype=tf.int32) + # concatenate old keys and new keys and pad + updated_keys = pad_tensor( + tf.concat([updated_keys, new_keys.values], axis=0), + self.internal_table_size, + ) + self.vocabulary_table_keys.assign(updated_keys) + # concatenate updated old values and new values and pad + updated_values = pad_tensor( + tf.concat([values_2, values_difference], axis=0), + self.internal_table_size, + ) + self.vocabulary_table_values.assign(updated_values) + return unique + + def get_top_vocabulary(self, k): + """Get top k vocabulary keys.""" + top_k_vocab = None + if self.eviction_policy == "LFU": + values_len = tf.shape(self.vocabulary_table_keys)[0] + if values_len > k: + _, indices = tf.nn.top_k(self.vocabulary_table_values, k=k) + else: + _, indices = tf.nn.top_k( + self.vocabulary_table_values, k=values_len + ) + top_k_vocab = tf.gather(self.vocabulary_table_keys, indices) + return top_k_vocab + + def get_vocabulary(self): + return self.vocabulary + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "initial_vocabulary": self.initial_vocabulary.numpy().tolist(), + "eviction_policy": self.eviction_policy, + } + ) + return config + + def save_assets(self, dir_path): + vocabulary = self.vocabulary.numpy().tolist() + vocabulary_filepath = tf.io.gfile.join(dir_path, "vocabulary.txt") + with open(vocabulary_filepath, "w") as f: + f.write("\n".join([str(w) for w in vocabulary])) + + def _generate_random_initial_vocab( + vocabulary_size, dtype + ): # pylint: disable=no-self-argument + if dtype == tf.string: + chars = string.ascii_letters + random_strings = [ + "".join([random.choice(chars) for _ in range(10)]) + for _ in range(vocabulary_size) + ] + random_vocab = tf.constant(random_strings, dtype=tf.string) + elif dtype == tf.int32: + random_vocab = tf.random.uniform( + shape=[vocabulary_size], + minval=0, + maxval=vocabulary_size, + dtype=tf.int32, + ) + elif dtype == tf.int64: + random_vocab = tf.random.uniform( + shape=[vocabulary_size], + minval=0, + maxval=vocabulary_size, + dtype=tf.int64, + ) + else: + raise ValueError( + "Supported dtype for initial vocabulary include `tf.int32`," + " `tf.int64`, or `tf.string`. But got dtype = {}".format(dtype) + ) + + return random_vocab + + +def pad_tensor(tensor, n): + """Pad a tensor to a fixed length.""" + if tensor.dtype == tf.string: + padding = "unk" + else: + padding = -1 + pad_length = tf.maximum(n - tf.shape(tensor)[0], 0) + padded_tensor = tf.pad(tensor, [[0, pad_length]], constant_values=padding) + return padded_tensor[:n] + + +def remove_padding(tensor): + """Remove padding from a tensor.""" + if tensor.dtype == tf.string: + padding = "unk" + else: + padding = -1 + mask = tf.reshape(tensor != padding, shape=[-1]) + return tf.boolean_mask(tensor, mask) diff --git a/keras/layers/experimental/dynamic_lookup_distributed_test.py b/keras/layers/experimental/dynamic_lookup_distributed_test.py new file mode 100644 index 000000000000..13e7256ea1de --- /dev/null +++ b/keras/layers/experimental/dynamic_lookup_distributed_test.py @@ -0,0 +1,67 @@ +"""Test DynamicEmbedding with Parameter server strategy.""" + +import numpy as np +import tensorflow.compat.v2 as tf +from absl.testing import parameterized + +import keras +from keras.layers.experimental import dynamic_lookup + +ds_combinations = tf.__internal__.distribute.combinations + + +class DistributedDynamiclookupTest(tf.test.TestCase, parameterized.TestCase): + @ds_combinations.generate( + tf.__internal__.test.combinations.combine( + strategy=[ + ds_combinations.parameter_server_strategy_3worker_2ps_cpu + ], + mode="eager", + ) + ) + def test_dynamic_lookup_with_pss(self, strategy): + train_data = np.array( + [ + ["a", "j", "c", "d", "e"], + ["a", "h", "i", "j", "b"], + ["i", "h", "c", "j", "e"], + ] + ) + train_labels = np.array([0, 1, 2]) + vocab = tf.constant(["a", "b", "c", "d", "e"]) + vocabulary_size = 5 + eviction_policy = "LFU" + with strategy.scope(): + # Define the model + model = keras.models.Sequential( + [ + dynamic_lookup.DynamicLookup( + vocabulary_size, + initial_vocabulary=vocab, + eviction_policy=eviction_policy, + name="dynamic_lookup", + ), + keras.layers.Flatten(), + keras.layers.Dense(3, activation="softmax"), + ] + ) + + # Compile the model + model.compile( + optimizer="adam", + loss="sparse_categorical_crossentropy", + metrics=["accuracy"], + ) + result = model.fit( + train_data, + train_labels, + epochs=10, + batch_size=1, + steps_per_epoch=1, + ) + # Assert model trains + self.assertEqual(result.history["loss"][0] > 0, True) + + +if __name__ == "__main__": + tf.__internal__.distribute.multi_process_runner.test_main() diff --git a/keras/layers/experimental/dynamic_lookup_test.py b/keras/layers/experimental/dynamic_lookup_test.py new file mode 100644 index 000000000000..f49626918eb9 --- /dev/null +++ b/keras/layers/experimental/dynamic_lookup_test.py @@ -0,0 +1,253 @@ +"""Test for dynamic_lookup layer.""" + +import os +import shutil +import tempfile + +import numpy as np +import tensorflow as tf + +import keras +from keras.layers.experimental import dynamic_lookup +from keras.testing_infra import test_combinations + + +class DynamicLookupTest(test_combinations.TestCase): + def test_dynamic_lookup_layer(self): + vocabulary_size = 5 + eviction_policy = "LFU" + vocab = tf.constant(["apple", "banana", "cherry", "grape", "juice"]) + # vocab_frequency({apple:0, banana:0, cherry:0, grape:0, juice:0}) + # hash table size is 1.2Xvocab size. in this case 5x1.2 = 6 + layer = dynamic_lookup.DynamicLookup( + vocabulary_size, + initial_vocabulary=vocab, + eviction_policy=eviction_policy, + ) + + input_1 = tf.constant(["apple", "banana", "cherry"]) + layer(input_1) + # vocab_frequency({apple:1, banana:1, cherry:1, grape:0, juice:0}) + input_2 = tf.constant(["apple", "banana", "mango"]) + layer(input_2) + # vocab_frequency({apple:2, banana:2, cherry:1, grape:0, juice:0, mango: + # 1}) + input_3 = tf.constant(["fig", "date", "date"]) + layer(input_3) + # vocab_frequency({apple:2, banana:2, cherry:1, fig:1, date:1, mango:1}) + input_4 = tf.constant(["banana", "jackfruit", "honeydew"]) + layer(input_4) + # vocab_frequency({apple:2, banana:3, jackfruit:1, fig:1, date:1, + # honeydew:1}) + input_5 = tf.constant(["banana", "apple", "jackfruit"]) + # vocab_frequency({apple:3, banana:4, jackfruit:2, fig:1, date:1, + # honeydew:1}) + outputs = layer(input_5) + expected_output = tf.constant([1, 0, 5], dtype=tf.int64) + # verify if look up values are accurate + self.assertTrue(tf.reduce_all(tf.equal(outputs, expected_output))) + # Check the shape of the output + self.assertEqual(outputs.shape, input_4.shape) + + # Check that the top-k vocab is correctly updated + top_k_vocab = layer.get_top_vocabulary(3) + expected_top_k_vocab = tf.constant( + ["banana", "apple", "jackfruit"], + dtype=tf.string, + ) + self.assertTrue( + tf.reduce_all(tf.equal(top_k_vocab, expected_top_k_vocab)) + ) + + def test_layer_with_model(self): + train_data = np.array( + [ + ["a", "j", "c", "d", "e"], + ["a", "h", "i", "j", "b"], + ["i", "h", "c", "j", "e"], + ] + ) + train_labels = np.array([0, 1, 2]) + vocab = tf.constant(["a", "b", "c", "d", "e"]) + vocabulary_size = 5 + eviction_policy = "LFU" + + # Define the model + model = keras.models.Sequential( + [ + dynamic_lookup.DynamicLookup( + vocabulary_size, + initial_vocabulary=vocab, + eviction_policy=eviction_policy, + ), + keras.layers.Flatten(), + keras.layers.Dense(3, activation="softmax"), + ] + ) + + # Compile the model + model.compile( + optimizer="adam", + loss="sparse_categorical_crossentropy", + metrics=["accuracy"], + ) + result = model.fit( + train_data, + train_labels, + epochs=10, + batch_size=1, + ) + # Assert model trains + self.assertEqual(result.history["loss"][0] > 0, True) + + def test_model_save_load(self): + train_data = np.array( + [ + ["a", "j", "c", "d", "e"], + ["a", "h", "i", "j", "b"], + ["i", "h", "c", "j", "e"], + ] + ) + train_labels = np.array([0, 1, 2]) + vocab = tf.constant(["a", "b", "c", "d", "e"]) + vocabulary_size = 5 + eviction_policy = "LFU" + + # Define the model + model = keras.models.Sequential( + [ + dynamic_lookup.DynamicLookup( + vocabulary_size, + initial_vocabulary=vocab, + eviction_policy=eviction_policy, + name="dynamic_lookup", + ), + keras.layers.Flatten(), + keras.layers.Dense(3, activation="softmax"), + ] + ) + + # Compile the model + model.compile( + optimizer="adam", + loss="sparse_categorical_crossentropy", + metrics=["accuracy"], + ) + model.fit( + train_data, + train_labels, + epochs=10, + batch_size=1, + ) + # Save the model to a temporary file + filepath = os.path.join(tempfile.gettempdir(), "tempdir") + model.save(filepath) + reloaded_model = keras.models.load_model(filepath) + self.assertTrue( + tf.reduce_all( + tf.equal( + model.get_layer("dynamic_lookup").vocabulary.numpy(), + reloaded_model.get_layer( + "dynamic_lookup" + ).vocabulary.numpy(), + ) + ) + ) + shutil.rmtree(filepath) + + def test_dynamic_lookup_layer_learn_vocab_arg(self): + vocabulary_size = 5 + eviction_policy = "LFU" + vocab = tf.constant(["apple", "banana", "cherry", "grape", "juice"]) + # vocab_frequency({apple:0, banana:0, cherry:0, grape:0, juice:0}) + # hash table size is 1.2Xvocab size. in this case 5x1.2 = 6 + layer = dynamic_lookup.DynamicLookup( + vocabulary_size, + initial_vocabulary=vocab, + eviction_policy=eviction_policy, + ) + + input_1 = tf.constant(["apple", "banana", "cherry"]) + layer(input_1, learn_vocab=False) + input_2 = tf.constant(["apple", "banana", "mango"]) + layer(input_2, learn_vocab=False) + input_3 = tf.constant(["fig", "date", "date"]) + layer(input_3, learn_vocab=False) + input_4 = tf.constant(["banana", "jackfruit", "honeydew"]) + layer(input_4, learn_vocab=False) + input_5 = tf.constant(["banana", "apple", "jackfruit"]) + layer(input_5, learn_vocab=False) + # Check that the top-k vocab is not updated + top_k_vocab = layer.get_top_vocabulary(5) + expected_top_k_vocab = tf.constant( + ["apple", "banana", "cherry", "grape", "juice"], + dtype=tf.string, + ) + self.assertTrue( + tf.reduce_all(tf.equal(top_k_vocab, expected_top_k_vocab)) + ) + + def test_get_vocabulary(self): + vocabulary_size = 5 + eviction_policy = "LFU" + vocab = tf.constant(["apple", "banana", "cherry", "grape", "juice"]) + layer = dynamic_lookup.DynamicLookup( + vocabulary_size, + initial_vocabulary=vocab, + eviction_policy=eviction_policy, + ) + input_1 = tf.constant(["apple", "banana", "cherry"]) + layer(input_1, learn_vocab=False) + vocabulary_output = layer.get_vocabulary() + self.assertTrue(tf.reduce_all(tf.equal(vocabulary_output, vocab))) + + def test_default_vocab(self): + # test default initial vocabulary tf.string + vocabulary_size = 5 + eviction_policy = "LFU" + layer1 = dynamic_lookup.DynamicLookup( + vocabulary_size, + initial_vocabulary=tf.string, + eviction_policy=eviction_policy, + ) + input_1 = tf.constant(["apple", "banana", "cherry"]) + layer1(input_1, learn_vocab=False) + vocabulary_output = layer1.get_vocabulary() + self.assertEqual(vocabulary_output.dtype, tf.string) + self.assertEqual(tf.shape(vocabulary_output)[0], vocabulary_size) + + # test default initial vocabulary tf.int32 + layer2 = dynamic_lookup.DynamicLookup( + vocabulary_size, + initial_vocabulary=tf.int32, + eviction_policy=eviction_policy, + ) + input_2 = tf.constant([1, 2, 3], dtype=tf.int32) + layer2(input_2, learn_vocab=False) + vocabulary_output = layer2.get_vocabulary() + self.assertEqual(vocabulary_output.dtype, tf.int32) + self.assertEqual(tf.shape(vocabulary_output)[0], vocabulary_size) + + # test default initial vocabulary tf.int64 + layer3 = dynamic_lookup.DynamicLookup( + vocabulary_size, + initial_vocabulary=tf.int64, + eviction_policy=eviction_policy, + ) + input_3 = tf.constant([1, 2, 3], dtype=tf.int64) + layer3(input_3, learn_vocab=False) + vocabulary_output = layer3.get_vocabulary() + self.assertEqual(vocabulary_output.dtype, tf.int64) + self.assertEqual(tf.shape(vocabulary_output)[0], vocabulary_size) + + # test value error when default initial vocabulary is tf.float32 + with self.assertRaises(ValueError): + layer3 = dynamic_lookup.DynamicLookup( + vocabulary_size, + initial_vocabulary=tf.float32, + eviction_policy=eviction_policy, + ) + + +if __name__ == "__main__": + tf.test.main()