From 593d78d184af5d68a67c7f0267efdfd48f8b51f4 Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Mon, 3 Jul 2023 14:04:06 +0200 Subject: [PATCH 1/2] Adding CrossBlock (used in DCN-v2) (#1172) * Adding DCN-V2 * Merging CrossBlock & CrossLink * Merging CrossBlock & CrossLink * Adding with_low_rank to CrossBlock --- merlin/models/torch/blocks/cross.py | 189 ++++++++++++++++++++++++++ tests/unit/torch/blocks/test_cross.py | 100 ++++++++++++++ 2 files changed, 289 insertions(+) create mode 100644 merlin/models/torch/blocks/cross.py create mode 100644 tests/unit/torch/blocks/test_cross.py diff --git a/merlin/models/torch/blocks/cross.py b/merlin/models/torch/blocks/cross.py new file mode 100644 index 0000000000..ebfa215bcf --- /dev/null +++ b/merlin/models/torch/blocks/cross.py @@ -0,0 +1,189 @@ +from copy import deepcopy +from typing import Dict, Optional, Union + +import torch +from torch import nn +from torch.nn.modules.lazy import LazyModuleMixin + +from merlin.models.torch.block import Block +from merlin.models.torch.transforms.agg import Concat +from merlin.models.utils.doc_utils import docstring_parameter + +_DCNV2_REF = """ + References + ---------- + .. [1]. Wang, Ruoxi, et al. "DCN V2: Improved deep & cross network and + practical lessons for web-scale learning to rank systems." Proceedings + of the Web Conference 2021. 2021. https://arxiv.org/pdf/2008.13535.pdf + +""" + + +class LazyMirrorLinear(LazyModuleMixin, nn.Linear): + """A :class:`torch.nn.Linear` module where both + `in_features` & `out_features` are inferred. (i.e. `out_features` = `in_features`) + + Parameters + ---------- + bias: + If set to ``False``, the layer will not learn an additive bias. + Default: ``True`` + + """ + + cls_to_become = nn.Linear # type: ignore[assignment] + weight: nn.parameter.UninitializedParameter + bias: nn.parameter.UninitializedParameter # type: ignore[assignment] + + def __init__(self, bias: bool = True, device=None, dtype=None) -> None: + # This code is taken from torch.nn.LazyLinear.__init__ + factory_kwargs = {"device": device, "dtype": dtype} + # bias is hardcoded to False to avoid creating tensor + # that will soon be overwritten. + super().__init__(0, 0, False) + self.weight = nn.parameter.UninitializedParameter(**factory_kwargs) + if bias: + self.bias = nn.parameter.UninitializedParameter(**factory_kwargs) + + def reset_parameters(self) -> None: + if not self.has_uninitialized_params() and self.in_features != 0: + super().reset_parameters() + + def initialize_parameters(self, input) -> None: # type: ignore[override] + if self.has_uninitialized_params(): + with torch.no_grad(): + self.in_features = input.shape[-1] + if not hasattr(self, "out_features") or self.out_features == 0: + self.out_features = self.in_features + self.weight.materialize((self.out_features, self.in_features)) + if self.bias is not None: + self.bias.materialize((self.out_features,)) + self.reset_parameters() + + +@docstring_parameter(dcn_reference=_DCNV2_REF) +class CrossBlock(Block): + """ + This block provides a way to create high-order feature interactions + by a number of stacked Cross Layers, from DCN V2: Improved Deep & Cross Network [1]. + See Eq. (1) for full-rank and Eq. (2) for low-rank version. + + Parameters + ---------- + *module : nn.Module + Variable length argument list of PyTorch modules to be contained in the block. + name : Optional[str], default = None + The name of the block. If None, no name is assigned. + + {dcn_reference} + """ + + def __init__(self, *module, name: Optional[str] = None): + super().__init__(*module, name=name) + self.concat = Concat() + self.init_hook_handle = self.register_forward_pre_hook(self.initialize) + + @classmethod + def with_depth(cls, depth: int) -> "CrossBlock": + """Creates a CrossBlock with a given depth. + + Parameters + ---------- + depth : int + Depth of the CrossBlock. + + Returns + ------- + CrossBlock + A CrossBlock of the given depth. + + Raises + ------ + ValueError + If depth is less than or equal to 0. + """ + if not depth > 0: + raise ValueError(f"`depth` must be greater than 0, got {depth}") + + return cls(*Block(LazyMirrorLinear()).repeat(depth)) + + @classmethod + def with_low_rank(cls, depth: int, low_rank: nn.Module) -> "CrossBlock": + """ + Creates a CrossBlock with a given depth and low rank. See Eq. (2) in [1]. + + Parameters + ---------- + depth : int + Depth of the CrossBlock. + low_rank : nn.Module + Low rank module to include in the CrossBlock. + + Returns + ------- + CrossBlock + A CrossBlock of the given depth and low rank. + """ + + return cls(*(Block(deepcopy(low_rank), *block) for block in cls.with_depth(depth))) + + def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor: + """Forward-pass of the cross-block. + + Parameters + ---------- + inputs : Union[torch.Tensor, Dict[str, torch.Tensor]] + The input data. It could be either a tensor or a dictionary of tensors. + + Returns + ------- + torch.Tensor + The output tensor after the forward pass. + + Raises + ------ + RuntimeError + If the output from a module is not a Tensor. + """ + + if torch.jit.isinstance(inputs, Dict[str, torch.Tensor]): + x = self.concat(inputs) + else: + x = inputs + + x0 = x + current = x + for module in self.values: + module_out = module(current) + if not isinstance(module_out, torch.Tensor): + raise RuntimeError("CrossBlock expects a Tensor as output") + + current = x0 * module_out + current + + return current + + def initialize(self, module, inputs): + """ + Initialize the block by setting the output features of all LazyMirrorLinear children. + + This is meant to be used as a forward pre-hook. + + Parameters + ---------- + module : nn.Module + The module to initialize. + inputs : tuple + The inputs to the forward method. + """ + + if torch.jit.isinstance(inputs[0], Dict[str, torch.Tensor]): + _inputs = self.concat(inputs[0]) + else: + _inputs = inputs[0] + + def set_out_features_lazy_mirror_linear(m): + if isinstance(m, LazyMirrorLinear): + m.out_features = _inputs.shape[-1] + + self.apply(set_out_features_lazy_mirror_linear) + self.init_hook_handle.remove() # Clear hook once block is initialized diff --git a/tests/unit/torch/blocks/test_cross.py b/tests/unit/torch/blocks/test_cross.py new file mode 100644 index 0000000000..2fb2f3d1dc --- /dev/null +++ b/tests/unit/torch/blocks/test_cross.py @@ -0,0 +1,100 @@ +from typing import Tuple + +import pytest +import torch +from torch import nn + +import merlin.models.torch as mm +from merlin.models.torch.blocks.cross import CrossBlock, LazyMirrorLinear +from merlin.models.torch.utils import module_utils + + +class TestLazyMirrorLinear: + def test_init(self): + module = LazyMirrorLinear(bias=True) + assert isinstance(module.weight, nn.parameter.UninitializedParameter) + assert isinstance(module.bias, nn.parameter.UninitializedParameter) + + def test_no_bias_init(self): + module = LazyMirrorLinear(bias=False) + assert isinstance(module.weight, nn.parameter.UninitializedParameter) + assert module.bias is None + + def test_reset_parameters(self): + module = LazyMirrorLinear(bias=True) + input = torch.randn(10, 20) + module.initialize_parameters(input) + assert module.in_features == 20 + assert module.out_features == 20 + assert module.weight.shape == (20, 20) + assert module.bias.shape == (20,) + + def test_forward(self): + module = LazyMirrorLinear(bias=True) + input = torch.randn(10, 20) + output = module_utils.module_test(module, input) + assert output.shape == (10, 20) + + def test_no_bias_forward(self): + module = LazyMirrorLinear(bias=False) + input = torch.randn(10, 20) + output = module_utils.module_test(module, input) + assert output.shape == (10, 20) + + +class TestCrossBlock: + def test_with_depth(self): + crossblock = CrossBlock.with_depth(depth=1) + assert len(crossblock) == 1 + assert isinstance(crossblock[0][0], LazyMirrorLinear) + + def test_with_multiple_depth(self): + crossblock = CrossBlock.with_depth(depth=3) + assert len(crossblock) == 3 + for module in crossblock: + assert isinstance(module[0], LazyMirrorLinear) + + def test_crossblock_invalid_depth(self): + with pytest.raises(ValueError): + CrossBlock.with_depth(depth=0) + + def test_forward_tensor(self): + crossblock = CrossBlock.with_depth(depth=1) + input = torch.randn(5, 10) + output = module_utils.module_test(crossblock, input) + assert output.shape == (5, 10) + + def test_forward_dict(self): + crossblock = CrossBlock.with_depth(depth=1) + inputs = {"a": torch.randn(5, 10), "b": torch.randn(5, 10)} + output = module_utils.module_test(crossblock, inputs) + assert output.shape == (5, 20) + + def test_forward_multiple_depth(self): + crossblock = CrossBlock.with_depth(depth=3) + input = torch.randn(5, 10) + output = module_utils.module_test(crossblock, input) + assert output.shape == (5, 10) + + def test_with_low_rank(self): + crossblock = CrossBlock.with_low_rank(depth=2, low_rank=mm.MLPBlock([5])) + assert len(crossblock) == 2 + + input = torch.randn(5, 10) + output = module_utils.module_test(crossblock, input) + assert output.shape == (5, 10) + + assert crossblock[0][0][1].in_features == 10 + assert crossblock[0][0][1].out_features == 5 + assert crossblock[0][1].in_features == 5 + assert crossblock[0][1].out_features == 10 + + def test_exception(self): + class ToTuple(nn.Module): + def forward(self, input) -> Tuple[torch.Tensor, torch.Tensor]: + return input, input + + crossblock = CrossBlock(ToTuple()) + + with pytest.raises(RuntimeError): + module_utils.module_test(crossblock, torch.randn(5, 10)) From 1782b44db351e1c6d0df085826a4bdd10d531f71 Mon Sep 17 00:00:00 2001 From: edknv <109497216+edknv@users.noreply.github.com> Date: Mon, 3 Jul 2023 22:59:24 +0900 Subject: [PATCH 2/2] Add DLRM Model (#1171) * Add DLRM Model * make model a class rather a function --------- Co-authored-by: Marc Romeyn --- merlin/models/torch/__init__.py | 2 + merlin/models/torch/models/base.py | 5 ++ merlin/models/torch/models/ranking.py | 76 +++++++++++++++++++++++ merlin/models/torch/utils/module_utils.py | 2 +- tests/unit/torch/models/test_base.py | 36 ++++++----- tests/unit/torch/models/test_ranking.py | 38 ++++++++++++ 6 files changed, 144 insertions(+), 15 deletions(-) create mode 100644 merlin/models/torch/models/ranking.py create mode 100644 tests/unit/torch/models/test_ranking.py diff --git a/merlin/models/torch/__init__.py b/merlin/models/torch/__init__.py index 988897ef44..025c8ba0dc 100644 --- a/merlin/models/torch/__init__.py +++ b/merlin/models/torch/__init__.py @@ -23,6 +23,7 @@ from merlin.models.torch.inputs.select import SelectFeatures, SelectKeys from merlin.models.torch.inputs.tabular import TabularInputBlock from merlin.models.torch.models.base import Model +from merlin.models.torch.models.ranking import DLRMModel from merlin.models.torch.outputs.base import ModelOutput from merlin.models.torch.outputs.classification import BinaryOutput from merlin.models.torch.outputs.regression import RegressionOutput @@ -55,4 +56,5 @@ "Stack", "schema", "DLRMBlock", + "DLRMModel", ] diff --git a/merlin/models/torch/models/base.py b/merlin/models/torch/models/base.py index 56851d285a..df1826746c 100644 --- a/merlin/models/torch/models/base.py +++ b/merlin/models/torch/models/base.py @@ -196,6 +196,11 @@ def compute_loss( else: raise ValueError(f"Unknown 'predictions' type: {type(predictions)}") + if _targets.size() != _predictions.size(): + _targets = _targets.view(_predictions.size()) + if _targets.type() != _predictions.type(): + _targets = _targets.type_as(_predictions) + results["loss"] = results["loss"] + model_out.loss(_predictions, _targets) / len( model_outputs ) diff --git a/merlin/models/torch/models/ranking.py b/merlin/models/torch/models/ranking.py new file mode 100644 index 0000000000..292abebbd8 --- /dev/null +++ b/merlin/models/torch/models/ranking.py @@ -0,0 +1,76 @@ +from typing import Optional + +from torch import nn + +from merlin.models.torch.block import Block +from merlin.models.torch.blocks.dlrm import DLRMBlock +from merlin.models.torch.models.base import Model +from merlin.models.torch.outputs.tabular import TabularOutputBlock +from merlin.schema import Schema + + +class DLRMModel(Model): + """ + The Deep Learning Recommendation Model (DLRM) as proposed in Naumov, et al. [1] + + Parameters + ---------- + schema : Schema + The schema to use for selection. + dim : int + The dimensionality of the output vectors. + bottom_block : Block + Block to pass the continuous features to. + Note that, the output dimensionality of this block must be equal to ``dim``. + top_block : Block, optional + An optional upper-level block of the model. + interaction : nn.Module, optional + Interaction module for DLRM. + If not provided, DLRMInteraction will be used by default. + output_block : Block, optional + The output block of the model, by default None. + If None, a TabularOutputBlock with schema and default initializations is used. + + Returns + ------- + Model + An instance of Model class representing the fully formed DLRM. + + Example usage + ------------- + >>> model = mm.DLRMModel( + ... schema, + ... dim=64, + ... bottom_block=mm.MLPBlock([256, 64]), + ... output_block=BinaryOutput(ColumnSchema("target"))) + >>> trainer = pl.Trainer() + >>> model.initialize(dataloader) + >>> trainer.fit(model, dataloader) + + References + ---------- + [1] Naumov, Maxim, et al. "Deep learning recommendation model for + personalization and recommendation systems." arXiv preprint arXiv:1906.00091 (2019). + """ + + def __init__( + self, + schema: Schema, + dim: int, + bottom_block: Block, + top_block: Optional[Block] = None, + interaction: Optional[nn.Module] = None, + output_block: Optional[Block] = None, + ) -> None: + if output_block is None: + output_block = TabularOutputBlock(schema, init="defaults") + + dlrm_body = DLRMBlock( + schema, + dim, + bottom_block, + top_block=top_block, + interaction=interaction, + ) + + super().__init__(dlrm_body, output_block) diff --git a/merlin/models/torch/utils/module_utils.py b/merlin/models/torch/utils/module_utils.py index d80cc649ba..4dcc965c51 100644 --- a/merlin/models/torch/utils/module_utils.py +++ b/merlin/models/torch/utils/module_utils.py @@ -236,7 +236,7 @@ def initialize(module, data: Union[Dataset, Loader, Batch], dtype=torch.float32) if hasattr(module, "model_outputs"): for model_out in module.model_outputs(): for metric in model_out.metrics: - metric.to(batch.device()) + metric.to(device=batch.device()) from merlin.models.torch import schema diff --git a/tests/unit/torch/models/test_base.py b/tests/unit/torch/models/test_base.py index ab329b8ca1..d51589b8f1 100644 --- a/tests/unit/torch/models/test_base.py +++ b/tests/unit/torch/models/test_base.py @@ -15,6 +15,7 @@ # import pandas as pd import pytest +import pytorch_lightning as pl import torch from torch import nn from torchmetrics import AUROC, Accuracy, Precision, Recall @@ -22,7 +23,7 @@ import merlin.models.torch as mm from merlin.dataloader.torch import Loader from merlin.io import Dataset -from merlin.models.torch.batch import Batch +from merlin.models.torch.batch import Batch, sample_batch from merlin.models.torch.models.base import compute_loss from merlin.models.torch.utils import module_utils from merlin.schema import ColumnSchema @@ -200,22 +201,29 @@ def test_no_output_schema(self): with pytest.raises(ValueError, match="Could not get output schema of PlusOne()"): mm.schema.output(model) - # def test_train_classification(self, music_streaming_data): - # schema = music_streaming_data.schema.without(["user_genres", "like", "item_genres"]) - # music_streaming_data.schema = schema + def test_train_classification_with_lightning_trainer(self, music_streaming_data, batch_size=16): + schema = music_streaming_data.schema.select_by_name( + ["item_id", "user_id", "user_age", "item_genres", "click"] + ) + music_streaming_data.schema = schema - # model = mm.Model( - # mm.TabularInputBlock(schema), - # mm.MLPBlock([4, 2]), - # mm.BinaryOutput(schema.select_by_name("click").first), - # schema=schema, - # ) + model = mm.Model( + mm.TabularInputBlock(schema, init="defaults"), + mm.MLPBlock([4, 2]), + mm.BinaryOutput(schema.select_by_name("click").first), + ) + + trainer = pl.Trainer(max_epochs=1, devices=1) + + with Loader(music_streaming_data, batch_size=batch_size) as loader: + model.initialize(loader) + trainer.fit(model, loader) - # trainer = pl.Trainer(max_epochs=1) + assert trainer.logged_metrics["train_loss"] > 0.0 + assert trainer.num_training_batches == 7 # 100 rows // 16 per batch + 1 for last batch - # with Loader(music_streaming_data, batch_size=16) as loader: - # model.initialize(loader) - # trainer.fit(model, loader) + batch = sample_batch(music_streaming_data, batch_size) + _ = module_utils.module_test(model, batch) class TestComputeLoss: diff --git a/tests/unit/torch/models/test_ranking.py b/tests/unit/torch/models/test_ranking.py new file mode 100644 index 0000000000..0fb463e0ef --- /dev/null +++ b/tests/unit/torch/models/test_ranking.py @@ -0,0 +1,38 @@ +import pytest +import pytorch_lightning as pl + +import merlin.models.torch as mm +from merlin.dataloader.torch import Loader +from merlin.models.torch.batch import sample_batch +from merlin.models.torch.utils import module_utils +from merlin.schema import ColumnSchema + + +@pytest.mark.parametrize("output_block", [None, mm.BinaryOutput(ColumnSchema("click"))]) +class TestDLRMModel: + def test_train_dlrm_with_lightning_loader( + self, music_streaming_data, output_block, dim=2, batch_size=16 + ): + schema = music_streaming_data.schema.select_by_name( + ["item_id", "user_id", "user_age", "item_genres", "click"] + ) + music_streaming_data.schema = schema + + model = mm.DLRMModel( + schema, + dim=dim, + bottom_block=mm.MLPBlock([4, 2]), + top_block=mm.MLPBlock([4, 2]), + output_block=output_block, + ) + + trainer = pl.Trainer(max_epochs=1, devices=1) + + with Loader(music_streaming_data, batch_size=batch_size) as train_loader: + model.initialize(train_loader) + trainer.fit(model, train_loader) + + assert trainer.logged_metrics["train_loss"] > 0.0 + + batch = sample_batch(music_streaming_data, batch_size) + _ = module_utils.module_test(model, batch)