Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

first implementation of wide&deep model #301

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 109 additions & 1 deletion merlin/models/tf/blocks/core/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Dict, Optional, Union
from typing import Dict, List, Optional, Union

import tensorflow as tf
from tensorflow.keras import backend
Expand All @@ -32,6 +32,10 @@
@Block.registry.register("as-sparse")
@tf.keras.utils.register_keras_serializable(package="merlin.models")
class AsSparseFeatures(TabularBlock):
"""
Convert inputs to sparse tensors.
"""

def call(self, inputs: TabularData, **kwargs) -> TabularData:
outputs = {}
for name, val in inputs.items():
Expand All @@ -57,6 +61,14 @@ def compute_output_shape(self, input_shape):
@Block.registry.register("as-dense")
@tf.keras.utils.register_keras_serializable(package="merlin.models")
class AsDenseFeatures(TabularBlock):
"""Convert sparse inputs to dense tensors

Parameters
----------
max_seq_length : int
The maximum length of multi-hot features.
"""

def __init__(self, max_seq_length: Optional[int] = None, **kwargs):
super().__init__(**kwargs)
self.max_seq_length = max_seq_length
Expand Down Expand Up @@ -99,6 +111,17 @@ def get_config(self):

@tf.keras.utils.register_keras_serializable(package="merlin.models")
class RenameFeatures(TabularBlock):
"""Rename input features

Parameters
----------
renames: dict
Mapping with new features names.
schema: Schema, optional
The `Schema` with input features,
by default None
"""

def __init__(
self, renames: Dict[Union[str, Tags], str], schema: Optional[Schema] = None, **kwargs
):
Expand Down Expand Up @@ -308,6 +331,8 @@ def compute_output_shape(self, input_shape):
@Block.registry.register_with_multiple_names("l2-norm")
@tf.keras.utils.register_keras_serializable(package="merlin.models")
class L2Norm(TabularBlock):
"""Apply L2-normalization to input tensors along a given axis"""

def __init__(self, **kwargs):
super(L2Norm, self).__init__(**kwargs)

Expand Down Expand Up @@ -388,6 +413,15 @@ def compute_output_shape(self, input_shape):

@tf.keras.utils.register_keras_serializable(package="merlin_models")
class LogitsTemperatureScaler(Block):
"""Scale the logits higher or lower,
this is often used to reduce the overconfidence of the model.

Parameters
----------
temperature : float
Divide the logits by this scaler.
"""

def __init__(self, temperature: float, **kwargs):
super(LogitsTemperatureScaler, self).__init__(**kwargs)
self.temperature = temperature
Expand Down Expand Up @@ -415,6 +449,23 @@ def compute_output_shape(self, input_shape):
@Block.registry.register_with_multiple_names("weight-tying")
@tf.keras.utils.register_keras_serializable(package="merlin_models")
class ItemsPredictionWeightTying(Block):
"""Tying the item embedding weights with the output projection layer matrix [1]
The output logits are obtained by multiplying the output vector by the item-ids embeddings.

Parameters
----------
schema : Schema
The `Schema` with the input features
bias_initializer : str, optional
Initializer to use on the bias vector, by default "zeros"

References:
-----------
[1] Hakan, Inan et al.
"Tying word vectors and word classifiers: A loss framework for language modeling"
arXiv:1611.01462
"""

def __init__(self, schema: Schema, bias_initializer="zeros", **kwargs):
super(ItemsPredictionWeightTying, self).__init__(**kwargs)
self.bias_initializer = bias_initializer
Expand Down Expand Up @@ -484,6 +535,8 @@ def get_config(self):
@Block.registry.register_with_multiple_names("label_to_onehot")
@tf.keras.utils.register_keras_serializable(package="merlin_models")
class LabelToOneHot(Block):
"""Transform the categorical encoded labels into a one-hot representation"""

def __init__(self, **kwargs):
super().__init__(**kwargs)

Expand All @@ -499,3 +552,58 @@ def call_outputs(
targets = transform_label_to_onehot(targets, num_classes)

return PredictionOutput(predictions, targets, outputs.positive_item_ids)


@tf.keras.utils.register_keras_serializable(package="merlin_models")
class CrossFeatures(TabularBlock):
"""Transformation performing crosses of categorical features
based on the provided keys.

There are some open questions in this experimental implementation:
#TODO : remove the dependency with tensorflow input api
#TODO : Add possibility of loading cross keys from schema

Parameters
----------
keys : List[List[str]]
Iterable of column names identifying the features to be crossed
hash_bucket_size : int
The number of buckets for feature crossing.
"""

def __init__(self, keys: List[List[str]], hash_bucket_size: int, **kwargs):
self.keys = keys
self.hash_bucket_size = hash_bucket_size
super().__init__(**kwargs)

def build(self, input_shape):
self.cross_layers = {}
for cross_keys in self.keys:
cross_column = tf.feature_column.crossed_column(
cross_keys, hash_bucket_size=self.hash_bucket_size
)
cross_layer = tf.keras.layers.DenseFeatures(
tf.feature_column.indicator_column(cross_column)
)
self.cross_layers["_".join(cross_keys)] = cross_layer

return super().build(input_shape)

def compute_output_shape(self, input_shape):
batch_size = self.calculate_batch_size_from_input_shapes(input_shape)
out_shapes = {}
for name in self.cross_layers.keys():
out_shapes[name] = tf.TensorShape((batch_size, self.hash_bucket_size))
return out_shapes

def call(self, inputs: TabularData, **kwargs) -> TabularData:
outputs = {}
for name, layer in self.cross_layers.items():
outputs[name] = layer(inputs)
return outputs

def get_config(self):
config = super().get_config()
config["keys"] = self.keys
config["hash_bucket_size"] = self.hash_bucket_size
return config
6 changes: 4 additions & 2 deletions merlin/models/tf/blocks/retrieval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,11 @@ def call_outputs(
for sampler in self.samplers:
input_data = EmbeddingWithMetadata(batch_items_embeddings, batch_items_metadata)
if "item_weights" in sampler._call_fn_args:
neg_items = sampler(input_data.__dict__, item_weights=embedding_table)
neg_items = sampler(
input_data.__dict__, item_weights=embedding_table, training=True
)
else:
neg_items = sampler(input_data.__dict__)
neg_items = sampler(input_data.__dict__, training=True)

if tf.shape(neg_items.embeddings)[0] > 0:
# Accumulates sampled negative items from all samplers
Expand Down
2 changes: 1 addition & 1 deletion merlin/models/tf/blocks/sampling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
self.set_max_num_samples(max_num_samples)

@abc.abstractmethod
def add(self, embeddings: tf.Tensor, items_metadata: TabularData, training=True):
def add(self, embeddings: tf.Tensor, items_metadata: TabularData, training=False):
raise NotImplementedError()

@abc.abstractmethod
Expand Down
10 changes: 5 additions & 5 deletions merlin/models/tf/blocks/sampling/cross_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def _check_built(self) -> None:
"so that it is built before calling add() or sample() directly"
)

def call(self, inputs: TabularData, training=True) -> EmbeddingWithMetadata:
def call(self, inputs: TabularData, training=False) -> EmbeddingWithMetadata:
"""Adds the current batch to the FIFO queue cache and samples all items
embeddings from the last N cached batches.

Expand Down Expand Up @@ -154,7 +154,7 @@ def call(self, inputs: TabularData, training=True) -> EmbeddingWithMetadata:
def add( # type: ignore
self,
inputs: TabularData,
training: bool = True,
training: bool = False,
) -> None:
self._check_built()

Expand Down Expand Up @@ -244,7 +244,7 @@ def _check_inputs(self, inputs):
def add( # type: ignore
self,
inputs: TabularData,
training: bool = True,
training: bool = False,
) -> None:
"""Updates the FIFO queue with batch item embeddings (for items whose ids were
already added to the queue) and adds to the queue the items seen for the first time
Expand Down Expand Up @@ -419,11 +419,11 @@ def _check_inputs(self, inputs):
self.item_id_feature_name in inputs["metadata"]
), "The 'item_id' metadata feature is required by PopularityBasedSampler."

def add(self, embeddings: tf.Tensor, items_metadata: TabularData, training=True):
def add(self, embeddings: tf.Tensor, items_metadata: TabularData, training=False):
pass

def call(
self, inputs: TabularData, item_weights: tf.Tensor, training=True
self, inputs: TabularData, item_weights: tf.Tensor, training=False
) -> EmbeddingWithMetadata:
if training:
self._check_inputs(inputs)
Expand Down
4 changes: 2 additions & 2 deletions merlin/models/tf/blocks/sampling/in_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def build(self, input_shapes: TabularData) -> None:
if self._batch_size is None:
self.set_batch_size(input_shapes["embeddings"][0])

def call(self, inputs: TabularData, training=True) -> EmbeddingWithMetadata:
def call(self, inputs: TabularData, training=False) -> EmbeddingWithMetadata:
"""Returns the item embeddings and item metadata from
the current batch.
The implementation is very simple, as it just returns the current
Expand Down Expand Up @@ -96,7 +96,7 @@ def call(self, inputs: TabularData, training=True) -> EmbeddingWithMetadata:
items_embeddings = self.sample()
return items_embeddings

def add(self, inputs: TabularData, training=True) -> None: # type: ignore
def add(self, inputs: TabularData, training=False) -> None: # type: ignore
self._check_inputs_batch_sizes(inputs)
self._last_batch_items_embeddings = inputs["embeddings"]
self._last_batch_items_metadata = inputs["metadata"]
Expand Down
101 changes: 99 additions & 2 deletions merlin/models/tf/models/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from merlin.models.tf.blocks.core.base import Block
from merlin.models.tf.blocks.core.combinators import ParallelBlock
from merlin.models.tf.blocks.core.inputs import InputBlock
from merlin.models.tf.blocks.core.transformations import CategoricalOneHot
from merlin.models.tf.blocks.core.transformations import CategoricalOneHot, CrossFeatures
from merlin.models.tf.blocks.cross import CrossBlock
from merlin.models.tf.blocks.dlrm import DLRMBlock
from merlin.models.tf.blocks.interaction import FMPairwiseInteraction
Expand Down Expand Up @@ -191,7 +191,7 @@ def DeepFMModel(
embedding_dim : int
Dimension of the embeddings
deep_block : Optional[Block]
The `Block` that learns high-ordeer feature interactions
The `Block` that learns high-order feature interactions
Defaults to MLPBlock([64, 128])
prediction_tasks: optional
The prediction tasks to be used, by default this will be inferred from the Schema.
Expand Down Expand Up @@ -233,3 +233,100 @@ def DeepFMModel(
model = deep_fm.connect(prediction_tasks)

return model


def WideAndDeepModel(
schema: Schema,
embedding_dim: int,
keys: List[List[str]],
hash_bucket_size: Optional[int] = 1000,
deep_block: Optional[Block] = None,
prediction_tasks: Optional[
Union[PredictionTask, List[PredictionTask], ParallelPredictionBlock]
] = None,
embedding_option_kwargs: dict = {},
**kwargs
) -> Model:
"""Experimental implementation of Wide&Deep model [1];
If deep_block is not provided, the model is equivalent to
a linear model with cross-product interactions.

Example usage::
model = ml.WideAndDeepModel(
schema,
embedding_dim=64,
keys=[['item_category', 'item_intention'], ['country', 'gender', 'device']],
deep_block=ml.MLPBlock([64, 128])
)

#TODO : Implementation of a BlocksOptimizer class that accepts different optimizers
# for different subsets of the model’s blocks

References:
-----------
[1] Heng-Tze, Cheng et al.
"Wide & Deep Learning for Recommender Systems" arXiv:1606.07792.

Parameters
----------
schema : Schema
The `Schema` with the input features
embedding_dim : int
The embedding dimension
keys : List[List[str]]
List of column names for cross-product transformation
hash_bucket_size : Optional[int], optional
number of buckets for features cross interaction, by default 1e4
deep_block : Block, optional
The `Block` that learns high-order feature interactions, by default None
prediction_tasks : optional
The prediction tasks to be used, by default this will be inferred from the Schema,
by default None
embedding_option_kwargs : dict, optional
Additional arguments to provide to `EmbeddingOptions` object
for embeddings tables setting, by default {}

Returns
-------
Model
Wide&Deep model class.

Raises
------
ValueError
Make sure the cross features keys are present in the schema
"""
if not all([key in schema.column_names for key_pair in keys for key in key_pair]):
raise ValueError("Make sure the cross features keys are present in the schema")
# wide block
cross_features = CrossFeatures(keys, hash_bucket_size)
branches = {
"categorical": CategoricalOneHot(schema),
"continuous": ContinuousFeatures.from_schema(schema),
}
base_features = ParallelBlock(branches, aggregation="concat")
wide_inputs = ParallelBlock(
{"cross": cross_features, "base": base_features}, aggregation="concat"
)
wide_block = wide_inputs.connect(tf.keras.layers.Dense(units=1, activation=None, use_bias=True))

if deep_block is not None:
deep_inputs = InputBlock(
schema,
embedding_options=EmbeddingOptions(
embedding_dim_default=embedding_dim, **embedding_option_kwargs
),
**kwargs
)
deep_block = deep_inputs.connect(deep_block)

wide_and_deep = ParallelBlock(
{"deep": deep_block, "wide": wide_block}, aggregation="concat"
)
else:
wide_and_deep = wide_block

prediction_tasks = parse_prediction_tasks(schema, prediction_tasks)
model = wide_and_deep.connect(prediction_tasks)

return model
21 changes: 21 additions & 0 deletions tests/tf/blocks/core/test_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,24 @@ def test_categorical_one_hot_encoding():
assert list(outputs["cat3"].shape) == [NUM_ROWS, MAX_LEN, 51]

assert inputs["cat1"][0].numpy() == tf.where(outputs["cat1"][0, :] == 1).numpy()[0]


def test_cross_product():
from merlin.models.tf.blocks.core.transformations import CrossFeatures

NUM_ROWS = 100

cardinalities = {"cat1": 21, "cat2": 11, "cat3": 6}
inputs = {}
for cat, cardinality in cardinalities.items():
inputs[cat] = tf.random.uniform((NUM_ROWS, 1), minval=1, maxval=cardinality, dtype=tf.int32)

n_buckets = 1000
cross_features = CrossFeatures(
keys=[["cat1", "cat2"], ["cat1", "cat2", "cat3"]], hash_bucket_size=n_buckets
)

outputs = cross_features(inputs)
for name, out in outputs.items():
assert out.shape[-1] == n_buckets
assert name in ["cat1_cat2", "cat1_cat2_cat3"]
Loading