Skip to content

Commit

Permalink
remove wide&deep from public API
Browse files Browse the repository at this point in the history
  • Loading branch information
sararb committed Mar 31, 2022
1 parent 094a0f6 commit 5610f92
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 8 deletions.
5 changes: 1 addition & 4 deletions merlin/models/tf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
AsDenseFeatures,
AsSparseFeatures,
CategoricalOneHot,
CrossFeatures,
ExpandDims,
LabelToOneHot,
StochasticSwapNoise,
Expand Down Expand Up @@ -96,7 +95,7 @@
)
from merlin.models.tf.models import benchmark
from merlin.models.tf.models.base import Model, RetrievalModel
from merlin.models.tf.models.ranking import DCNModel, DeepFMModel, DLRMModel, WideAndDeepModel
from merlin.models.tf.models.ranking import DCNModel, DeepFMModel, DLRMModel
from merlin.models.tf.models.retrieval import (
MatrixFactorizationModel,
TwoTowerModel,
Expand Down Expand Up @@ -156,7 +155,6 @@
"AsDenseFeatures",
"AsSparseFeatures",
"CategoricalOneHot",
"CrossFeatures",
"ElementwiseSum",
"ElementwiseSumItemMulti",
"AsTabular",
Expand Down Expand Up @@ -201,7 +199,6 @@
"DLRMModel",
"DCNModel",
"DeepFMModel",
"WideAndDeepModel",
"losses",
"LossType",
"sample_batch",
Expand Down
4 changes: 4 additions & 0 deletions merlin/models/tf/blocks/core/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,10 @@ class CrossFeatures(TabularBlock):
"""Transformation performing crosses of categorical features
based on the provided keys.
There are some open questions in this experimental implementation:
#TODO : remove the dependency with tensorflow input api
#TODO : Add possibility of loading cross keys from schema
Parameters
----------
keys : List[List[str]]
Expand Down
5 changes: 3 additions & 2 deletions merlin/models/tf/models/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,6 @@ def WideAndDeepModel(
**kwargs
) -> Model:
"""Experimental implementation of Wide&Deep model [1];
The wide and deep components are trained with a single
optimizer in this first implementation.
If deep_block is not provided, the model is equivalent to
a linear model with cross-product interactions.
Expand All @@ -261,6 +259,9 @@ def WideAndDeepModel(
deep_block=ml.MLPBlock([64, 128])
)
#TODO : Implementation of a BlocksOptimizer class that accepts different optimizers
# for different subsets of the model’s blocks
References:
-----------
[1] Heng-Tze, Cheng et al.
Expand Down
8 changes: 6 additions & 2 deletions tests/tf/models/test_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def test_deep_fm_model_single_task_from_pred_task(ecommerce_data, num_epochs=5,

@pytest.mark.parametrize("run_eagerly", [True, False])
def test_wide_and_deep_model_single_task_from_pred_task(ecommerce_data, run_eagerly, num_epochs=5):
model = ml.WideAndDeepModel(
from merlin.models.tf.models.ranking import WideAndDeepModel

model = WideAndDeepModel(
ecommerce_data.schema,
embedding_dim=64,
keys=[["item_category", "item_intention"]],
Expand All @@ -105,8 +107,10 @@ def test_wide_and_deep_model_single_task_from_pred_task(ecommerce_data, run_eage


def test_wide_and_deep_model_wrong_keys(ecommerce_data):
from merlin.models.tf.models.ranking import WideAndDeepModel

with pytest.raises(ValueError) as exc_info:
ml.WideAndDeepModel(ecommerce_data.schema, embedding_dim=64, keys=[["item_1", "item_2"]])
WideAndDeepModel(ecommerce_data.schema, embedding_dim=64, keys=[["item_1", "item_2"]])
assert "Make sure the cross features keys are present in the schema" in str(exc_info.value)


Expand Down

0 comments on commit 5610f92

Please sign in to comment.