Skip to content

Commit

Permalink
Add DynamicEmbedding to Keras
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 556908566
  • Loading branch information
divyashreepathihalli authored and tensorflower-gardener committed Aug 22, 2023
1 parent 94a1712 commit 8a485e2
Show file tree
Hide file tree
Showing 16 changed files with 1,771 additions and 0 deletions.
1 change: 1 addition & 0 deletions keras/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ py_library(
"//keras/protobuf:projector_config_proto_py_pb2",
"//keras/utils:engine_utils",
"//keras/utils:mode_keys",
"//keras/utils:timed_threads",
],
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
path: "tensorflow.keras.callbacks.experimental.UpdateEmbeddingCallback"
tf_class {
is_instance: "<class \'keras.callbacks.UpdateEmbeddingCallback\'>"
is_instance: "<class \'keras.utils.timed_threads.TimedThread\'>"
is_instance: "<class \'keras.callbacks.Callback\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'dynamic_embedding_layer\', \'interval\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_alive"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "on_batch_begin"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_batch_end"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_epoch_begin"
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_epoch_end"
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_interval"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "on_predict_batch_begin"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_predict_batch_end"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_predict_begin"
argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_predict_end"
argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_test_batch_begin"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_test_batch_end"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_test_begin"
argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_test_end"
argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_train_batch_begin"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_train_batch_end"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_train_begin"
argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_train_end"
argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "set_model"
argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "set_params"
argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "start"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "stop"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
path: "tensorflow.keras.callbacks.experimental"
tf_module {
member {
name: "UpdateEmbeddingCallback"
mtype: "<type \'type\'>"
}
}
4 changes: 4 additions & 0 deletions keras/api/golden/v1/tensorflow.keras.callbacks.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,8 @@ tf_module {
name: "TerminateOnNaN"
mtype: "<type \'type\'>"
}
member {
name: "experimental"
mtype: "<type \'module\'>"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
path: "tensorflow.keras.callbacks.experimental.UpdateEmbeddingCallback"
tf_class {
is_instance: "<class \'keras.callbacks.UpdateEmbeddingCallback\'>"
is_instance: "<class \'keras.utils.timed_threads.TimedThread\'>"
is_instance: "<class \'keras.callbacks.Callback\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'dynamic_embedding_layer\', \'interval\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_alive"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "on_batch_begin"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_batch_end"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_epoch_begin"
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_epoch_end"
argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_interval"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "on_predict_batch_begin"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_predict_batch_end"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_predict_begin"
argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_predict_end"
argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_test_batch_begin"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_test_batch_end"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_test_begin"
argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_test_end"
argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_train_batch_begin"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_train_batch_end"
argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_train_begin"
argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "on_train_end"
argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "set_model"
argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "set_params"
argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "start"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "stop"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,8 @@ tf_module {
name: "BackupAndRestore"
mtype: "<type \'type\'>"
}
member {
name: "UpdateEmbeddingCallback"
mtype: "<type \'type\'>"
}
}
117 changes: 117 additions & 0 deletions keras/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from keras.utils.data_utils import Sequence
from keras.utils.generic_utils import Progbar
from keras.utils.mode_keys import ModeKeys
from keras.utils.timed_threads import TimedThread

# isort: off
from tensorflow.python.platform import tf_logging as logging
Expand Down Expand Up @@ -3306,3 +3307,119 @@ def __init__(
self.on_train_begin = on_train_begin
if on_train_end is not None:
self.on_train_end = on_train_end


@keras_export("keras.callbacks.experimental.UpdateEmbeddingCallback")
class UpdateEmbeddingCallback(TimedThread, Callback):
"""A callback to update the DynamicEmbedding layer at specific time
interval.
Updating the embedding matrix would mean that the optimizer variables will
be reset in this callback and this could have potential side effects. This
means that any existing slot variables associated with the optimizer will
likely be discarded when the optimizer is rebuilt. This affects optimizers
that rely on states of optimizer slot variables.
Example:
```
# Generate dummy data
train_data = np.array([
['a', 'j', 'c', 'd', 'e'],
['a', 'h', 'i', 'j', 'b'],
['i', 'h', 'c', 'j', 'e'],
])
train_labels = np.array([0, 1, 2])
vocab = tf.constant(['a', 'b', 'c', 'd', 'e'])
eviction_policy = 'LFU'
# Define the model
model = tf.keras.models.Sequential([
DynamicEmbedding(
input_dim=5,
output_dim=2,
input_length=5,
eviction_policy=eviction_policy,
initial_vocabulary=vocab,
),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(3, activation='softmax'),
])
# Compile the model
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'],
)
# update the vocabulary every 1 second
update_embedding_callback = UpdateEmbeddingCallback(
model.layers[0], interval=1
)
with update_embedding_callback:
result = model.fit(
train_data,
train_labels,
epochs=100,
batch_size=1,
callbacks=[update_embedding_callback],
)
```
"""

def __init__(self, dynamic_embedding_layer, interval):
"""Initialize Timed Callback object.
Args:
dynamic_embedding_layer: The dynamic embedding
layer to be updated.
interval: the interval, in seconds, to wait between calls to the
thread function. The thread function here updates the embeddings
matrix and resets the optimizer states.
"""
self._epoch = 0
TimedThread.__init__(self, interval)
Callback.__init__(self)
self._dynamic_embedding_layer = dynamic_embedding_layer
self.strategy = tf.distribute.get_strategy()

def on_interval(self):
try:
critical_section = tf.CriticalSection()

# Using `tf.CriticalSection` when updating embeddings using timed
# thread can help ensure thread safety and prevent race conditions
# in the shared variables.
def execute_critical_section():
critical_section.execute(
lambda: self._dynamic_embedding_layer.update_embeddings( # pylint: disable=g-long-lambda
self.strategy
)
)

# update embeddings across all devices if distributed training is
# used
self.strategy.run(execute_critical_section)
# update optimizer variables across all devices if distributed
# training is used.
self.strategy.run(
lambda: self._reset_optimizer()
) # pylint: disable=unnecessary-lambda
except AttributeError:
logging.info(
"Time interval specified to the UpdateEmbeddingCallback may be"
" too small, please try increasing the value of `interval`."
)

def _reset_optimizer(self):
"""Resetting the optimizer variables.
Resetting the optimizer variables is necessary after updating the
variable in the layer. This ensures that the optimizer is working with a
consistent internal state. This helps to prevent unexpected behavior and
can lead to more stable and faster training of the model.
"""
for var in self.model.optimizer.variables():
if "dynamic_embedding" in var.name:
backend.set_value(var, backend.zeros_like(var))

def on_epoch_begin(self, epoch, logs=None):
self._epoch = epoch
1 change: 1 addition & 0 deletions keras/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
from keras.layers.core.tf_op_layer import SlicingOpLambda
from keras.layers.core.tf_op_layer import TFOpLambda


# Locally-connected layers.
from keras.layers.locally_connected.locally_connected1d import (
LocallyConnected1D,
Expand Down
Loading

0 comments on commit 8a485e2

Please sign in to comment.