Skip to content

Commit

Permalink
Use standard Keras embedding layers inside EmbeddingFeatures block (#…
Browse files Browse the repository at this point in the history
…472)

* Use standard Keras embedding layers inside `EmbeddingFeatures` block

* Fix failing `EmbeddingFeatures` tests

Co-authored-by: Marc Romeyn <[email protected]>
Co-authored-by: Gabriel Moreira <[email protected]>
Co-authored-by: rnyak <[email protected]>
  • Loading branch information
4 people authored Jun 22, 2022
1 parent d2bc6fe commit 886cf6d
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 33 deletions.
14 changes: 8 additions & 6 deletions merlin/models/tf/blocks/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,14 @@ class ModelContext(Layer):
(This is created automatically in the model and doesn't need to be created manually.)
"""

def add_embedding_weight(self, name, **kwargs):
table = self.add_weight(name=f"{str(name)}/embedding", **kwargs)

return table

def add_variable(self, variable):
setattr(self, variable.name, variable)

def add_embedding_table(self, name, embedding_table):
embedding_tables = getattr(self, "embedding_tables", {})
embedding_tables[name] = embedding_table
setattr(self, "embedding_tables", embedding_tables)

def set_dtypes(self, features):
for feature_name in features:
feature = features[feature_name]
Expand Down Expand Up @@ -124,7 +124,9 @@ def get_embedding(self, item):
item = item.value
else:
item = str(item)
return self.named_variables[f"{item}/embedding"]

embedding_tables = getattr(self, "embedding_tables", {})
return embedding_tables[item].embeddings

@property
def named_variables(self) -> Dict[str, tf.Variable]:
Expand Down
21 changes: 12 additions & 9 deletions merlin/models/tf/inputs/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,15 +208,18 @@ def build(self, input_shapes):
tables[table.name] = table

for name, table in tables.items():
add_fn = (
self.context.add_embedding_weight if hasattr(self, "_context") else self.add_weight
)
self.embedding_tables[name] = add_fn(
self.embedding_tables[name] = tf.keras.layers.Embedding(
table.vocabulary_size,
table.dim,
name=name,
trainable=True,
initializer=table.initializer,
shape=(table.vocabulary_size, table.dim),
embeddings_initializer=table.initializer,
)

self.embedding_tables[name].build(())

if hasattr(self, "_context"):
self._context.add_embedding_table(name, self.embedding_tables[name])

if isinstance(input_shapes, dict):
super().build(input_shapes)
else:
Expand Down Expand Up @@ -246,7 +249,7 @@ def lookup_feature(self, name, val, output_sequence=False):
val = tf.cast(val, "int32")

table: TableConfig = self.feature_config[name].table
table_var = self.embedding_tables[table.name]
table_var = self.embedding_tables[table.name].embeddings
if isinstance(val, tf.SparseTensor):
out = tf.nn.safe_embedding_lookup_sparse(table_var, val, None, combiner=table.combiner)
else:
Expand Down Expand Up @@ -281,7 +284,7 @@ def get_embedding_table(self, table_name: Union[str, Tags], l2_normalization: bo
else:
raise ValueError(f"Could not find a feature associated to the tag {table_name}")

embeddings = self.embedding_tables[table_name]
embeddings = self.embedding_tables[table_name].embeddings
if l2_normalization:
embeddings = tf.linalg.l2_normalize(embeddings, axis=-1)

Expand Down
64 changes: 46 additions & 18 deletions tests/unit/tf/inputs/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_embedding_features_yoochoose(testing_data: Dataset):
assert sorted(list(embeddings.keys())) == sorted(schema.column_names)
assert all(emb.shape[-1] == 512 for emb in embeddings.values())
max_value = list(schema.select_by_name("item_id"))[0].int_domain.max
assert emb_module.embedding_tables["item_id"].shape[0] == max_value + 1
assert emb_module.embedding_tables["item_id"].embeddings.shape[0] == max_value + 1

# These embeddings have not a specific initializer, so they should
# have default truncated normal initialization
Expand Down Expand Up @@ -101,8 +101,8 @@ def test_embedding_features_yoochoose_custom_dims(testing_data: Dataset):

assert len(emb_module.losses) == 0, "There should be no regularization loss by default"

assert emb_module.embedding_tables["item_id"].shape[1] == 100
assert emb_module.embedding_tables["categories"].shape[1] == 64
assert emb_module.embedding_tables["item_id"].embeddings.shape[1] == 100
assert emb_module.embedding_tables["categories"].embeddings.shape[1] == 64

assert embeddings["item_id"].shape[1] == 100
assert embeddings["categories"].shape[1] == 64
Expand Down Expand Up @@ -142,15 +142,23 @@ def test_embedding_features_yoochoose_infer_embedding_sizes(testing_data: Datase

embeddings = emb_module(mm.sample_batch(testing_data, batch_size=100, include_targets=False))

assert emb_module.embedding_tables["user_id"].shape[1] == embeddings["user_id"].shape[1] == 20
assert (
emb_module.embedding_tables["user_country"].shape[1]
emb_module.embedding_tables["user_id"].embeddings.shape[1]
== embeddings["user_id"].shape[1]
== 20
)
assert (
emb_module.embedding_tables["user_country"].embeddings.shape[1]
== embeddings["user_country"].shape[1]
== 9
)
assert emb_module.embedding_tables["item_id"].shape[1] == embeddings["item_id"].shape[1] == 46
assert (
emb_module.embedding_tables["categories"].shape[1]
emb_module.embedding_tables["item_id"].embeddings.shape[1]
== embeddings["item_id"].shape[1]
== 46
)
assert (
emb_module.embedding_tables["categories"].embeddings.shape[1]
== embeddings["categories"].shape[1]
== 13
)
Expand All @@ -170,15 +178,23 @@ def test_embedding_features_yoochoose_infer_embedding_sizes_multiple_8(testing_d

embeddings = emb_module(mm.sample_batch(testing_data, batch_size=100, include_targets=False))

assert emb_module.embedding_tables["user_id"].shape[1] == embeddings["user_id"].shape[1] == 24
assert (
emb_module.embedding_tables["user_country"].shape[1]
emb_module.embedding_tables["user_id"].embeddings.shape[1]
== embeddings["user_id"].shape[1]
== 24
)
assert (
emb_module.embedding_tables["user_country"].embeddings.shape[1]
== embeddings["user_country"].shape[1]
== 16
)
assert emb_module.embedding_tables["item_id"].shape[1] == embeddings["item_id"].shape[1] == 48
assert (
emb_module.embedding_tables["categories"].shape[1]
emb_module.embedding_tables["item_id"].embeddings.shape[1]
== embeddings["item_id"].shape[1]
== 48
)
assert (
emb_module.embedding_tables["categories"].embeddings.shape[1]
== embeddings["categories"].shape[1]
== 16
)
Expand All @@ -198,15 +214,23 @@ def test_embedding_features_yoochoose_partially_infer_embedding_sizes(testing_da

embeddings = emb_module(mm.sample_batch(testing_data, batch_size=100, include_targets=False))

assert emb_module.embedding_tables["user_id"].shape[1] == embeddings["user_id"].shape[1] == 50
assert (
emb_module.embedding_tables["user_country"].shape[1]
emb_module.embedding_tables["user_id"].embeddings.shape[1]
== embeddings["user_id"].shape[1]
== 50
)
assert (
emb_module.embedding_tables["user_country"].embeddings.shape[1]
== embeddings["user_country"].shape[1]
== 100
)
assert emb_module.embedding_tables["item_id"].shape[1] == embeddings["item_id"].shape[1] == 46
assert (
emb_module.embedding_tables["categories"].shape[1]
emb_module.embedding_tables["item_id"].embeddings.shape[1]
== embeddings["item_id"].shape[1]
== 46
)
assert (
emb_module.embedding_tables["categories"].embeddings.shape[1]
== embeddings["categories"].shape[1]
== 13
)
Expand Down Expand Up @@ -272,8 +296,12 @@ def test_embedding_features_yoochoose_pretrained_initializer(testing_data: Datas
# Calling the first batch, so that embedding tables are build
_ = emb_module(mm.sample_batch(testing_data, batch_size=10, include_targets=False))

assert np.allclose(emb_module.embedding_tables["item_id"].numpy(), pretrained_emb_item_ids)
assert np.allclose(emb_module.embedding_tables["categories"].numpy(), pretrained_emb_categories)
assert np.allclose(
emb_module.embedding_tables["item_id"].embeddings.numpy(), pretrained_emb_item_ids
)
assert np.allclose(
emb_module.embedding_tables["categories"].embeddings.numpy(), pretrained_emb_categories
)


def test_embedding_features_exporting_and_loading_pretrained_initializer(testing_data: Dataset):
Expand All @@ -282,7 +310,7 @@ def test_embedding_features_exporting_and_loading_pretrained_initializer(testing

# Calling the first batch, so that embedding tables are build
_ = emb_module(mm.sample_batch(testing_data, batch_size=10, include_targets=False))
item_id_embeddings = emb_module.embedding_tables["item_id"]
item_id_embeddings = emb_module.embedding_tables["item_id"].embeddings

items_embeddings_dataset = emb_module.embedding_table_dataset(Tags.ITEM_ID, gpu=False)
assert np.allclose(
Expand Down

0 comments on commit 886cf6d

Please sign in to comment.