Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

first implementation of wide&deep model #301

Closed
wants to merge 5 commits into from
Closed

Conversation

sararb
Copy link
Contributor

@sararb sararb commented Mar 28, 2022

Fixes #154

Goals ⚽

  • This PR includes an experimental implementation of Wide&Deep model where both components are jointly trained using a single optimizer.
  • A cross-feature layer is defined to add high-order interaction features to the wide component.

Testing Details 🔍

  • Unit test to check the training/evaluation of the Wide&Deep model.
  • Unit test to catch the ValueError related to wrong keys names being provided to CrossFeatures.

@sararb sararb force-pushed the tf/wide_and_deep branch from 5ab4b69 to 5610f92 Compare March 31, 2022 12:22
@github-actions
Copy link

Documentation preview

https://nvidia-merlin.github.io/models/review/pr-301

@sararb
Copy link
Contributor Author

sararb commented Mar 31, 2022

blocked by #308

@EvenOldridge
Copy link
Member

Can we merge this and flag the optimizer issue.

@nvidia-merlin-bot
Copy link

Click to view CI Results
GitHub pull request #301 of commit c7fbae56a07caa1eb6de1a8533d8f16e024e4778, has merge conflicts.
Running as SYSTEM
Setting status of c7fbae56a07caa1eb6de1a8533d8f16e024e4778 to PENDING with url https://10.20.13.93:8080/job/merlin_models/235/console and message: 'Pending'
Using context: Jenkins
Building on master in workspace /var/jenkins_home/workspace/merlin_models
using credential nvidia-merlin-bot
 > git rev-parse --is-inside-work-tree # timeout=10
Fetching changes from the remote Git repository
 > git config remote.origin.url https://github.com/NVIDIA-Merlin/models/ # timeout=10
Fetching upstream changes from https://github.com/NVIDIA-Merlin/models/
 > git --version # timeout=10
using GIT_ASKPASS to set credentials This is the bot credentials for our CI/CD
 > git fetch --tags --force --progress -- https://github.com/NVIDIA-Merlin/models/ +refs/pull/301/*:refs/remotes/origin/pr/301/* # timeout=10
 > git rev-parse c7fbae56a07caa1eb6de1a8533d8f16e024e4778^{commit} # timeout=10
Checking out Revision c7fbae56a07caa1eb6de1a8533d8f16e024e4778 (detached)
 > git config core.sparsecheckout # timeout=10
 > git checkout -f c7fbae56a07caa1eb6de1a8533d8f16e024e4778 # timeout=10
Commit message: "Merge branch 'main' into tf/wide_and_deep"
 > git rev-list --no-walk 736348756714771db0da34cfa43eb995512dc3e4 # timeout=10
[merlin_models] $ /bin/bash /tmp/jenkins175405788639646245.sh
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Requirement already satisfied: testbook in /usr/local/lib/python3.8/dist-packages (0.4.2)
Requirement already satisfied: nbformat>=5.0.4 in /usr/local/lib/python3.8/dist-packages (from testbook) (5.3.0)
Requirement already satisfied: nbclient>=0.4.0 in /usr/local/lib/python3.8/dist-packages (from testbook) (0.6.0)
Requirement already satisfied: traitlets>=5.0.0 in /usr/local/lib/python3.8/dist-packages (from nbclient>=0.4.0->testbook) (5.1.1)
Requirement already satisfied: jupyter-client>=6.1.5 in /usr/local/lib/python3.8/dist-packages (from nbclient>=0.4.0->testbook) (7.3.0)
Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.8/dist-packages (from nbclient>=0.4.0->testbook) (1.5.5)
Requirement already satisfied: fastjsonschema in /usr/local/lib/python3.8/dist-packages (from nbformat>=5.0.4->testbook) (2.15.3)
Requirement already satisfied: jupyter-core in /usr/local/lib/python3.8/dist-packages (from nbformat>=5.0.4->testbook) (4.10.0)
Requirement already satisfied: jsonschema>=2.6 in /usr/local/lib/python3.8/dist-packages (from nbformat>=5.0.4->testbook) (4.4.0)
Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.8/dist-packages (from jsonschema>=2.6->nbformat>=5.0.4->testbook) (21.4.0)
Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.8/dist-packages (from jsonschema>=2.6->nbformat>=5.0.4->testbook) (0.18.1)
Requirement already satisfied: importlib-resources>=1.4.0 in /usr/local/lib/python3.8/dist-packages (from jsonschema>=2.6->nbformat>=5.0.4->testbook) (5.7.1)
Requirement already satisfied: pyzmq>=22.3 in /usr/local/lib/python3.8/dist-packages (from jupyter-client>=6.1.5->nbclient>=0.4.0->testbook) (22.3.0)
Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.8/dist-packages (from jupyter-client>=6.1.5->nbclient>=0.4.0->testbook) (2.8.2)
Requirement already satisfied: tornado>=6.0 in /var/jenkins_home/.local/lib/python3.8/site-packages/tornado-6.1-py3.8-linux-x86_64.egg (from jupyter-client>=6.1.5->nbclient>=0.4.0->testbook) (6.1)
Requirement already satisfied: entrypoints in /usr/local/lib/python3.8/dist-packages (from jupyter-client>=6.1.5->nbclient>=0.4.0->testbook) (0.4)
Requirement already satisfied: zipp>=3.1.0 in /usr/local/lib/python3.8/dist-packages (from importlib-resources>=1.4.0->jsonschema>=2.6->nbformat>=5.0.4->testbook) (3.8.0)
Requirement already satisfied: six>=1.5 in /var/jenkins_home/.local/lib/python3.8/site-packages (from python-dateutil>=2.8.2->jupyter-client>=6.1.5->nbclient>=0.4.0->testbook) (1.15.0)
============================= test session starts ==============================
platform linux -- Python 3.8.10, pytest-7.1.2, pluggy-1.0.0
rootdir: /var/jenkins_home/workspace/merlin_models/models, configfile: pyproject.toml
plugins: xdist-2.5.0, forked-1.4.0, cov-3.0.0
collected 360 items / 2 skipped

tests/data/test_synthetic.py .. [ 0%]
tests/data/testing/test_dataset.py ..... [ 1%]
tests/tf/test_core.py ................ [ 6%]
tests/tf/test_dataset.py .............. [ 10%]
tests/tf/test_public_api.py . [ 10%]
tests/tf/blocks/test_cross.py ............ [ 13%]
tests/tf/blocks/test_dlrm.py ........ [ 16%]
tests/tf/blocks/test_interactions.py . [ 16%]
tests/tf/blocks/test_mlp.py ............................. [ 24%]
tests/tf/blocks/core/test_aggregation.py ......... [ 26%]
tests/tf/blocks/core/test_base.py . [ 27%]
tests/tf/blocks/core/test_index.py .. [ 27%]
tests/tf/blocks/core/test_masking.py ....... [ 29%]
tests/tf/blocks/core/test_transformations.py ......... [ 32%]
tests/tf/blocks/retrieval/test_matrix_factorization.py .. [ 32%]
tests/tf/blocks/retrieval/test_two_tower.py .......... [ 35%]
tests/tf/features/test_continuous.py ..... [ 36%]
tests/tf/features/test_embedding.py ......... [ 39%]
tests/tf/features/test_tabular.py ....... [ 41%]
tests/tf/layers/test_queue.py .............. [ 45%]
tests/tf/losses/test_losses.py ....................... [ 51%]
tests/tf/metrics/test_metrics_ranking.py ................. [ 56%]
tests/tf/models/test_benchmark.py . [ 56%]
tests/tf/models/test_ranking.py ......................... [ 63%]
tests/tf/models/test_retrieval.py ......... [ 66%]
tests/tf/prediction/test_classification.py .. [ 66%]
tests/tf/prediction/test_multi_task.py ....... [ 68%]
tests/tf/prediction/test_next_item.py .................... [ 74%]
tests/tf/prediction/test_regression.py .. [ 74%]
tests/tf/prediction/test_sampling.py .................... [ 80%]
tests/tf/utils/test_batch.py .... [ 81%]
tests/torch/test_dataset.py ......... [ 83%]
tests/torch/test_public_api.py . [ 84%]
tests/torch/block/test_base.py .... [ 85%]
tests/torch/block/test_mlp.py . [ 85%]
tests/torch/features/test_continuous.py .. [ 86%]
tests/torch/features/test_embedding.py .............. [ 90%]
tests/torch/features/test_tabular.py .... [ 91%]
tests/torch/model/test_head.py ............ [ 94%]
tests/torch/model/test_model.py .. [ 95%]
tests/torch/tabular/test_aggregation.py ........ [ 97%]
tests/torch/tabular/test_tabular.py ... [ 98%]
tests/torch/tabular/test_transformations.py ....... [100%]

=============================== warnings summary ===============================
../../../.local/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py:22
/var/jenkins_home/.local/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py:22: DeprecationWarning: the imp module is deprecated in favour of importlib; see the module's documentation for alternative uses
import imp

../../../../../usr/local/lib/python3.8/dist-packages/keras_preprocessing/image/utils.py:23
/usr/local/lib/python3.8/dist-packages/keras_preprocessing/image/utils.py:23: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
'nearest': pil_image.NEAREST,

../../../../../usr/local/lib/python3.8/dist-packages/keras_preprocessing/image/utils.py:24
/usr/local/lib/python3.8/dist-packages/keras_preprocessing/image/utils.py:24: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
'bilinear': pil_image.BILINEAR,

../../../../../usr/local/lib/python3.8/dist-packages/keras_preprocessing/image/utils.py:25
/usr/local/lib/python3.8/dist-packages/keras_preprocessing/image/utils.py:25: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
'bicubic': pil_image.BICUBIC,

../../../../../usr/local/lib/python3.8/dist-packages/keras_preprocessing/image/utils.py:28
/usr/local/lib/python3.8/dist-packages/keras_preprocessing/image/utils.py:28: DeprecationWarning: HAMMING is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.HAMMING instead.
if hasattr(pil_image, 'HAMMING'):

../../../../../usr/local/lib/python3.8/dist-packages/keras_preprocessing/image/utils.py:29
/usr/local/lib/python3.8/dist-packages/keras_preprocessing/image/utils.py:29: DeprecationWarning: HAMMING is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.HAMMING instead.
_PIL_INTERPOLATION_METHODS['hamming'] = pil_image.HAMMING

../../../../../usr/local/lib/python3.8/dist-packages/keras_preprocessing/image/utils.py:30
/usr/local/lib/python3.8/dist-packages/keras_preprocessing/image/utils.py:30: DeprecationWarning: BOX is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BOX instead.
if hasattr(pil_image, 'BOX'):

../../../../../usr/local/lib/python3.8/dist-packages/keras_preprocessing/image/utils.py:31
/usr/local/lib/python3.8/dist-packages/keras_preprocessing/image/utils.py:31: DeprecationWarning: BOX is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BOX instead.
_PIL_INTERPOLATION_METHODS['box'] = pil_image.BOX

../../../../../usr/local/lib/python3.8/dist-packages/keras_preprocessing/image/utils.py:33
/usr/local/lib/python3.8/dist-packages/keras_preprocessing/image/utils.py:33: DeprecationWarning: LANCZOS is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.LANCZOS instead.
if hasattr(pil_image, 'LANCZOS'):

../../../../../usr/local/lib/python3.8/dist-packages/keras_preprocessing/image/utils.py:34
/usr/local/lib/python3.8/dist-packages/keras_preprocessing/image/utils.py:34: DeprecationWarning: LANCZOS is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.LANCZOS instead.
_PIL_INTERPOLATION_METHODS['lanczos'] = pil_image.LANCZOS

../../../../../usr/lib/python3/dist-packages/requests/init.py:89
/usr/lib/python3/dist-packages/requests/init.py:89: RequestsDependencyWarning: urllib3 (1.26.9) or chardet (3.0.4) doesn't match a supported version!
warnings.warn("urllib3 ({}) or chardet ({}) doesn't match a supported "

../../../../../usr/local/lib/python3.8/dist-packages/torchvision/transforms/functional_pil.py:228
/usr/local/lib/python3.8/dist-packages/torchvision/transforms/functional_pil.py:228: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
interpolation: int = Image.BILINEAR,

../../../../../usr/local/lib/python3.8/dist-packages/torchvision/transforms/functional_pil.py:296
/usr/local/lib/python3.8/dist-packages/torchvision/transforms/functional_pil.py:296: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
interpolation: int = Image.NEAREST,

../../../../../usr/local/lib/python3.8/dist-packages/torchvision/transforms/functional_pil.py:312
/usr/local/lib/python3.8/dist-packages/torchvision/transforms/functional_pil.py:312: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
interpolation: int = Image.NEAREST,

../../../../../usr/local/lib/python3.8/dist-packages/torchvision/transforms/functional_pil.py:329
/usr/local/lib/python3.8/dist-packages/torchvision/transforms/functional_pil.py:329: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
interpolation: int = Image.BICUBIC,

../../../../../usr/local/lib/python3.8/dist-packages/torchvision/io/image.py:11
/usr/local/lib/python3.8/dist-packages/torchvision/io/image.py:11: UserWarning: Failed to load image Python extension: libtorch_cuda_cu.so: cannot open shared object file: No such file or directory
warn(f"Failed to load image Python extension: {e}")

merlin/models/tf/models/base.py:1
/var/jenkins_home/workspace/merlin_models/models/merlin/models/tf/models/base.py:1: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3, and in 3.10 it will stop working
from collections import Sequence as SequenceCollection

tests/tf/test_dataset.py::test_tf_drp_reset[100-True-10]
tests/tf/test_dataset.py::test_tf_drp_reset[100-True-9]
tests/tf/test_dataset.py::test_tf_drp_reset[100-True-8]
tests/tf/test_dataset.py::test_tf_drp_reset[100-False-10]
tests/tf/test_dataset.py::test_tf_drp_reset[100-False-9]
tests/tf/test_dataset.py::test_tf_drp_reset[100-False-8]
tests/tf/test_dataset.py::test_tf_catname_ordering
tests/tf/test_dataset.py::test_tf_map
/usr/lib/python3.8/site-packages/cudf/core/dataframe.py:1253: UserWarning: The deep parameter is ignored and is only included for pandas compatibility.
warnings.warn(

tests/tf/blocks/core/test_index.py: 2 warnings
tests/tf/models/test_retrieval.py: 9 warnings
tests/tf/prediction/test_next_item.py: 43 warnings
tests/tf/utils/test_batch.py: 2 warnings
/tmp/tmpyxsi0zn3.py:8: DeprecationWarning: The 'warn' method is deprecated, use 'warning' instead
ag__.converted_call(ag__.ld(warnings).warn, ("The 'warn' method is deprecated, use 'warning' instead", ag__.ld(DeprecationWarning), 2), None, fscope)

tests/tf/blocks/core/test_transformations.py::test_stochastic_swap_noise[0.1]
tests/tf/blocks/core/test_transformations.py::test_stochastic_swap_noise[0.3]
tests/tf/blocks/core/test_transformations.py::test_stochastic_swap_noise[0.5]
tests/tf/blocks/core/test_transformations.py::test_stochastic_swap_noise[0.7]
/var/jenkins_home/.local/lib/python3.8/site-packages/keras/backend.py:6089: UserWarning: tf.keras.backend.random_binomial is deprecated, and will be removed in a future version.Please use tf.keras.backend.random_bernoulli instead.
warnings.warn('tf.keras.backend.random_binomial is deprecated, '

tests/torch/block/test_mlp.py::test_mlp_block
/var/jenkins_home/workspace/merlin_models/models/merlin/models/data/synthetic.py:154: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:201.)
return {key: torch.tensor(value).to(self.device) for key, value in data.items()}

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========== 360 passed, 2 skipped, 86 warnings in 454.98s (0:07:34) ============
Performing Post build task...
Match found for : : True
Logical operation result is TRUE
Running script : #!/bin/bash
cd /var/jenkins_home/
CUDA_VISIBLE_DEVICES=1 python test_res_push.py "https://api.GitHub.com/repos/NVIDIA-Merlin/models/issues/$ghprbPullId/comments" "/var/jenkins_home/jobs/$JOB_NAME/builds/$BUILD_NUMBER/log"
[merlin_models] $ /bin/bash /tmp/jenkins3736865485173270852.sh

@marcromeyn
Copy link
Contributor

This PR uses the feature-column API for feature-crosses. This is not something we want to merge now. From talking to @benfred it seems like we actually support feature-crosses in NVT, although sounded like it’s a bit hidden. I think we need to create an example how to do it, and then incorporate it in the wide&deep model. It might also make sense to create a tag for a feature-cross that NVT could output.

@karlhigley
Copy link
Contributor

karlhigley commented May 3, 2022

From reviewing the Wide and Deep paper, they perform feature crosses as part of the model architecture, so I don't think we want to move the responsibility for feature crossing to NVT. Seems like what would be required to move this PR forward would be an implementation of feature crossing in Merlin Models that doesn't rely on the feature column API. What would be involved in creating one?

EDIT: To make my reasoning more explicit here, forcing people to do feature processing for a particular model violates the promise that we're making with dataset schemas, which is that you can quickly try out a bunch of models to compare them on the same data and we do the work to make that happen behind the scenes.

@marcromeyn
Copy link
Contributor

I agree with you that the ideal case is to make feature-crosses part of Merlin Models but this will require extra work. Since we already seem to be supporting it in NVT, I was thinking that that would be the minimum scope to have a wide&deep model in Merlin Models.

We could create a separate ticket to add feature-crosses to Merlin Models. Not sure as of now, what the priority should be to work on that ticket though.

Curiuous to your thoughts (cc @EvenOldridge @benfred)

@karlhigley
Copy link
Contributor

I'm working on closing out unfinished scope from 22.04 which includes Wide & Deep, so it's not extra work, it's just the work we previously committed to. If we need to put more people on doing that work in order to finish it, maybe it's something @jperez999 and I can help with?

@marcromeyn
Copy link
Contributor

I don’t think we ever made a proper definition of done, so not sure it’s competely fair to say that we have comitted to a W&D model that contains feature-crosses (as opposed to doing the feature-crossing in NVT).

But if you have time to help, it seems totally doable to add it though. I was under the impression that the string hashing-ops were CPU-only in TF, but it seems like that’s no longer the case (PR). This would mean that we could concat the ids of the columns you want to cross: IDfeature1_IDfeature2_IDfeature3 and hash this. We could infer the cardinality of this by multiplying the cardinalities of each of the features. The user should be able to overwrite this though (to enable less buckets where embeddings are shared).

We have to figure out how to enable a nice API for the user to provide the feature-crosses.

@karlhigley
Copy link
Contributor

Until we either have Wide&Deep merged or we've explicitly decided not to add it, whatever is required to make it happen seems to me like part of the scope we've committed ourselves to. 🤷🏻

Do we specifically need to use hashing, or can we exhaustively compute the pairwise feature crosses?

@marcromeyn
Copy link
Contributor

The ideal end state of the wide&deep model:

  • One line to create the model
  • Different optimizers for wide & deep part
  • Feature crossing as part of the modeling code

We could either commit to do all these things in one go or to add things gradually. Up until this point we have been informally talking about the gradual approach, but totally fine to me if we want to do it all in go.

@sararb
Copy link
Contributor Author

sararb commented Jun 13, 2022

I am closing this PR, as a new PR was opened by @Timmy00. This PR solves the first point: One line to create the model. So I also added the three points needed for completing the wide&Deep model in this RMP ticket

@gabrielspmoreira
Copy link
Member

I agree feature crossing should be done in the modeling side, so that a Data Scientist would not have to preprocess multiple versions of the dataset to try out different feature interactions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add Wide&Deep model
6 participants