diff --git a/merlin/models/tf/blocks/core/base.py b/merlin/models/tf/blocks/core/base.py index 6a2dd9067f..82fa64f2e3 100644 --- a/merlin/models/tf/blocks/core/base.py +++ b/merlin/models/tf/blocks/core/base.py @@ -89,14 +89,14 @@ class ModelContext(Layer): (This is created automatically in the model and doesn't need to be created manually.) """ - def add_embedding_weight(self, name, **kwargs): - table = self.add_weight(name=f"{str(name)}/embedding", **kwargs) - - return table - def add_variable(self, variable): setattr(self, variable.name, variable) + def add_embedding_table(self, name, embedding_table): + embedding_tables = getattr(self, "embedding_tables", {}) + embedding_tables[name] = embedding_table + setattr(self, "embedding_tables", embedding_tables) + def set_dtypes(self, features): for feature_name in features: feature = features[feature_name] @@ -124,7 +124,9 @@ def get_embedding(self, item): item = item.value else: item = str(item) - return self.named_variables[f"{item}/embedding"] + + embedding_tables = getattr(self, "embedding_tables", {}) + return embedding_tables[item].embeddings @property def named_variables(self) -> Dict[str, tf.Variable]: diff --git a/merlin/models/tf/inputs/embedding.py b/merlin/models/tf/inputs/embedding.py index 9896bdd7d4..417011b2dc 100644 --- a/merlin/models/tf/inputs/embedding.py +++ b/merlin/models/tf/inputs/embedding.py @@ -208,15 +208,18 @@ def build(self, input_shapes): tables[table.name] = table for name, table in tables.items(): - add_fn = ( - self.context.add_embedding_weight if hasattr(self, "_context") else self.add_weight - ) - self.embedding_tables[name] = add_fn( + self.embedding_tables[name] = tf.keras.layers.Embedding( + table.vocabulary_size, + table.dim, name=name, - trainable=True, - initializer=table.initializer, - shape=(table.vocabulary_size, table.dim), + embeddings_initializer=table.initializer, ) + + self.embedding_tables[name].build(()) + + if hasattr(self, "_context"): + self._context.add_embedding_table(name, self.embedding_tables[name]) + if isinstance(input_shapes, dict): super().build(input_shapes) else: @@ -246,7 +249,7 @@ def lookup_feature(self, name, val, output_sequence=False): val = tf.cast(val, "int32") table: TableConfig = self.feature_config[name].table - table_var = self.embedding_tables[table.name] + table_var = self.embedding_tables[table.name].embeddings if isinstance(val, tf.SparseTensor): out = tf.nn.safe_embedding_lookup_sparse(table_var, val, None, combiner=table.combiner) else: @@ -281,7 +284,7 @@ def get_embedding_table(self, table_name: Union[str, Tags], l2_normalization: bo else: raise ValueError(f"Could not find a feature associated to the tag {table_name}") - embeddings = self.embedding_tables[table_name] + embeddings = self.embedding_tables[table_name].embeddings if l2_normalization: embeddings = tf.linalg.l2_normalize(embeddings, axis=-1) diff --git a/tests/unit/tf/inputs/test_embedding.py b/tests/unit/tf/inputs/test_embedding.py index 11218d1ed9..3ee5224148 100644 --- a/tests/unit/tf/inputs/test_embedding.py +++ b/tests/unit/tf/inputs/test_embedding.py @@ -48,7 +48,7 @@ def test_embedding_features_yoochoose(testing_data: Dataset): assert sorted(list(embeddings.keys())) == sorted(schema.column_names) assert all(emb.shape[-1] == 512 for emb in embeddings.values()) max_value = list(schema.select_by_name("item_id"))[0].int_domain.max - assert emb_module.embedding_tables["item_id"].shape[0] == max_value + 1 + assert emb_module.embedding_tables["item_id"].embeddings.shape[0] == max_value + 1 # These embeddings have not a specific initializer, so they should # have default truncated normal initialization @@ -101,8 +101,8 @@ def test_embedding_features_yoochoose_custom_dims(testing_data: Dataset): assert len(emb_module.losses) == 0, "There should be no regularization loss by default" - assert emb_module.embedding_tables["item_id"].shape[1] == 100 - assert emb_module.embedding_tables["categories"].shape[1] == 64 + assert emb_module.embedding_tables["item_id"].embeddings.shape[1] == 100 + assert emb_module.embedding_tables["categories"].embeddings.shape[1] == 64 assert embeddings["item_id"].shape[1] == 100 assert embeddings["categories"].shape[1] == 64 @@ -142,15 +142,23 @@ def test_embedding_features_yoochoose_infer_embedding_sizes(testing_data: Datase embeddings = emb_module(mm.sample_batch(testing_data, batch_size=100, include_targets=False)) - assert emb_module.embedding_tables["user_id"].shape[1] == embeddings["user_id"].shape[1] == 20 assert ( - emb_module.embedding_tables["user_country"].shape[1] + emb_module.embedding_tables["user_id"].embeddings.shape[1] + == embeddings["user_id"].shape[1] + == 20 + ) + assert ( + emb_module.embedding_tables["user_country"].embeddings.shape[1] == embeddings["user_country"].shape[1] == 9 ) - assert emb_module.embedding_tables["item_id"].shape[1] == embeddings["item_id"].shape[1] == 46 assert ( - emb_module.embedding_tables["categories"].shape[1] + emb_module.embedding_tables["item_id"].embeddings.shape[1] + == embeddings["item_id"].shape[1] + == 46 + ) + assert ( + emb_module.embedding_tables["categories"].embeddings.shape[1] == embeddings["categories"].shape[1] == 13 ) @@ -170,15 +178,23 @@ def test_embedding_features_yoochoose_infer_embedding_sizes_multiple_8(testing_d embeddings = emb_module(mm.sample_batch(testing_data, batch_size=100, include_targets=False)) - assert emb_module.embedding_tables["user_id"].shape[1] == embeddings["user_id"].shape[1] == 24 assert ( - emb_module.embedding_tables["user_country"].shape[1] + emb_module.embedding_tables["user_id"].embeddings.shape[1] + == embeddings["user_id"].shape[1] + == 24 + ) + assert ( + emb_module.embedding_tables["user_country"].embeddings.shape[1] == embeddings["user_country"].shape[1] == 16 ) - assert emb_module.embedding_tables["item_id"].shape[1] == embeddings["item_id"].shape[1] == 48 assert ( - emb_module.embedding_tables["categories"].shape[1] + emb_module.embedding_tables["item_id"].embeddings.shape[1] + == embeddings["item_id"].shape[1] + == 48 + ) + assert ( + emb_module.embedding_tables["categories"].embeddings.shape[1] == embeddings["categories"].shape[1] == 16 ) @@ -198,15 +214,23 @@ def test_embedding_features_yoochoose_partially_infer_embedding_sizes(testing_da embeddings = emb_module(mm.sample_batch(testing_data, batch_size=100, include_targets=False)) - assert emb_module.embedding_tables["user_id"].shape[1] == embeddings["user_id"].shape[1] == 50 assert ( - emb_module.embedding_tables["user_country"].shape[1] + emb_module.embedding_tables["user_id"].embeddings.shape[1] + == embeddings["user_id"].shape[1] + == 50 + ) + assert ( + emb_module.embedding_tables["user_country"].embeddings.shape[1] == embeddings["user_country"].shape[1] == 100 ) - assert emb_module.embedding_tables["item_id"].shape[1] == embeddings["item_id"].shape[1] == 46 assert ( - emb_module.embedding_tables["categories"].shape[1] + emb_module.embedding_tables["item_id"].embeddings.shape[1] + == embeddings["item_id"].shape[1] + == 46 + ) + assert ( + emb_module.embedding_tables["categories"].embeddings.shape[1] == embeddings["categories"].shape[1] == 13 ) @@ -272,8 +296,12 @@ def test_embedding_features_yoochoose_pretrained_initializer(testing_data: Datas # Calling the first batch, so that embedding tables are build _ = emb_module(mm.sample_batch(testing_data, batch_size=10, include_targets=False)) - assert np.allclose(emb_module.embedding_tables["item_id"].numpy(), pretrained_emb_item_ids) - assert np.allclose(emb_module.embedding_tables["categories"].numpy(), pretrained_emb_categories) + assert np.allclose( + emb_module.embedding_tables["item_id"].embeddings.numpy(), pretrained_emb_item_ids + ) + assert np.allclose( + emb_module.embedding_tables["categories"].embeddings.numpy(), pretrained_emb_categories + ) def test_embedding_features_exporting_and_loading_pretrained_initializer(testing_data: Dataset): @@ -282,7 +310,7 @@ def test_embedding_features_exporting_and_loading_pretrained_initializer(testing # Calling the first batch, so that embedding tables are build _ = emb_module(mm.sample_batch(testing_data, batch_size=10, include_targets=False)) - item_id_embeddings = emb_module.embedding_tables["item_id"] + item_id_embeddings = emb_module.embedding_tables["item_id"].embeddings items_embeddings_dataset = emb_module.embedding_table_dataset(Tags.ITEM_ID, gpu=False) assert np.allclose(