Skip to content

Commit

Permalink
Add DynamicEmbedding to Keras
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 556908566
  • Loading branch information
divyashreepathihalli authored and tensorflower-gardener committed Aug 16, 2023
1 parent 29b1384 commit 9702bdc
Show file tree
Hide file tree
Showing 11 changed files with 1,429 additions and 3 deletions.
1 change: 1 addition & 0 deletions keras/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ py_library(
"//keras/protobuf:projector_config_proto_py_pb2",
"//keras/utils:engine_utils",
"//keras/utils:mode_keys",
"//keras/utils:timed_threads",
],
)

Expand Down
70 changes: 67 additions & 3 deletions keras/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@
import sys
import time

import numpy as np
import tensorflow.compat.v2 as tf

from keras import backend
from keras.distribute import distributed_file_utils
from keras.distribute import worker_training_state
Expand All @@ -40,6 +37,9 @@
from keras.utils.data_utils import Sequence
from keras.utils.generic_utils import Progbar
from keras.utils.mode_keys import ModeKeys
import numpy as np
import tensorflow.compat.v2 as tf
from keras.utils.timed_threads import TimedThread

# isort: off
from tensorflow.python.platform import tf_logging as logging
Expand Down Expand Up @@ -3306,3 +3306,67 @@ def __init__(
self.on_train_begin = on_train_begin
if on_train_end is not None:
self.on_train_end = on_train_end


@keras_export("keras.callbacks.UpdateEmbeddingCallback")
class UpdateEmbeddingCallback(TimedThread, Callback):
"""A callback to update the DynamicEmbedding layer at specific time interval.
Updating the embedding matrix would mean that the optimizer variables will be
reset in this callback and this could have potential side effects. This means
that any existing slot variables associated with the optimizer will likely be
discarded when the optimizer is rebuilt. This affects optimizers that rely on
states of optimizer slot variables.
"""

def __init__(self, dynamic_embedding_layer, interval):
"""Initialize Timed Callback object.
Args:
dynamic_embedding_layer: The dynamic embedding
layer to be updated.
interval: the interval, in seconds, to wait between calls to the
thread function.
"""
self._epoch = 0
tf.keras.utils.TimedThread.__init__(self, interval)
tf.keras.callbacks.Callback.__init__(self)
self._dynamic_embedding_layer = dynamic_embedding_layer
self.strategy = tf.distribute.get_strategy()

def on_interval(self):
try:
critical_section = tf.CriticalSection()

# Using `tf.CriticalSection` when updating embeddings using timed thread
# can help ensure thread safety and prevent race conditions in the shared
# variables.
def execute_critical_section():
critical_section.execute(
lambda: self._dynamic_embedding_layer.update_embeddings( # pylint: disable=g-long-lambda
self.strategy
)
)
# update embeddings across all devices if distributed training is used
self.strategy.run(execute_critical_section)
# update optimizer variables across all devices if distributed training is
# used.
self.strategy.run(lambda: self._reset_optimizer()) # pylint: disable=unnecessary-lambda
except Exception as e: # pylint: disable=broad-exception-caught
# Once training is complete, optimizer doesn't have iterations
logging.info("Exception raised in UpdateEmbeddingCallback %s", e)

def _reset_optimizer(self):
"""Resetting the optimizer variables.
Resetting the optimizer variables is necessary after updating the variable
in the layer. This ensures that the optimizer is working with a consistent
internal state. This helps to prevent unexpected behavior and can lead to
more stable and faster training of the model.
"""
for var in self.model.optimizer.variables():
if "dynamic_embedding" in var.name:
backend.set_value(var, backend.zeros_like(var))

def on_epoch_begin(self, epoch, logs=None):
self._epoch = epoch
1 change: 1 addition & 0 deletions keras/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
from keras.layers.core.tf_op_layer import SlicingOpLambda
from keras.layers.core.tf_op_layer import TFOpLambda


# Locally-connected layers.
from keras.layers.locally_connected.locally_connected1d import (
LocallyConnected1D,
Expand Down
98 changes: 98 additions & 0 deletions keras/layers/dynamic_embedding/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Description:
# DynamicEmbeddingLayer allows for the continuous updating of the vocabulary and embeddings during
# the training process.

# Placeholder: load unaliased py_library
load("@org_keras//keras:keras.bzl", "tf_py_test")
load("@org_keras//keras:keras.bzl", "distribute_py_test")

package(
default_applicable_licenses = ["//keras:license"],
default_visibility = [
"//visibility:public",
],
licenses = ["notice"],
)

py_library(
name = "dynamic_lookup",
srcs = ["dynamic_lookup.py"],
srcs_version = "PY3",
deps = [
"//:expect_tensorflow_installed",
"//keras/layers",
"//third_party/tensorflow/python/util:core",
],
)

tf_py_test(
name = "dynamic_lookup_test",
size = "small",
srcs = ["dynamic_lookup_test.py"],
python_version = "PY3",
shard_count = 6,
deps = [
":dynamic_lookup",
"//:expect_numpy_installed",
"//:expect_tensorflow_installed",
"//keras",
"//testing/pybase",
],
)

py_library(
name = "dynamic_embedding",
srcs = ["dynamic_embedding.py"],
deps = [
":dynamic_lookup",
"//:expect_tensorflow_installed",
"//keras/layers",
"//keras/utils",
"//third_party/py/absl/logging",
"//third_party/tensorflow/python/util:core",
],
)

tf_py_test(
name = "dynamic_embedding_test",
srcs = ["dynamic_embedding_test.py"],
deps = [
":dynamic_embedding",
"//:expect_tensorflow_installed",
"//keras:callbacks",
"//keras/layers",
"//keras/models",
"//testing/pybase",
],
)

distribute_py_test(
name = "dynamic_embedding_distributed_test",
srcs = ["dynamic_embedding_distributed_test.py"],
tags = [
"no_oss",
"no_windows",
],
deps = [
":dynamic_embedding",
"//:expect_absl_installed",
"//:expect_numpy_installed",
"//:expect_tensorflow_installed",
"//keras",
"//keras:callbacks",
"//keras/testing_infra:test_combinations",
],
)

distribute_py_test(
name = "dynamic_lookup_distributed_test",
srcs = ["dynamic_lookup_distributed_test.py"],
deps = [
":dynamic_lookup",
"//:expect_absl_installed",
"//:expect_numpy_installed",
"//:expect_tensorflow_installed",
"//keras",
"//keras/testing_infra:test_combinations",
],
)
Empty file.
Loading

0 comments on commit 9702bdc

Please sign in to comment.