Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Support feeding pre-trained embeddings to TF4Rec model with high-level api #475

Open
2 of 3 tasks
rnyak opened this issue Aug 18, 2022 · 8 comments
Open
2 of 3 tasks

Comments

@rnyak
Copy link
Contributor

rnyak commented Aug 18, 2022

🚀 Feature request

Currently we do not have out of the box support for adding pre-trained embeddings to embedding layer, and ability to freeze them, and train a TF4Rec model. We have embedding_initializer but we never tested if it works accurately and as expected. May be we can create in PyTorch a class like TensorInitializer (TF) as we did in Merlin Models and expose the embedding initializer and trainable args to the user.

We need to

  • Expose definition of embeddings module in the input blocks: TabularFeatures and TabularSequenceFeatures
  • Support feeding pre-trained embeddings to TF4Rec model with high-level api (users should be add them to the embedding layer, and freeze them, i.e., set trainable=False (TF Api) or requires_grad=False (PyTorch API))
  • create an example notebook for showcasing that functionality

Motivation

This is a FEA coming from our customers and users.

@karlhigley
Copy link
Contributor

Is this related to or part of NVIDIA-Merlin/Merlin#211?

@rnyak
Copy link
Contributor Author

rnyak commented Aug 18, 2022

Is this related to or part of NVIDIA-Merlin/Merlin#211?

@karlhigley more related to NVIDIA-Merlin/Merlin#471. Not sure about the link to 211.

@gabrielspmoreira
Copy link
Member

When the embedding table are not huge and fit GPU memory, the new PretrainedEmbeddingsInitializer ( #572 ) can be used to initialize the embedding matrix with pre-trained embeddings and set them to trainable or not.

@rnyak rnyak self-assigned this Jan 30, 2023
@rnyak rnyak modified the milestones: Merlin 23.01, Merlin 23.02 Jan 30, 2023
@karlhigley karlhigley modified the milestones: Merlin 23.02, Merlin 23.04 Apr 4, 2023
@karunaahuja
Copy link

Is there an example notebook of usage of PretrainedEmbeddingsInitializer to initialize the embedding matrix

@rnyak
Copy link
Contributor Author

rnyak commented Jul 8, 2024

Is there an example notebook of usage of PretrainedEmbeddingsInitializer to initialize the embedding matrix

We dont have an example for this feature, but you can refer to the unit test, and try to implement it.

@karunaahuja
Copy link

Is there an example notebook of usage of PretrainedEmbeddingsInitializer to initialize the embedding matrix

We dont have an example for this feature, but you can refer to the unit test, and try to implement it.

Thanks, I guess what I am looking for is how to use this along with the input block defined by a model schema, TabularSequenceFeatures (with a series of categorical and continuous features) and tr.NextItemPredictionTask and Electra config. Here's my pseudo code without using the embeddings

input_module = tr.TabularSequenceFeatures.from_schema(
       schema,
       max_sequence_length=max_sequence_length,
       aggregation="concat",
       d_output=d_model,
       masking="mlm",
       embedding_dim_default=embedding_dim_default,
   )  
   
   metrics = [
       tr.ranking_metric.NDCGAt(top_ks=[10, 20, 50, 100, 150, 200], labels_onehot=True),
       tr.ranking_metric.AvgPrecisionAt(
           top_ks=[10, 20, 50, 100, 150, 200], labels_onehot=True
       ),
       tr.ranking_metric.RecallAt(top_ks=[10, 20, 50, 100, 150, 200], labels_onehot=True),
   ]
   
   prediction_task = tr.NextItemPredictionTask(weight_tying=True, metrics=metrics)

   
   transformer_config = tr.Electra.build(
       d_model=d_model,
       n_head=n_head,
       n_layer=n_layer,
       total_seq_length=max_sequence_length,
       pad_token=PAD_TOKEN,
   )
   
   model = transformer_config.to_torch_model(input_module, prediction_task)

@karunaahuja
Copy link

following up on this ^

@Tottowich
Copy link

Any progress? @karunaahuja
Looking to do the same thing :D

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants