From 7efc68d19523e440150bf6d685a5ba0b6cf84aeb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Tue, 10 Dec 2024 20:07:38 +0800 Subject: [PATCH] add zero collision hash embedding --- requirements/runtime.txt | 4 +- tzrec/datasets/data_parser.py | 1 + tzrec/features/combo_feature.py | 10 +- tzrec/features/feature.py | 49 +++ tzrec/features/id_feature.py | 12 +- tzrec/features/id_feature_test.py | 40 ++ tzrec/features/lookup_feature.py | 7 +- tzrec/features/match_feature.py | 7 +- tzrec/features/sequence_feature.py | 5 +- tzrec/main.py | 7 +- tzrec/modules/embedding.py | 262 +++++++++++-- tzrec/modules/embedding_test.py | 169 ++++++--- tzrec/protos/feature.proto | 55 +++ .../multi_tower_din_zch_fg_mock.config | 358 ++++++++++++++++++ tzrec/tests/match_integration_test.py | 256 +++++++++++++ ...xport_test.py => rank_integration_test.py} | 300 ++------------- tzrec/utils/fx_util.py | 45 +++ tzrec/utils/test_util.py | 2 +- tzrec/version.py | 2 +- 19 files changed, 1228 insertions(+), 363 deletions(-) create mode 100644 tzrec/tests/configs/multi_tower_din_zch_fg_mock.config create mode 100644 tzrec/tests/match_integration_test.py rename tzrec/tests/{train_eval_export_test.py => rank_integration_test.py} (72%) create mode 100644 tzrec/utils/fx_util.py diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 4739af5..93a7298 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -7,8 +7,8 @@ graphlearn @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/graphlearn-1. graphlearn @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/graphlearn-1.3.1-cp310-cp310-linux_x86_64.whl ; python_version=="3.10" grpcio-tools<1.63.0 pandas -pyfg @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/pyfg-0.3.7-cp311-cp311-linux_x86_64.whl ; python_version=="3.11" -pyfg @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/pyfg-0.3.7-cp310-cp310-linux_x86_64.whl ; python_version=="3.10" +pyfg @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/pyfg-0.3.9-cp311-cp311-linux_x86_64.whl ; python_version=="3.11" +pyfg @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/pyfg-0.3.9-cp310-cp310-linux_x86_64.whl ; python_version=="3.10" pyodps>=0.12.0 scikit-learn tensorboard diff --git a/tzrec/datasets/data_parser.py b/tzrec/datasets/data_parser.py index b7c38f9..2d1267d 100644 --- a/tzrec/datasets/data_parser.py +++ b/tzrec/datasets/data_parser.py @@ -119,6 +119,7 @@ def _init_fg_hander(self) -> None: if not self._fg_handler: fg_json = create_fg_json(self._features) # pyre-ignore [16] + print(fg_json) self._fg_handler = pyfg.FgArrowHandler(fg_json, self._fg_threads) def parse(self, input_data: Dict[str, pa.Array]) -> Dict[str, torch.Tensor]: diff --git a/tzrec/features/combo_feature.py b/tzrec/features/combo_feature.py index 1664f28..f5cd55b 100644 --- a/tzrec/features/combo_feature.py +++ b/tzrec/features/combo_feature.py @@ -20,7 +20,11 @@ ParsedData, SparseData, ) -from tzrec.features.feature import FgMode, _parse_fg_encoded_sparse_feature_impl +from tzrec.features.feature import ( + MAX_HASH_BUCKET_SIZE, + FgMode, + _parse_fg_encoded_sparse_feature_impl, +) from tzrec.features.id_feature import IdFeature from tzrec.protos.feature_pb2 import FeatureConfig from tzrec.utils.logging_util import logger @@ -116,7 +120,9 @@ def fg_json(self) -> List[Dict[str, Any]]: } if self.config.separator != "\x1d": fg_cfg["separator"] = self.config.separator - if self.config.HasField("hash_bucket_size"): + if self.config.HasField("zch"): + fg_cfg["hash_bucket_size"] = MAX_HASH_BUCKET_SIZE + elif self.config.HasField("hash_bucket_size"): fg_cfg["hash_bucket_size"] = self.config.hash_bucket_size elif len(self.config.vocab_list) > 0: fg_cfg["vocab_list"] = [self.config.default_value, ""] + list( diff --git a/tzrec/features/feature.py b/tzrec/features/feature.py index 1268d89..5c10617 100644 --- a/tzrec/features/feature.py +++ b/tzrec/features/feature.py @@ -20,12 +20,23 @@ import numpy as np import pyarrow as pa import pyfg +import torch from torch import nn # NOQA from torchrec.modules.embedding_configs import ( EmbeddingBagConfig, EmbeddingConfig, PoolingType, ) +from torchrec.modules.mc_modules import ( + DistanceLFU_EvictionPolicy, + LFU_EvictionPolicy, + LRU_EvictionPolicy, + ManagedCollisionModule, + MCHManagedCollisionModule, + average_threshold_filter, # NOQA + dynamic_threshold_filter, # NOQA + probabilistic_threshold_filter, # NOQA +) from tzrec.datasets.utils import ( BASE_DATA_GROUP, @@ -52,6 +63,9 @@ class FgMode(Enum): DAG = 3 +MAX_HASH_BUCKET_SIZE = 2**31 - 1 + + def _parse_fg_encoded_sparse_feature_impl( name: str, feat: pa.Array, @@ -386,6 +400,41 @@ def emb_config(self) -> Optional[EmbeddingConfig]: else: return None + def mc_module(self, device: torch.device) -> Optional[ManagedCollisionModule]: + """Get ManagedCollisionModule.""" + if self.is_sparse: + if hasattr(self.config, "zch") and self.config.HasField("zch"): + evict_type = self.config.zch.WhichOneof("eviction_policy") + evict_config = getattr(self.config.zch, evict_type) + threshold_filtering_func = None + if evict_config.HasField("threshold_filtering_func"): + threshold_filtering_func = eval( + evict_config.threshold_filtering_func + ) + if evict_type == "lfu": + eviction_policy = LFU_EvictionPolicy( + threshold_filtering_func=threshold_filtering_func + ) + elif evict_type == "lru": + eviction_policy = LRU_EvictionPolicy( + decay_exponent=evict_config.decay_exponent, + threshold_filtering_func=threshold_filtering_func, + ) + elif evict_type == "distance_lfu": + eviction_policy = DistanceLFU_EvictionPolicy( + decay_exponent=evict_config.decay_exponent, + threshold_filtering_func=threshold_filtering_func, + ) + else: + raise ValueError("Unknown evict policy type: {evict_type}") + return MCHManagedCollisionModule( + zch_size=self.config.zch.zch_size, + device=device, + eviction_interval=self.config.zch.eviction_interval, + eviction_policy=eviction_policy, + ) + return None + @property def inputs(self) -> List[str]: """Input field names.""" diff --git a/tzrec/features/id_feature.py b/tzrec/features/id_feature.py index 13891a9..5102dd5 100644 --- a/tzrec/features/id_feature.py +++ b/tzrec/features/id_feature.py @@ -20,6 +20,7 @@ SparseData, ) from tzrec.features.feature import ( + MAX_HASH_BUCKET_SIZE, BaseFeature, FgMode, _parse_fg_encoded_sparse_feature_impl, @@ -71,7 +72,9 @@ def is_sparse(self) -> bool: @property def num_embeddings(self) -> int: """Get embedding row count.""" - if self.config.HasField("hash_bucket_size"): + if self.config.HasField("zch"): + num_embeddings = self.config.zch.zch_size + elif self.config.HasField("hash_bucket_size"): num_embeddings = self.config.hash_bucket_size elif self.config.HasField("num_buckets"): num_embeddings = self.config.num_buckets @@ -89,7 +92,7 @@ def num_embeddings(self) -> int: else: raise ValueError( f"{self.__class__.__name__}[{self.name}] must set hash_bucket_size" - " or num_buckets or vocab_list or vocab_dict" + " or num_buckets or vocab_list or vocab_dict or zch.zch_size" ) return num_embeddings @@ -156,7 +159,10 @@ def fg_json(self) -> List[Dict[str, Any]]: } if self.config.separator != "\x1d": fg_cfg["separator"] = self.config.separator - if self.config.HasField("hash_bucket_size"): + if self.config.HasField("zch"): + fg_cfg["hash_bucket_size"] = MAX_HASH_BUCKET_SIZE + fg_cfg["value_type"] = "string" + elif self.config.HasField("hash_bucket_size"): fg_cfg["hash_bucket_size"] = self.config.hash_bucket_size fg_cfg["value_type"] = "string" elif len(self.config.vocab_list) > 0: diff --git a/tzrec/features/id_feature_test.py b/tzrec/features/id_feature_test.py index e35025e..1c12d69 100644 --- a/tzrec/features/id_feature_test.py +++ b/tzrec/features/id_feature_test.py @@ -15,6 +15,7 @@ import numpy as np import pyarrow as pa +import torch from parameterized import parameterized from torch import nn from torchrec.modules.embedding_configs import ( @@ -88,6 +89,45 @@ def test_init_fn_id_feature(self): ) self.assertEqual(repr(id_feat.emb_config), repr(expected_emb_config)) + def test_zch_id_feature(self): + id_feat_cfg = feature_pb2.FeatureConfig( + id_feature=feature_pb2.IdFeature( + feature_name="id_feat", + embedding_dim=16, + zch=feature_pb2.ZeroCollisionHash( + zch_size=100, + eviction_interval=5, + distance_lfu=feature_pb2.DistanceLFU_EvictionPolicy( + decay_exponent=1.0, + threshold_filtering_func="lambda x:" + " probabilistic_threshold_filter(x,0.05)", + ), + ), + ) + ) + id_feat = id_feature_lib.IdFeature(id_feat_cfg) + expected_emb_bag_config = EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=16, + name="id_feat_emb", + feature_names=["id_feat"], + pooling=PoolingType.SUM, + ) + self.assertEqual(repr(id_feat.emb_bag_config), repr(expected_emb_bag_config)) + expected_emb_config = EmbeddingConfig( + num_embeddings=100, + embedding_dim=16, + name="id_feat_emb", + feature_names=["id_feat"], + ) + self.assertEqual(repr(id_feat.emb_config), repr(expected_emb_config)) + mc_module = id_feat.mc_module(torch.device("meta")) + self.assertEqual(mc_module._zch_size, 100) + self.assertEqual(mc_module._eviction_interval, 5) + self.assertTrue( + mc_module._eviction_policy._threshold_filtering_func is not None + ) + def test_fg_encoded_with_weighted(self): id_feat_cfg = feature_pb2.FeatureConfig( id_feature=feature_pb2.IdFeature( diff --git a/tzrec/features/lookup_feature.py b/tzrec/features/lookup_feature.py index a24bf63..2899c4a 100644 --- a/tzrec/features/lookup_feature.py +++ b/tzrec/features/lookup_feature.py @@ -23,6 +23,7 @@ SparseData, ) from tzrec.features.feature import ( + MAX_HASH_BUCKET_SIZE, BaseFeature, FgMode, _parse_fg_encoded_dense_feature_impl, @@ -201,7 +202,11 @@ def fg_json(self) -> List[Dict[str, Any]]: fg_cfg["separator"] = self.config.separator if self.config.HasField("normalizer"): fg_cfg["normalizer"] = self.config.normalizer - if self.config.HasField("hash_bucket_size"): + if self.config.HasField("zch"): + fg_cfg["hash_bucket_size"] = MAX_HASH_BUCKET_SIZE + fg_cfg["value_type"] = "string" + fg_cfg["needDiscrete"] = True + elif self.config.HasField("hash_bucket_size"): fg_cfg["hash_bucket_size"] = self.config.hash_bucket_size fg_cfg["value_type"] = "string" fg_cfg["needDiscrete"] = True diff --git a/tzrec/features/match_feature.py b/tzrec/features/match_feature.py index 13de909..fcd8136 100644 --- a/tzrec/features/match_feature.py +++ b/tzrec/features/match_feature.py @@ -23,6 +23,7 @@ SparseData, ) from tzrec.features.feature import ( + MAX_HASH_BUCKET_SIZE, BaseFeature, FgMode, _parse_fg_encoded_dense_feature_impl, @@ -174,7 +175,11 @@ def fg_json(self) -> List[Dict[str, Any]]: fg_cfg["separator"] = self.config.separator if self.config.HasField("normalizer"): fg_cfg["normalizer"] = self.config.normalizer - if self.config.HasField("hash_bucket_size"): + if self.config.HasField("zch"): + fg_cfg["hash_bucket_size"] = MAX_HASH_BUCKET_SIZE + fg_cfg["value_type"] = "string" + fg_cfg["needDiscrete"] = True + elif self.config.HasField("hash_bucket_size"): fg_cfg["hash_bucket_size"] = self.config.hash_bucket_size fg_cfg["value_type"] = "string" fg_cfg["needDiscrete"] = True diff --git a/tzrec/features/sequence_feature.py b/tzrec/features/sequence_feature.py index ee1734e..fa7c9c6 100644 --- a/tzrec/features/sequence_feature.py +++ b/tzrec/features/sequence_feature.py @@ -22,6 +22,7 @@ SequenceDenseData, SequenceSparseData, ) +from tzrec.features.feature import MAX_HASH_BUCKET_SIZE from tzrec.features.id_feature import FgMode, IdFeature from tzrec.features.raw_feature import RawFeature from tzrec.protos import feature_pb2 @@ -303,7 +304,9 @@ def fg_json(self) -> List[Dict[str, Any]]: fg_cfg["sequence_length"] = self.config.sequence_length if self.config.separator != "\x1d": fg_cfg["separator"] = self.config.separator - if self.config.HasField("hash_bucket_size"): + if self.config.HasField("zch"): + fg_cfg["hash_bucket_size"] = MAX_HASH_BUCKET_SIZE + elif self.config.HasField("hash_bucket_size"): fg_cfg["hash_bucket_size"] = self.config.hash_bucket_size elif self.config.HasField("num_buckets"): fg_cfg["num_buckets"] = self.config.num_buckets diff --git a/tzrec/main.py b/tzrec/main.py index 6d68374..962f09a 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -35,7 +35,6 @@ # NOQA from torchrec.distributed.train_pipeline import TrainPipelineSparseDist -from torchrec.fx import symbolic_trace from torchrec.inference.modules import quantize_embeddings from torchrec.inference.state_dict_transform import ( state_dict_gather, @@ -85,6 +84,7 @@ from tzrec.protos.pipeline_pb2 import EasyRecConfig from tzrec.protos.train_pb2 import TrainConfig from tzrec.utils import checkpoint_util, config_util +from tzrec.utils.fx_util import symbolic_trace from tzrec.utils.logging_util import ProgressLogger, logger from tzrec.utils.plan_util import create_planner, get_default_sharders from tzrec.version import __version__ as tzrec_version @@ -751,7 +751,7 @@ def _script_model( model.eval() if is_trt_convert: - data_cuda = batch.to_dict(sparse_dtype=torch.int32) + data_cuda = batch.to_dict(sparse_dtype=torch.int64) result = model(data_cuda, "cuda:0") result_info = {k: (v.size(), v.dtype) for k, v in result.items()} logger.info(f"Model Outputs: {result_info}") @@ -1008,7 +1008,6 @@ def predict( device_and_backend = init_process_group() device: torch.device = device_and_backend[0] - sparse_dtype: torch.dtype = torch.int32 if device.type == "cuda" else torch.int64 is_rank_zero = int(os.environ.get("RANK", 0)) == 0 is_local_rank_zero = int(os.environ.get("LOCAL_RANK", 0)) == 0 @@ -1077,7 +1076,7 @@ def predict( def _forward(batch: Batch) -> Tuple[Dict[str, torch.Tensor], RecordBatchTensor]: with torch.no_grad(): - parsed_inputs = batch.to_dict(sparse_dtype=sparse_dtype) + parsed_inputs = batch.to_dict(sparse_dtype=torch.int64) # when predicting with a model exported using INPUT_TILE, # we set the batch size tensor to 1 to disable tiling. parsed_inputs["batch_size"] = torch.tensor(1, dtype=torch.int64) diff --git a/tzrec/modules/embedding.py b/tzrec/modules/embedding.py index 6f4f7cc..deb71e0 100644 --- a/tzrec/modules/embedding.py +++ b/tzrec/modules/embedding.py @@ -19,6 +19,15 @@ EmbeddingBagCollection, EmbeddingCollection, ) +from torchrec.modules.mc_embedding_modules import ( + ManagedCollisionEmbeddingBagCollection, + ManagedCollisionEmbeddingCollection, +) +from torchrec.modules.mc_modules import ( + ManagedCollisionCollection, + ManagedCollisionModule, + MCHManagedCollisionModule, +) from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor from tzrec.acc.utils import is_input_tile, is_input_tile_emb @@ -464,17 +473,14 @@ def predict( return values_list -def add_embedding_bag_config( +def _add_embedding_bag_config( emb_bag_configs: Dict[str, EmbeddingBagConfig], emb_bag_config: EmbeddingBagConfig ) -> None: - """Add embedding bag config . + """Add embedding bag config to a dict of embedding bag config. Args: - emb_bag_configs: Dict[str, EmbeddingBagConfig]: a dict contains emb_bag_configs - emb_bag_config: EmbeddingBagConfig - - Returns: - None + emb_bag_configs(Dict[str, EmbeddingBagConfig]): a dict contains emb_bag_configs + emb_bag_config(EmbeddingBagConfig): an instance of EmbeddingBagConfig """ if emb_bag_config.name in emb_bag_configs: existed_emb_bag_config = emb_bag_configs[emb_bag_config.name] @@ -494,17 +500,14 @@ def add_embedding_bag_config( emb_bag_configs[emb_bag_config.name] = emb_bag_config -def add_embedding_config( +def _add_embedding_config( emb_configs: Dict[str, EmbeddingConfig], emb_config: EmbeddingConfig ) -> None: - """Add embedding config . + """Add embedding config to a dict of embedding config. Args: - emb_configs: Dict[str, EmbeddingConfig]: a dict contains emb_configs - emb_config: EmbeddingConfig - - Returns: - None + emb_configs(Dict[str, EmbeddingConfig]): a dict contains emb_configs + emb_config(EmbeddingConfig): an instance of EmbeddingConfig """ if emb_config.name in emb_configs: existed_emb_config = emb_configs[emb_config.name] @@ -523,6 +526,28 @@ def add_embedding_config( emb_configs[emb_config.name] = emb_config +def _add_mc_module( + mc_modules: Dict[str, ManagedCollisionModule], + emb_name: str, + mc_module: ManagedCollisionModule, +): + """Add ManagedCollisionModule to a dict of ManagedCollisionModule. + + Args: + mc_modules(Dict[str, ManagedCollisionModule]): a dict of ManagedCollisionModule. + emb_name(str): embedding_name. + mc_module(ManagedCollisionModule): an instance of ManagedCollisionModule. + """ + if emb_name in mc_modules: + existed_mc_module = mc_modules[emb_name] + if isinstance(mc_module, MCHManagedCollisionModule): + assert isinstance(existed_mc_module, MCHManagedCollisionModule) + assert mc_module._zch_size == existed_mc_module._zch_size + assert mc_module._eviction_interval == existed_mc_module._eviction_interval + assert repr(mc_module._eviction_policy) == repr(mc_module._eviction_policy) + mc_modules[emb_name] = mc_module + + class EmbeddingGroupImpl(nn.Module): """Applies embedding lookup transformation for feature group. @@ -545,9 +570,13 @@ def __init__( device = torch.device("meta") name_to_feature = {x.name: x for x in features} emb_bag_configs = OrderedDict() + mc_emb_bag_configs = OrderedDict() + mc_modules = OrderedDict() self.has_sparse = False self.has_sparse_user = False + self.has_mc_sparse = False + self.has_mc_sparse_user = False self.has_dense = False self.has_dense_user = False @@ -580,6 +609,10 @@ def __init__( input_tile = is_input_tile() emb_bag_configs_user = OrderedDict() emb_bag_configs_item = OrderedDict() + mc_emb_bag_configs_user = OrderedDict() + mc_emb_bag_configs_item = OrderedDict() + mc_modules_item = OrderedDict() + mc_modules_user = OrderedDict() for feature_group in feature_groups: total_dim = 0 @@ -595,6 +628,7 @@ def __init__( if feature.is_sparse: output_dim = feature.output_dim emb_bag_config = feature.emb_bag_config + mc_module = feature.mc_module(device) assert emb_bag_config is not None if is_wide: # TODO(hongsheng.jhs): change to embedding_dim to 1 @@ -607,24 +641,45 @@ def __init__( if input_tile_emb: if feature.is_user_feat: - add_embedding_bag_config( - emb_bag_configs=emb_bag_configs_user, + _add_embedding_bag_config( + emb_bag_configs=mc_emb_bag_configs_user + if mc_module + else emb_bag_configs_user, emb_bag_config=emb_bag_config, ) + if mc_module: + _add_mc_module( + mc_modules_user, emb_bag_config.name, mc_module + ) + self.has_mc_sparse_user = True + else: + self.has_sparse_user = True else: - add_embedding_bag_config( - emb_bag_configs=emb_bag_configs_item, + _add_embedding_bag_config( + emb_bag_configs=mc_emb_bag_configs_item + if mc_module + else emb_bag_configs_item, emb_bag_config=emb_bag_config, ) + if mc_module: + _add_mc_module( + mc_modules_item, emb_bag_config.name, mc_module + ) + self.has_mc_sparse = True + else: + self.has_sparse = True else: - add_embedding_bag_config( - emb_bag_configs=emb_bag_configs, + _add_embedding_bag_config( + emb_bag_configs=mc_emb_bag_configs + if mc_module + else emb_bag_configs, emb_bag_config=emb_bag_config, ) - if input_tile_emb and feature.is_user_feat: - self.has_sparse_user = True - else: - self.has_sparse = True + if mc_module: + _add_mc_module(mc_modules, emb_bag_config.name, mc_module) + self.has_mc_sparse = True + else: + self.has_sparse = True if shared_feature_flag[name]: shared_name = shared_name + "@" + emb_bag_config.name @@ -656,10 +711,37 @@ def __init__( self.ebc_item = EmbeddingBagCollection( list(emb_bag_configs_item.values()), device=device ) + if self.has_mc_sparse_user: + self.mc_ebc_user = ManagedCollisionEmbeddingBagCollection( + EmbeddingBagCollection( + list(mc_emb_bag_configs_user.values()), device=device + ), + ManagedCollisionCollection( + mc_modules_user, list(mc_emb_bag_configs_user.values()) + ), + ) + if self.has_mc_sparse: + self.mc_ebc_item = ManagedCollisionEmbeddingBagCollection( + EmbeddingBagCollection( + list(mc_emb_bag_configs_item.values()), device=device + ), + ManagedCollisionCollection( + mc_modules_item, list(mc_emb_bag_configs_item.values()) + ), + ) else: self.ebc = EmbeddingBagCollection( list(emb_bag_configs.values()), device=device ) + if self.has_mc_sparse: + self.mc_ebc = ManagedCollisionEmbeddingBagCollection( + EmbeddingBagCollection( + list(mc_emb_bag_configs.values()), device=device + ), + ManagedCollisionCollection( + mc_modules, list(mc_emb_bag_configs.values()) + ), + ) def group_dims(self, group_name: str) -> List[int]: """Output dimension of each feature in a feature group.""" @@ -689,6 +771,12 @@ def forward( else: kts.append(self.ebc(sparse_feature)) + if self.has_mc_sparse: + if is_input_tile_emb(): + kts.append(self.mc_ebc_item(sparse_feature)[0]) + else: + kts.append(self.mc_ebc(sparse_feature)[0]) + if self.has_sparse_user: keyed_tensor_user = self.ebc_user(sparse_feature_user) values_tile = keyed_tensor_user.values().tile(batch_size, 1) @@ -699,6 +787,16 @@ def forward( ) kts.append(keyed_tensor_user_tile) + if self.has_mc_sparse_user: + keyed_tensor_user = self.mc_ebc_user(sparse_feature_user)[0] + values_tile = keyed_tensor_user.values().tile(batch_size, 1) + keyed_tensor_user_tile = KeyedTensor( + keys=keyed_tensor_user.keys(), + length_per_key=keyed_tensor_user.length_per_key(), + values=values_tile, + ) + kts.append(keyed_tensor_user_tile) + if self.has_dense: kts.append(dense_feature) @@ -734,9 +832,13 @@ def __init__( device = torch.device("meta") name_to_feature = {x.name: x for x in features} dim_to_emb_configs = defaultdict(OrderedDict) + dim_to_mc_emb_configs = defaultdict(OrderedDict) + dim_to_mc_modules = defaultdict(OrderedDict) self.has_sparse = False self.has_sparse_user = False + self.has_mc_sparse = False + self.has_mc_sparse_user = False self.has_dense = False self.has_dense_user = False self.has_sequence_dense = False @@ -768,6 +870,10 @@ def __init__( need_input_tile_emb = is_input_tile_emb() dim_to_emb_configs_user = defaultdict(OrderedDict) dim_to_emb_configs_item = defaultdict(OrderedDict) + dim_to_mc_emb_configs_user = defaultdict(OrderedDict) + dim_to_mc_emb_configs_item = defaultdict(OrderedDict) + dim_to_mc_modules_user = defaultdict(OrderedDict) + dim_to_mc_modules_item = defaultdict(OrderedDict) for feature_group in feature_groups: query_dim = 0 @@ -785,35 +891,71 @@ def __init__( if feature.is_sparse: output_dim = feature.output_dim emb_config = feature.emb_config + mc_module = feature.mc_module(device) + assert emb_config is not None # we may/could modify ec name at feat_to_group_to_emb_name emb_config.name = feat_to_group_to_emb_name[name][group_name] - assert emb_config is not None - emb_configs = dim_to_emb_configs[emb_config.embedding_dim] + embedding_dim = emb_config.embedding_dim if need_input_tile_emb: if feature.is_user_feat: - add_embedding_config( - emb_configs=dim_to_emb_configs_user[ - emb_config.embedding_dim - ], + emb_configs = ( + dim_to_mc_emb_configs_user[embedding_dim] + if mc_module + else dim_to_emb_configs_user[embedding_dim] + ) + _add_embedding_config( + emb_configs=emb_configs, emb_config=emb_config, ) + if mc_module: + _add_mc_module( + dim_to_mc_modules_user[embedding_dim], + emb_config.name, + mc_module, + ) + self.has_mc_sparse_user = True + else: + self.has_sparse_user = True else: - add_embedding_config( - emb_configs=dim_to_emb_configs_item[ - emb_config.embedding_dim - ], + emb_configs = ( + dim_to_mc_emb_configs_item[embedding_dim] + if mc_module + else dim_to_emb_configs_item[embedding_dim] + ) + _add_embedding_config( + emb_configs=emb_configs, emb_config=emb_config, ) + if mc_module: + _add_mc_module( + dim_to_mc_modules_item[embedding_dim], + emb_config.name, + mc_module, + ) + self.has_mc_sparse = True + else: + self.has_sparse = True else: - add_embedding_config( + emb_configs = ( + dim_to_mc_emb_configs[embedding_dim] + if mc_module + else dim_to_emb_configs[embedding_dim] + ) + _add_embedding_config( emb_configs=emb_configs, emb_config=emb_config, ) - if need_input_tile_emb and feature.is_user_feat: - self.has_sparse_user = True - else: - self.has_sparse = True + if mc_module: + _add_mc_module( + dim_to_mc_modules[embedding_dim], + emb_config.name, + mc_module, + ) + self.has_mc_sparse = True + else: + self.has_sparse = True + if shared_feature_flag[name]: shared_name = shared_name + "@" + emb_config.name else: @@ -857,12 +999,42 @@ def __init__( self.ec_list_item.append( EmbeddingCollection(list(emb_configs.values()), device=device) ) + self.mc_ec_list_user = nn.ModuleList() + for k, emb_configs in dim_to_mc_emb_configs_user.items(): + self.mc_ec_list_user.append( + ManagedCollisionEmbeddingCollection( + EmbeddingCollection(list(emb_configs.values()), device=device), + ManagedCollisionCollection( + dim_to_mc_modules_user[k], list(emb_configs.values()) + ), + ) + ) + self.mc_ec_list_item = nn.ModuleList() + for k, emb_configs in dim_to_mc_emb_configs_item.items(): + self.mc_ec_list_item.append( + ManagedCollisionEmbeddingCollection( + EmbeddingCollection(list(emb_configs.values()), device=device), + ManagedCollisionCollection( + dim_to_mc_modules_item[k], list(emb_configs.values()) + ), + ) + ) else: self.ec_list = nn.ModuleList() for _, emb_configs in dim_to_emb_configs.items(): self.ec_list.append( EmbeddingCollection(list(emb_configs.values()), device=device) ) + self.mc_ec_list = nn.ModuleList() + for k, emb_configs in dim_to_mc_emb_configs.items(): + self.mc_ec_list.append( + ManagedCollisionEmbeddingCollection( + EmbeddingCollection(list(emb_configs.values()), device=device), + ManagedCollisionCollection( + dim_to_mc_modules[k], list(emb_configs.values()) + ), + ) + ) def group_dims(self, group_name: str) -> List[int]: """Output dimension of each feature in a feature group.""" @@ -911,10 +1083,22 @@ def forward( for ec in self.ec_list: sparse_jt_dict_list.append(ec(sparse_feature)) + if self.has_mc_sparse: + if need_input_tile_emb: + for ec in self.mc_ec_list_item: + sparse_jt_dict_list.append(ec(sparse_feature)[0]) + else: + for ec in self.mc_ec_list: + sparse_jt_dict_list.append(ec(sparse_feature)[0]) + if self.has_sparse_user: for ec in self.ec_list_user: sparse_jt_dict_list.append(ec(sparse_feature_user)) + if self.has_mc_sparse_user: + for ec in self.mc_ec_list_user: + sparse_jt_dict_list.append(ec(sparse_feature_user)[0]) + sparse_jt_dict = _merge_list_of_jt_dict(sparse_jt_dict_list) if self.has_dense: diff --git a/tzrec/modules/embedding_test.py b/tzrec/modules/embedding_test.py index 265a98d..5ec1b87 100644 --- a/tzrec/modules/embedding_test.py +++ b/tzrec/modules/embedding_test.py @@ -31,16 +31,28 @@ from tzrec.utils.test_util import TestGraphType, create_test_module -def _create_test_features(): +def _create_test_features(has_zch=False): + cat_a_kwargs = {} + cat_b_kwargs = {} + if has_zch: + cat_a_kwargs["zch"] = feature_pb2.ZeroCollisionHash( + zch_size=100, lfu=feature_pb2.LFU_EvictionPolicy() + ) + cat_b_kwargs["zch"] = feature_pb2.ZeroCollisionHash( + zch_size=1000, lru=feature_pb2.LRU_EvictionPolicy() + ) + else: + cat_a_kwargs["num_buckets"] = 100 + cat_b_kwargs["num_buckets"] = 1000 feature_cfgs = [ feature_pb2.FeatureConfig( id_feature=feature_pb2.IdFeature( - feature_name="cat_a", embedding_dim=16, num_buckets=100 + feature_name="cat_a", embedding_dim=16, **cat_a_kwargs ) ), feature_pb2.FeatureConfig( id_feature=feature_pb2.IdFeature( - feature_name="cat_b", embedding_dim=8, num_buckets=1000 + feature_name="cat_b", embedding_dim=8, **cat_b_kwargs ) ), feature_pb2.FeatureConfig( @@ -56,22 +68,34 @@ def _create_test_features(): return features -def _create_test_sequence_features(): +def _create_test_sequence_features(has_zch=False): + cat_a_kwargs = {} + cat_b_kwargs = {} + if has_zch: + cat_a_kwargs["zch"] = feature_pb2.ZeroCollisionHash( + zch_size=100, lfu=feature_pb2.LFU_EvictionPolicy() + ) + cat_b_kwargs["zch"] = feature_pb2.ZeroCollisionHash( + zch_size=1000, lru=feature_pb2.LRU_EvictionPolicy() + ) + else: + cat_a_kwargs["num_buckets"] = 100 + cat_b_kwargs["num_buckets"] = 1000 feature_cfgs = [ feature_pb2.FeatureConfig( id_feature=feature_pb2.IdFeature( feature_name="cat_a", embedding_dim=16, - num_buckets=100, expression="item:cat_a", + **cat_a_kwargs, ) ), feature_pb2.FeatureConfig( id_feature=feature_pb2.IdFeature( feature_name="cat_b", embedding_dim=8, - num_buckets=1000, expression="item:cat_b", + **cat_b_kwargs, ) ), feature_pb2.FeatureConfig( @@ -88,7 +112,7 @@ def _create_test_sequence_features(): feature_name="cat_a", expression="item:cat_a", embedding_dim=16, - num_buckets=100, + **cat_a_kwargs, ) ), feature_pb2.SeqFeatureConfig( @@ -96,7 +120,7 @@ def _create_test_sequence_features(): feature_name="cat_b", expression="item:cat_b", embedding_dim=8, - num_buckets=1000, + **cat_b_kwargs, ) ), feature_pb2.SeqFeatureConfig( @@ -117,7 +141,7 @@ def _create_test_sequence_features(): feature_name="cat_a", expression="item:cat_a", embedding_dim=16, - num_buckets=100, + **cat_b_kwargs, ) ), feature_pb2.SeqFeatureConfig( @@ -172,7 +196,7 @@ def tearDown(self): [[TestGraphType.NORMAL], [TestGraphType.FX_TRACE], [TestGraphType.JIT_SCRIPT]] ) def test_embedding_group_impl(self, graph_type) -> None: - features = _create_test_sequence_features() + features = _create_test_features() feature_groups = [ model_pb2.FeatureGroupConfig( group_name="wide", @@ -183,27 +207,6 @@ def test_embedding_group_impl(self, graph_type) -> None: group_name="deep", feature_names=["cat_a", "cat_b", "int_a"], group_type=model_pb2.FeatureGroupType.DEEP, - sequence_groups=[ - model_pb2.SeqGroupConfig( - group_name="click", - feature_names=[ - "cat_a", - "cat_b", - "int_a", - "click_seq__cat_a", - "click_seq__cat_b", - "click_seq__int_a", - ], - ), - ], - sequence_encoders=[ - seq_encoder_pb2.SeqEncoderConfig( - din_encoder=seq_encoder_pb2.DINEncoder( - input="click", - attn_mlp=module_pb2.MLP(hidden_units=[128, 64]), - ) - ), - ], ), ] embedding_group = EmbeddingGroupImpl( @@ -221,20 +224,12 @@ def test_embedding_group_impl(self, graph_type) -> None: self.assertDictEqual( embedding_group.group_feature_dims("deep"), deep_feature_dims ) - embedding_group = create_test_module(embedding_group, graph_type) sparse_feature = KeyedJaggedTensor.from_lengths_sync( - keys=[ - "cat_a", - "cat_b", - "click_seq__cat_a", - "click_seq__cat_b", - "buy_seq__cat_a", - "buy_seq__cat_b", - ], - values=torch.tensor(list(range(24))), - lengths=torch.tensor([1, 1, 1, 1, 3, 3, 3, 3, 2, 2, 2, 2]), + keys=["cat_a", "cat_b"], + values=torch.tensor([1, 2, 3, 4, 5, 6, 7]), + lengths=torch.tensor([1, 2, 1, 3]), ) dense_feature = KeyedTensor.from_tensor_list( keys=["int_a"], tensors=[torch.tensor([[0.2], [0.3]])] @@ -252,8 +247,58 @@ def test_embedding_group_impl(self, graph_type) -> None: @parameterized.expand( [[TestGraphType.NORMAL], [TestGraphType.FX_TRACE], [TestGraphType.JIT_SCRIPT]] ) - def test_sequence_embedding_group_impl(self, graph_type) -> None: - features = _create_test_sequence_features() + def test_zch_embedding_group_impl(self, graph_type) -> None: + features = _create_test_features(has_zch=True) + # TODO(hongsheng.jhs) zch not support wide group now. + feature_groups = [ + model_pb2.FeatureGroupConfig( + group_name="deep", + feature_names=["cat_a", "cat_b", "int_a"], + group_type=model_pb2.FeatureGroupType.DEEP, + ), + ] + embedding_group = EmbeddingGroupImpl( + features, feature_groups, device=torch.device("cpu") + ) + self.assertEqual(embedding_group.group_dims("deep"), [16, 8, 1]) + self.assertEqual(embedding_group.group_total_dim("deep"), 25) + deep_feature_dims = OrderedDict({"cat_a": 16, "cat_b": 8, "int_a": 1}) + self.assertDictEqual( + embedding_group.group_feature_dims("deep"), deep_feature_dims + ) + + if graph_type != TestGraphType.NORMAL: + embedding_group.eval() + embedding_group = create_test_module(embedding_group, graph_type) + + sparse_feature = KeyedJaggedTensor.from_lengths_sync( + keys=["cat_a", "cat_b"], + values=torch.tensor([1, 2, 3, 4, 5, 6, 7]), + lengths=torch.tensor([1, 2, 1, 3]), + ) + dense_feature = KeyedTensor.from_tensor_list( + keys=["int_a"], tensors=[torch.tensor([[0.2], [0.3]])] + ) + result = embedding_group( + sparse_feature, + dense_feature, + EMPTY_KJT, + EMPTY_KT, + ) + self.assertEqual(result["deep"].size(), (2, 25)) + + @parameterized.expand( + [ + [TestGraphType.NORMAL, False], + [TestGraphType.FX_TRACE, False], + [TestGraphType.JIT_SCRIPT, False], + [TestGraphType.NORMAL, True], + [TestGraphType.FX_TRACE, True], + [TestGraphType.JIT_SCRIPT, True], + ] + ) + def test_sequence_embedding_group_impl(self, graph_type, has_zch=False) -> None: + features = _create_test_sequence_features(has_zch) feature_groups = [ model_pb2.FeatureGroupConfig( group_name="click", @@ -341,6 +386,8 @@ def test_sequence_embedding_group_impl(self, graph_type) -> None: embedding_group.group_total_dim("deep___click_no_query.query"), 0 ) + if has_zch and graph_type != TestGraphType.NORMAL: + embedding_group = embedding_group.eval() embedding_group = create_test_module(embedding_group, graph_type) sparse_feature = KeyedJaggedTensor.from_lengths_sync( @@ -563,11 +610,20 @@ def test_embedding_group(self, graph_type) -> None: self.assertEqual(result["buy.sequence_length"].size(), (2,)) @parameterized.expand( - [[TestGraphType.NORMAL], [TestGraphType.FX_TRACE], [TestGraphType.JIT_SCRIPT]] + [ + [TestGraphType.NORMAL, False], + [TestGraphType.FX_TRACE, False], + [TestGraphType.JIT_SCRIPT, False], + [TestGraphType.NORMAL, True], + [TestGraphType.FX_TRACE, True], + [TestGraphType.JIT_SCRIPT, True], + ] ) - def test_sequence_embedding_group_impl_input_tile(self, graph_type) -> None: + def test_sequence_embedding_group_impl_input_tile( + self, graph_type, has_zch=False + ) -> None: os.environ["INPUT_TILE"] = "2" - features = _create_test_sequence_features() + features = _create_test_sequence_features(has_zch=has_zch) feature_groups = [ model_pb2.FeatureGroupConfig( group_name="click", @@ -599,6 +655,8 @@ def test_sequence_embedding_group_impl_input_tile(self, graph_type) -> None: self.assertEqual(embedding_group.group_total_dim("buy.sequence"), 17) self.assertEqual(embedding_group.group_total_dim("buy.query"), 17) + if has_zch and graph_type != TestGraphType.NORMAL: + embedding_group = embedding_group.eval() embedding_group = create_test_module(embedding_group, graph_type) sparse_feature = KeyedJaggedTensor.from_lengths_sync( @@ -638,11 +696,20 @@ def test_sequence_embedding_group_impl_input_tile(self, graph_type) -> None: self.assertEqual(result["buy.sequence_length"].size(), (2,)) @parameterized.expand( - [[TestGraphType.NORMAL], [TestGraphType.FX_TRACE], [TestGraphType.JIT_SCRIPT]] + [ + [TestGraphType.NORMAL, False], + [TestGraphType.FX_TRACE, False], + [TestGraphType.JIT_SCRIPT, False], + [TestGraphType.NORMAL, True], + [TestGraphType.FX_TRACE, True], + [TestGraphType.JIT_SCRIPT, True], + ] ) - def test_sequence_embedding_group_impl_input_tile_emb(self, graph_type) -> None: + def test_sequence_embedding_group_impl_input_tile_emb( + self, graph_type, has_zch=False + ) -> None: os.environ["INPUT_TILE"] = "3" - features = _create_test_sequence_features() + features = _create_test_sequence_features(has_zch=has_zch) feature_groups = [ model_pb2.FeatureGroupConfig( group_name="click", @@ -674,6 +741,8 @@ def test_sequence_embedding_group_impl_input_tile_emb(self, graph_type) -> None: self.assertEqual(embedding_group.group_total_dim("buy.sequence"), 17) self.assertEqual(embedding_group.group_total_dim("buy.query"), 17) + if has_zch and graph_type != TestGraphType.NORMAL: + embedding_group = embedding_group.eval() embedding_group = create_test_module(embedding_group, graph_type) sparse_feature = KeyedJaggedTensor.from_lengths_sync( diff --git a/tzrec/protos/feature.proto b/tzrec/protos/feature.proto index c2d8a30..d02a4bb 100644 --- a/tzrec/protos/feature.proto +++ b/tzrec/protos/feature.proto @@ -2,6 +2,51 @@ syntax = "proto2"; package tzrec.protos; +// LFU: evict_score = access_cnt +message LFU_EvictionPolicy { + // lambda function string used to filter incoming ids before update/eviction. experimental feature. + // [input: Tensor] the function takes as input a 1-d tensor of unique id counts. + // [output1: Tensor] the function returns a boolean_mask or index array of corresponding elements in the input tensor that pass the filter. + // [output2: float, Tensor] the function returns the threshold that will be used to filter ids before update/eviction. all values <= this value will be filtered out. + optional string threshold_filtering_func = 1; +} + +// LRU: evict_score = 1 / pow((current_iter - last_access_iter), decay_exponent) +message LRU_EvictionPolicy { + // decay rate is access step + optional float decay_exponent = 1 [default = 1.0]; + // lambda function used to filter incoming ids before update/eviction. experimental feature. + // [input: Tensor] the function takes as input a 1-d tensor of unique id counts. + // [output1: Tensor] the function returns a boolean_mask or index array of corresponding elements in the input tensor that pass the filter. + // [output2: float, Tensor] the function returns the threshold that will be used to filter ids before update/eviction. all values <= this value will be filtered out. + optional string threshold_filtering_func = 2; +} + +// DistanceLFU: evict_score = access_cnt / pow((current_iter - last_access_iter), decay_exponent) +message DistanceLFU_EvictionPolicy { + // decay rate is access step + optional float decay_exponent = 1 [default = 1.0]; + // lambda function string used to filter incoming ids before update/eviction. experimental feature. + // [input: Tensor] the function takes as input a 1-d tensor of unique id counts. + // [output1: Tensor] the function returns a boolean_mask or index array of corresponding elements in the input tensor that pass the filter. + // [output2: float, Tensor] the function returns the threshold that will be used to filter ids before update/eviction. all values <= this value will be filtered out. + optional string threshold_filtering_func = 2; +} + +message ZeroCollisionHash { + // zero collision size + required uint64 zch_size = 1; + // evict interval steps + optional uint32 eviction_interval = 2 [default = 5]; + // evict policy + oneof eviction_policy { + LFU_EvictionPolicy lfu = 101; + LRU_EvictionPolicy lru = 102; + DistanceLFU_EvictionPolicy distance_lfu = 103; + } +} + + message IdFeature { // feature name, e.g. item_id required string feature_name = 1; @@ -34,6 +79,8 @@ message IdFeature { optional string init_fn = 14; // mask value in training progress optional bool use_mask = 15; + // zero collision hash + optional ZeroCollisionHash zch = 16; // default value when fg_encoded = true, // when use pai-fg, you do not need to set the param. @@ -108,6 +155,8 @@ message ComboFeature { optional string init_fn = 13; // mask value in training progress optional bool use_mask = 14; + // zero collision hash + optional ZeroCollisionHash zch = 15; // default value when fg_encoded = true, // when use pai-fg, you do not need to set the param. @@ -166,6 +215,8 @@ message LookupFeature { // mask value in training progress optional bool use_mask = 25; + // zero collision hash + optional ZeroCollisionHash zch = 26; // default value when fg_encoded = true, // when use pai-fg, you do not need to set the param. @@ -226,6 +277,8 @@ message MatchFeature { optional uint32 value_dim = 20; // mask value in training progress optional bool use_mask = 25; + // zero collision hash + optional ZeroCollisionHash zch = 26; // default value when fg_encoded = true, // when use pai-fg, you do not need to set the param. @@ -393,6 +446,8 @@ message SequenceIdFeature { optional uint32 value_dim = 15 [default = 1]; // mask value in training progress optional bool use_mask = 20; + // zero collision hash + optional ZeroCollisionHash zch = 21; // default value when fg_encoded = true, // when use pai-fg, you do not need to set the param. diff --git a/tzrec/tests/configs/multi_tower_din_zch_fg_mock.config b/tzrec/tests/configs/multi_tower_din_zch_fg_mock.config new file mode 100644 index 0000000..0789a39 --- /dev/null +++ b/tzrec/tests/configs/multi_tower_din_zch_fg_mock.config @@ -0,0 +1,358 @@ +train_input_path: "" +eval_input_path: "" +model_dir: "experiments/multi_tower_din_mock_fg" +train_config { + sparse_optimizer { + adagrad_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + dense_optimizer { + adam_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + num_epochs: 1 +} +eval_config { +} +data_config { + batch_size: 8192 + dataset_type: ParquetDataset + fg_encoded: false + label_fields: "clk" + num_workers: 8 +} +feature_configs { + id_feature { + feature_name: "user_id" + expression: "user:user_id" + zch { + zch_size: 1000000 + eviction_interval: 2 + lfu {} + } + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "item_id" + expression: "item:item_id" + zch { + zch_size: 10000 + eviction_interval: 2 + lru {} + } + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "id_3" + expression: "item:id_3" + vocab_list: ["a", "b", "c"] + embedding_dim: 8 + } +} +feature_configs { + id_feature { + feature_name: "id_4" + expression: "item:id_4" + hash_bucket_size: 100 + embedding_dim: 16 + embedding_name: "id_4_emb" + } +} +feature_configs { + id_feature { + feature_name: "id_5" + expression: "item:id_5" + hash_bucket_size: 100 + embedding_dim: 16 + embedding_name: "id_4_emb" + } +} +feature_configs { + raw_feature { + feature_name: "raw_1" + expression: "item:raw_1" + boundaries: [0.1, 0.2, 0.3, 0.4] + embedding_dim: 16 + } +} +feature_configs { + raw_feature { + feature_name: "raw_2" + expression: "item:raw_2" + } +} +feature_configs { + raw_feature { + feature_name: "raw_3" + expression: "user:raw_3" + value_dim: 4 + } +} +feature_configs { + raw_feature { + feature_name: "raw_4" + expression: "user:raw_4" + value_dim: 4 + boundaries: [0.1, 0.2, 0.3, 0.4] + embedding_dim: 16 + } +} +feature_configs { + raw_feature { + feature_name: "raw_5" + expression: "user:raw_5" + } +} +feature_configs { + raw_feature { + feature_name: "raw_6_id" + expression: "item:raw_6" + boundaries: [0.1, 0.2, 0.3, 0.4] + embedding_dim: 16 + } +} +feature_configs { + combo_feature { + feature_name: "combo_1" + expression: ["user:id_1", "item:id_2"] + hash_bucket_size: 1000000 + embedding_dim: 16 + } +} +feature_configs { + lookup_feature { + feature_name: "lookup_1" + map: "user:map_1" + key: "item:id_2" + } +} +feature_configs { + lookup_feature { + feature_name: "lookup_2" + map: "user:map_2" + key: "item:id_2" + hash_bucket_size: 10000 + embedding_dim: 8 + } +} +feature_configs { + lookup_feature { + feature_name: "lookup_3" + map: "user:map_3" + key: "item:id_2" + num_buckets: 1000 + embedding_dim: 8 + } +} +feature_configs { + lookup_feature { + feature_name: "lookup_4" + map: "user:map_4" + key: "item:id_2" + vocab_list: ["e", "f", "g"] + embedding_dim: 16 + } +} +feature_configs { + lookup_feature { + feature_name: "lookup_5" + map: "user:map_5" + key: "feature:raw_6_id" + vocab_list: ["e", "f", "g"] + embedding_dim: 16 + } +} +feature_configs { + match_feature { + feature_name: "match_1" + nested_map: "user:nested_map_1" + pkey: "item:id_2" + skey: "item:id_3" + } +} +feature_configs { + match_feature { + feature_name: "match_2" + nested_map: "user:nested_map_2" + pkey: "item:id_2" + skey: "item:id_3" + hash_bucket_size: 100000 + embedding_dim: 16 + } +} +feature_configs { + match_feature { + feature_name: "match_3" + nested_map: "user:nested_map_3" + pkey: "item:id_2" + skey: "item:id_3" + num_buckets: 10000 + embedding_dim: 8 + } +} +feature_configs { + match_feature { + feature_name: "match_4" + nested_map: "user:nested_map_4" + pkey: "item:id_2" + skey: "item:id_3" + vocab_list: ["e", "f", "g"] + embedding_dim: 16 + } +} +feature_configs { + expr_feature { + feature_name: "expr_1" + expression: "raw_1 + raw_2" + variables: ["item:raw_1", "item:raw_2"] + } +} +feature_configs { + sequence_feature { + sequence_name: "click_50_seq" + sequence_length: 50 + sequence_delim: "|" + features { + id_feature { + feature_name: "item_id" + expression: "item:item_id" + zch { + zch_size: 10000 + eviction_interval: 2 + lru {} + } + embedding_dim: 16 + } + } + features { + id_feature { + feature_name: "id_3" + expression: "item:id_3" + vocab_list: ["a", "b", "c"] + embedding_dim: 8 + } + } + features { + raw_feature { + feature_name: "raw_1" + expression: "item:raw_1" + boundaries: [0.1, 0.2, 0.3, 0.4] + embedding_dim: 16 + } + } + features { + raw_feature { + feature_name: "raw_2" + expression: "item:raw_2" + } + } + } +} +feature_configs { + sequence_id_feature { + feature_name: "buy_50_user_id_seq" + sequence_length: 50 + sequence_delim: "|" + expression: "item:buy_50_user_id_seq" + zch { + zch_size: 1000000 + eviction_interval: 2 + lfu {} + } + embedding_dim: 16 + } +} +feature_configs { + sequence_raw_feature { + feature_name: "buy_50_raw_5_seq" + sequence_length: 50 + sequence_delim: "|" + expression: "item:buy_50_raw_5_seq" + } +} +model_config { + feature_groups { + group_name: "deep" + feature_names: "user_id" + feature_names: "item_id" + feature_names: "id_3" + feature_names: "id_4" + feature_names: "id_5" + feature_names: "raw_1" + feature_names: "raw_2" + feature_names: "raw_3" + feature_names: "raw_4" + feature_names: "raw_5" + feature_names: "combo_1" + feature_names: "lookup_1" + feature_names: "lookup_2" + feature_names: "lookup_3" + feature_names: "lookup_4" + feature_names: "lookup_5" + feature_names: "match_1" + feature_names: "match_2" + feature_names: "match_3" + feature_names: "match_4" + feature_names: "expr_1" + group_type: DEEP + } + feature_groups { + group_name: "seq" + feature_names: "item_id" + feature_names: "id_3" + feature_names: "raw_1" + feature_names: "raw_2" + feature_names: "click_50_seq__item_id" + feature_names: "click_50_seq__id_3" + feature_names: "click_50_seq__raw_1" + feature_names: "click_50_seq__raw_2" + group_type: SEQUENCE + } + feature_groups { + group_name: "seq_item" + feature_names: "user_id" + feature_names: "raw_5" + feature_names: "buy_50_user_id_seq" + feature_names: "buy_50_raw_5_seq" + group_type: SEQUENCE + } + multi_tower_din { + towers { + input: 'deep' + mlp { + hidden_units: [512, 256, 128] + } + } + din_towers { + input: 'seq' + attn_mlp { + hidden_units: [256, 64] + } + } + din_towers { + input: 'seq_item' + attn_mlp { + hidden_units: [256, 64] + } + } + final { + hidden_units: [64] + } + } + metrics { + auc {} + } + losses { + binary_cross_entropy {} + } +} diff --git a/tzrec/tests/match_integration_test.py b/tzrec/tests/match_integration_test.py new file mode 100644 index 0000000..442c51a --- /dev/null +++ b/tzrec/tests/match_integration_test.py @@ -0,0 +1,256 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import tempfile +import unittest + +from tzrec.tests import utils + + +class MatchIntegrationTest(unittest.TestCase): + def setUp(self): + self.success = False + if not os.path.exists("./tmp"): + os.makedirs("./tmp") + self.test_dir = tempfile.mkdtemp(prefix="tzrec_", dir="./tmp") + os.chmod(self.test_dir, 0o755) + + def tearDown(self): + if self.success: + if os.path.exists(self.test_dir): + shutil.rmtree(self.test_dir) + + def test_dssm_fg_encoded_train_eval_export(self): + self.success = utils.test_train_eval( + "tzrec/tests/configs/dssm_mock.config", self.test_dir, item_id="item_id" + ) + if self.success: + self.success = utils.test_eval( + os.path.join(self.test_dir, "pipeline.config"), self.test_dir + ) + if self.success: + self.success = utils.test_export( + os.path.join(self.test_dir, "pipeline.config"), self.test_dir + ) + if self.success: + self.success = utils.test_predict( + scripted_model_path=os.path.join(self.test_dir, "export/item"), + predict_input_path=os.path.join(self.test_dir, r"eval_data/\*.parquet"), + predict_output_path=os.path.join(self.test_dir, "predict_result"), + reserved_columns="item_id", + output_columns="item_tower_emb", + test_dir=self.test_dir, + ) + self.assertTrue(self.success) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "train/model.ckpt-63")) + ) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "export/user/scripted_model.pt")) + ) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "export/item/scripted_model.pt")) + ) + + def test_dssm_fg_encoded_variational_dropout(self): + self.success = utils.test_train_eval( + "tzrec/tests/configs/dssm_variational_dropout_mock.config", + self.test_dir, + item_id="item_id", + ) + if self.success: + self.success = utils.test_feature_selection( + os.path.join(self.test_dir, "pipeline.config"), self.test_dir + ) + self.assertTrue(self.success) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "train/model.ckpt-63")) + ) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "output_dir/pipeline.config")) + ) + + def test_dssm_with_fg_train_eval_export(self): + self.success = utils.test_train_eval( + "tzrec/tests/configs/dssm_fg_mock.config", + self.test_dir, + user_id="user_id", + item_id="item_id", + ) + if self.success: + self.success = utils.test_eval( + os.path.join(self.test_dir, "pipeline.config"), self.test_dir + ) + if self.success: + self.success = utils.test_export( + os.path.join(self.test_dir, "pipeline.config"), self.test_dir + ) + if self.success: + self.success = utils.test_predict( + scripted_model_path=os.path.join(self.test_dir, "export/item"), + predict_input_path=os.path.join(self.test_dir, r"item_data/\*.parquet"), + predict_output_path=os.path.join(self.test_dir, "item_emb"), + reserved_columns="item_id", + output_columns="item_tower_emb", + test_dir=self.test_dir, + ) + if self.success: + self.success = utils.test_create_faiss_index( + embedding_input_path=os.path.join( + self.test_dir, r"item_emb/\*.parquet" + ), + index_output_dir=os.path.join(self.test_dir, "export/user"), + id_field="item_id", + embedding_field="item_tower_emb", + test_dir=self.test_dir, + ) + if self.success: + self.success = utils.test_predict( + scripted_model_path=os.path.join(self.test_dir, "export/user"), + predict_input_path=os.path.join(self.test_dir, r"user_data/\*.parquet"), + predict_output_path=os.path.join(self.test_dir, "user_emb"), + reserved_columns="user_id,click_50_seq__item_id", + output_columns="user_tower_emb", + test_dir=self.test_dir, + ) + if self.success: + self.success = utils.test_hitrate( + user_gt_input=os.path.join(self.test_dir, r"user_emb/\*.parquet"), + item_embedding_input=os.path.join( + self.test_dir, r"item_emb/\*.parquet" + ), + total_hitrate_output=os.path.join(self.test_dir, "total_hitrate"), + hitrate_details_output=os.path.join(self.test_dir, "hitrate_details"), + request_id_field="user_id", + gt_items_field="click_50_seq__item_id", + test_dir=self.test_dir, + ) + if self.success: + self.success = utils.test_create_fg_json( + os.path.join(self.test_dir, "pipeline.config"), + fg_output_dir=os.path.join(self.test_dir, "fg_output"), + reserves="clk", + test_dir=self.test_dir, + ) + self.assertTrue(self.success) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "export/user/scripted_model.pt")) + ) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "export/user/faiss_index")) + ) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "export/user/id_mapping")) + ) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "export/item/scripted_model.pt")) + ) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "fg_output/fg.json")) + ) + self.assertTrue( + os.path.exists( + os.path.join(self.test_dir, "fg_output/item_title_tokenizer.json") + ) + ) + + def test_dssm_v2_with_fg_train_eval_export(self): + self.success = utils.test_train_eval( + "tzrec/tests/configs/dssm_v2_fg_mock.config", + self.test_dir, + user_id="user_id", + item_id="item_id", + ) + if self.success: + self.success = utils.test_eval( + os.path.join(self.test_dir, "pipeline.config"), self.test_dir + ) + if self.success: + self.success = utils.test_export( + os.path.join(self.test_dir, "pipeline.config"), self.test_dir + ) + self.assertTrue(self.success) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "export/user/scripted_model.pt")) + ) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "export/item/scripted_model.pt")) + ) + + def test_tdm_train_eval_export(self): + self.success = utils.test_train_eval( + "tzrec/tests/configs/tdm_fg_mock.config", + self.test_dir, + user_id="user_id", + item_id="item_id", + cate_id="cate_id", + ) + if self.success: + self.success = utils.test_eval( + os.path.join(self.test_dir, "pipeline.config"), self.test_dir + ) + if self.success: + self.success = utils.test_export( + os.path.join(self.test_dir, "pipeline.config"), + self.test_dir, + asset_files=os.path.join(self.test_dir, "init_tree/serving_tree"), + ) + if self.success: + self.success = utils.test_predict( + scripted_model_path=os.path.join(self.test_dir, "export/embedding"), + predict_input_path=os.path.join(self.test_dir, r"item_data/\*.parquet"), + predict_output_path=os.path.join(self.test_dir, "item_emb"), + reserved_columns="item_id,cate_id,id_4,id_5,raw_1,raw_2", + output_columns="item_emb", + test_dir=self.test_dir, + ) + if self.success: + self.success = utils.test_tdm_cluster_train_eval( + pipeline_config_path=os.path.join(self.test_dir, "pipeline.config"), + test_dir=self.test_dir, + item_input_path=os.path.join(self.test_dir, r"item_emb/\*.parquet"), + item_id="item_id", + embedding_field="item_emb", + ) + if self.success: + self.success = utils.test_tdm_retrieval( + scripted_model_path=os.path.join(self.test_dir, "export/model"), + eval_data_path=os.path.join(self.test_dir, r"eval_data/\*.parquet"), + retrieval_output_path=os.path.join(self.test_dir, "retrieval_result"), + reserved_columns="user_id,item_id", + test_dir=self.test_dir, + ) + + self.assertTrue(self.success) + self.assertTrue(os.path.exists(os.path.join(self.test_dir, "learnt_tree"))) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "train/eval_result.txt")) + ) + self.assertTrue( + os.path.exists( + os.path.join(self.test_dir, "export/embedding/scripted_model.pt") + ) + ) + self.assertTrue( + os.path.exists( + os.path.join(self.test_dir, "export/model/scripted_model.pt") + ) + ) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "export/model/serving_tree")) + ) + self.assertTrue(os.path.exists(os.path.join(self.test_dir, "retrieval_result"))) + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/tests/train_eval_export_test.py b/tzrec/tests/rank_integration_test.py similarity index 72% rename from tzrec/tests/train_eval_export_test.py rename to tzrec/tests/rank_integration_test.py index 5e0ea15..d9df533 100644 --- a/tzrec/tests/train_eval_export_test.py +++ b/tzrec/tests/rank_integration_test.py @@ -31,7 +31,7 @@ from tzrec.utils.test_util import dfs_are_close -class TrainEvalExportTest(unittest.TestCase): +class RankIntegrationTest(unittest.TestCase): def setUp(self): self.success = False if not os.path.exists("./tmp"): @@ -89,9 +89,12 @@ def test_multi_tower_din_fg_encoded_finetune(self): ) self.assertTrue(self.success) - def test_dssm_fg_encoded_train_eval_export(self): + def _test_rank_with_fg(self, pipeline_config_path, comp_cpu_gpu_pred_result=False): self.success = utils.test_train_eval( - "tzrec/tests/configs/dssm_mock.config", self.test_dir, item_id="item_id" + pipeline_config_path, + self.test_dir, + user_id="user_id", + item_id="item_id", ) if self.success: self.success = utils.test_eval( @@ -103,65 +106,20 @@ def test_dssm_fg_encoded_train_eval_export(self): ) if self.success: self.success = utils.test_predict( - scripted_model_path=os.path.join(self.test_dir, "export/item"), + scripted_model_path=os.path.join(self.test_dir, "export"), predict_input_path=os.path.join(self.test_dir, r"eval_data/\*.parquet"), predict_output_path=os.path.join(self.test_dir, "predict_result"), - reserved_columns="item_id", - output_columns="item_tower_emb", + reserved_columns="user_id,item_id", test_dir=self.test_dir, ) self.assertTrue(self.success) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "train/model.ckpt-63")) - ) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "export/user/scripted_model.pt")) - ) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "export/item/scripted_model.pt")) - ) - - def test_dssm_fg_encoded_variational_dropout(self): - self.success = utils.test_train_eval( - "tzrec/tests/configs/dssm_variational_dropout_mock.config", - self.test_dir, - item_id="item_id", - ) - if self.success: - self.success = utils.test_feature_selection( - os.path.join(self.test_dir, "pipeline.config"), self.test_dir - ) - self.assertTrue(self.success) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "train/model.ckpt-63")) - ) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "output_dir/pipeline.config")) - ) - - def test_multi_tower_din_with_fg_train_eval_export(self): - self.success = utils.test_train_eval( - "tzrec/tests/configs/multi_tower_din_fg_mock.config", - self.test_dir, - user_id="user_id", - item_id="item_id", - ) - if self.success: - self.success = utils.test_eval( - os.path.join(self.test_dir, "pipeline.config"), self.test_dir - ) - if self.success: - self.success = utils.test_export( - os.path.join(self.test_dir, "pipeline.config"), self.test_dir - ) - self.assertTrue(self.success) self.assertTrue( os.path.exists(os.path.join(self.test_dir, "train/eval_result.txt")) ) self.assertTrue( os.path.exists(os.path.join(self.test_dir, "export/scripted_model.pt")) ) - if torch.cuda.is_available(): + if comp_cpu_gpu_pred_result and torch.cuda.is_available(): pipeline_config = config_util.load_pipeline_config( os.path.join(self.test_dir, "pipeline.config") ) @@ -187,120 +145,16 @@ def test_multi_tower_din_with_fg_train_eval_export(self): os.path.join(self.test_dir, "export/scripted_model.pt"), map_location=device, ) - result_gpu = model_gpu(data.to_dict(sparse_dtype=torch.int32), device) + result_gpu = model_gpu(data.to_dict(sparse_dtype=torch.int64), device) for k, v in result_gpu.items(): torch.testing.assert_close( result_cpu[k].to(device), v, rtol=5e-3, atol=1e-5 ) - def test_dssm_with_fg_train_eval_export(self): - self.success = utils.test_train_eval( - "tzrec/tests/configs/dssm_fg_mock.config", - self.test_dir, - user_id="user_id", - item_id="item_id", - ) - if self.success: - self.success = utils.test_eval( - os.path.join(self.test_dir, "pipeline.config"), self.test_dir - ) - if self.success: - self.success = utils.test_export( - os.path.join(self.test_dir, "pipeline.config"), self.test_dir - ) - if self.success: - self.success = utils.test_predict( - scripted_model_path=os.path.join(self.test_dir, "export/item"), - predict_input_path=os.path.join(self.test_dir, r"item_data/\*.parquet"), - predict_output_path=os.path.join(self.test_dir, "item_emb"), - reserved_columns="item_id", - output_columns="item_tower_emb", - test_dir=self.test_dir, - ) - if self.success: - self.success = utils.test_create_faiss_index( - embedding_input_path=os.path.join( - self.test_dir, r"item_emb/\*.parquet" - ), - index_output_dir=os.path.join(self.test_dir, "export/user"), - id_field="item_id", - embedding_field="item_tower_emb", - test_dir=self.test_dir, - ) - if self.success: - self.success = utils.test_predict( - scripted_model_path=os.path.join(self.test_dir, "export/user"), - predict_input_path=os.path.join(self.test_dir, r"user_data/\*.parquet"), - predict_output_path=os.path.join(self.test_dir, "user_emb"), - reserved_columns="user_id,click_50_seq__item_id", - output_columns="user_tower_emb", - test_dir=self.test_dir, - ) - if self.success: - self.success = utils.test_hitrate( - user_gt_input=os.path.join(self.test_dir, r"user_emb/\*.parquet"), - item_embedding_input=os.path.join( - self.test_dir, r"item_emb/\*.parquet" - ), - total_hitrate_output=os.path.join(self.test_dir, "total_hitrate"), - hitrate_details_output=os.path.join(self.test_dir, "hitrate_details"), - request_id_field="user_id", - gt_items_field="click_50_seq__item_id", - test_dir=self.test_dir, - ) - if self.success: - self.success = utils.test_create_fg_json( - os.path.join(self.test_dir, "pipeline.config"), - fg_output_dir=os.path.join(self.test_dir, "fg_output"), - reserves="clk", - test_dir=self.test_dir, - ) - self.assertTrue(self.success) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "export/user/scripted_model.pt")) - ) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "export/user/faiss_index")) - ) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "export/user/id_mapping")) - ) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "export/item/scripted_model.pt")) - ) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "fg_output/fg.json")) - ) - self.assertTrue( - os.path.exists( - os.path.join(self.test_dir, "fg_output/item_title_tokenizer.json") - ) - ) - - def test_dssm_v2_with_fg_train_eval_export(self): - self.success = utils.test_train_eval( - "tzrec/tests/configs/dssm_v2_fg_mock.config", - self.test_dir, - user_id="user_id", - item_id="item_id", - ) - if self.success: - self.success = utils.test_eval( - os.path.join(self.test_dir, "pipeline.config"), self.test_dir - ) - if self.success: - self.success = utils.test_export( - os.path.join(self.test_dir, "pipeline.config"), self.test_dir - ) - self.assertTrue(self.success) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "export/user/scripted_model.pt")) - ) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "export/item/scripted_model.pt")) - ) - - def test_multi_tower_din_with_fg_train_eval_export_input_tile(self): + def _test_rank_with_fg_input_tile( + self, + pipeline_config_path, + ): self.success = utils.test_train_eval( "tzrec/tests/configs/multi_tower_din_fg_mock.config", self.test_dir, @@ -463,7 +317,7 @@ def test_multi_tower_din_with_fg_train_eval_export_input_tile(self): os.path.join(self.test_dir, "export/scripted_model.pt"), map_location=device, ) - result_gpu = model_gpu(data.to_dict(sparse_dtype=torch.int32), device) + result_gpu = model_gpu(data.to_dict(sparse_dtype=torch.int64), device) result_dict_json_path = os.path.join(self.test_dir, "result_gpu.json") utils.save_predict_result_json(result_gpu, result_dict_json_path) for k, v in result_gpu.items(): @@ -477,7 +331,7 @@ def test_multi_tower_din_with_fg_train_eval_export_input_tile(self): map_location=device, ) result_gpu_no_quant = model_gpu_no_quant( - data.to_dict(sparse_dtype=torch.int32), device + data.to_dict(sparse_dtype=torch.int64), device ) result_dict_json_path = os.path.join( self.test_dir, "result_gpu_no_quant.json" @@ -499,7 +353,7 @@ def test_multi_tower_din_with_fg_train_eval_export_input_tile(self): iterator_input_tile = iter(dataloader) data_input_tile = next(iterator_input_tile) result_gpu_input_tile = model_gpu_input_tile( - data_input_tile.to(device=device).to_dict(sparse_dtype=torch.int32), + data_input_tile.to(device=device).to_dict(sparse_dtype=torch.int64), device, ) result_dict_json_path = os.path.join( @@ -513,7 +367,7 @@ def test_multi_tower_din_with_fg_train_eval_export_input_tile(self): map_location=device, ) result_gpu_input_tile_emb = model_gpu_input_tile_emb( - data_input_tile.to(device=device).to_dict(sparse_dtype=torch.int32), + data_input_tile.to(device=device).to_dict(sparse_dtype=torch.int64), device, ) @@ -531,35 +385,29 @@ def test_multi_tower_din_with_fg_train_eval_export_input_tile(self): result_gpu_input_tile[k].to(device), v, rtol=1e-4, atol=1e-4 ) - def test_dbmtl_has_sequence_train_eval_export(self): - self.success = utils.test_train_eval( - "tzrec/tests/configs/dbmtl_has_sequence_mock.config", - self.test_dir, + def test_multi_tower_din_with_fg_train_eval_export(self): + self._test_rank_with_fg( + "tzrec/tests/configs/multi_tower_din_fg_mock.config", + comp_cpu_gpu_pred_result=True, ) - if self.success: - self.success = utils.test_eval( - os.path.join(self.test_dir, "pipeline.config"), self.test_dir - ) - if self.success: - self.success = utils.test_export( - os.path.join(self.test_dir, "pipeline.config"), self.test_dir - ) - if self.success: - self.success = utils.test_predict( - scripted_model_path=os.path.join(self.test_dir, "export"), - predict_input_path=os.path.join(self.test_dir, r"eval_data/\*.parquet"), - predict_output_path=os.path.join(self.test_dir, "predict_result"), - reserved_columns="clk,buy", - output_columns="probs_ctr,probs_cvr", - test_dir=self.test_dir, - ) - self.assertTrue(self.success) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "train/eval_result.txt")) + def test_multi_tower_din_zch_with_fg_train_eval_export(self): + self._test_rank_with_fg( + "tzrec/tests/configs/multi_tower_din_zch_fg_mock.config", + comp_cpu_gpu_pred_result=True, ) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "export/scripted_model.pt")) + + def test_dbmtl_has_sequence_train_eval_export(self): + self._test_rank_with_fg("tzrec/tests/configs/dbmtl_has_sequence_mock.config") + + def test_multi_tower_din_with_fg_train_eval_export_input_tile(self): + self._test_rank_with_fg_input_tile( + "tzrec/tests/configs/multi_tower_din_fg_mock.config" + ) + + def test_multi_tower_din_zch_with_fg_train_eval_export_input_tile(self): + self._test_rank_with_fg_input_tile( + "tzrec/tests/configs/multi_tower_din_zch_fg_mock.config" ) def test_dbmtl_has_sequence_variational_dropout_train_eval_export(self): @@ -590,70 +438,6 @@ def test_dbmtl_has_sequence_variational_dropout_train_eval_export(self): os.path.exists(os.path.join(self.test_dir, "output_dir/pipeline.config")) ) - def test_tdm_train_eval_export(self): - self.success = utils.test_train_eval( - "tzrec/tests/configs/tdm_fg_mock.config", - self.test_dir, - user_id="user_id", - item_id="item_id", - cate_id="cate_id", - ) - if self.success: - self.success = utils.test_eval( - os.path.join(self.test_dir, "pipeline.config"), self.test_dir - ) - if self.success: - self.success = utils.test_export( - os.path.join(self.test_dir, "pipeline.config"), - self.test_dir, - asset_files=os.path.join(self.test_dir, "init_tree/serving_tree"), - ) - if self.success: - self.success = utils.test_predict( - scripted_model_path=os.path.join(self.test_dir, "export/embedding"), - predict_input_path=os.path.join(self.test_dir, r"item_data/\*.parquet"), - predict_output_path=os.path.join(self.test_dir, "item_emb"), - reserved_columns="item_id,cate_id,id_4,id_5,raw_1,raw_2", - output_columns="item_emb", - test_dir=self.test_dir, - ) - if self.success: - self.success = utils.test_tdm_cluster_train_eval( - pipeline_config_path=os.path.join(self.test_dir, "pipeline.config"), - test_dir=self.test_dir, - item_input_path=os.path.join(self.test_dir, r"item_emb/\*.parquet"), - item_id="item_id", - embedding_field="item_emb", - ) - if self.success: - self.success = utils.test_tdm_retrieval( - scripted_model_path=os.path.join(self.test_dir, "export/model"), - eval_data_path=os.path.join(self.test_dir, r"eval_data/\*.parquet"), - retrieval_output_path=os.path.join(self.test_dir, "retrieval_result"), - reserved_columns="user_id,item_id", - test_dir=self.test_dir, - ) - - self.assertTrue(self.success) - self.assertTrue(os.path.exists(os.path.join(self.test_dir, "learnt_tree"))) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "train/eval_result.txt")) - ) - self.assertTrue( - os.path.exists( - os.path.join(self.test_dir, "export/embedding/scripted_model.pt") - ) - ) - self.assertTrue( - os.path.exists( - os.path.join(self.test_dir, "export/model/scripted_model.pt") - ) - ) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "export/model/serving_tree")) - ) - self.assertTrue(os.path.exists(os.path.join(self.test_dir, "retrieval_result"))) - @unittest.skipIf(not torch.cuda.is_available(), "cuda not found") def test_multi_tower_with_fg_train_eval_export_trt(self): self.success = utils.test_train_eval( @@ -839,7 +623,7 @@ def test_multi_tower_with_fg_train_eval_export_trt(self): model_gpu = torch.jit.load( os.path.join(self.test_dir, "export/scripted_model.pt"), map_location=device ) - result_gpu = model_gpu(data.to_dict(sparse_dtype=torch.int32), device) + result_gpu = model_gpu(data.to_dict(sparse_dtype=torch.int64), device) result_dict_json_path = os.path.join(self.test_dir, "result_gpu.json") utils.save_predict_result_json(result_gpu, result_dict_json_path) @@ -849,7 +633,7 @@ def test_multi_tower_with_fg_train_eval_export_trt(self): os.path.join(self.test_dir, "trt/export/scripted_model.pt"), map_location=device, ) - result_gpu_trt = model_gpu_trt(data.to_dict(sparse_dtype=torch.int32)) + result_gpu_trt = model_gpu_trt(data.to_dict(sparse_dtype=torch.int64)) result_dict_json_path = os.path.join(self.test_dir, "result_gpu_trt.json") utils.save_predict_result_json(result_gpu_trt, result_dict_json_path) @@ -868,7 +652,7 @@ def test_multi_tower_with_fg_train_eval_export_trt(self): iterator_input_tile = iter(dataloader) data_input_tile = next(iterator_input_tile) result_gpu_input_tile = model_gpu_input_tile( - data_input_tile.to(device=device).to_dict(sparse_dtype=torch.int32) + data_input_tile.to(device=device).to_dict(sparse_dtype=torch.int64) ) result_dict_json_path = os.path.join( self.test_dir, "result_gpu_input_tile.json" @@ -881,7 +665,7 @@ def test_multi_tower_with_fg_train_eval_export_trt(self): map_location=device, ) result_gpu_input_tile_emb = model_gpu_input_tile_emb( - data_input_tile.to(device=device).to_dict(sparse_dtype=torch.int32) + data_input_tile.to(device=device).to_dict(sparse_dtype=torch.int64) ) # trt is all same sa no-trt diff --git a/tzrec/utils/fx_util.py b/tzrec/utils/fx_util.py new file mode 100644 index 0000000..919f3ce --- /dev/null +++ b/tzrec/utils/fx_util.py @@ -0,0 +1,45 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from torchrec.fx import symbolic_trace as _symbolic_trace + + +def symbolic_trace( + # pyre-ignore[24] + root: Union[torch.nn.Module, Callable], + concrete_args: Optional[Dict[str, Any]] = None, + leaf_modules: Optional[List[str]] = None, +) -> torch.fx.GraphModule: + """Symbolic tracing API. + + Given an `nn.Module` or function instance `root`, this function will return a + `GraphModule` constructed by recording operations seen while tracing through `root`. + + `concrete_args` allows you to partially specialize your function, whether it's to + remove control flow or data structures. + + Args: + root (Union[torch.nn.Module, Callable]): Module or function to be traced and + converted into a Graph representation. + concrete_args (Optional[Dict[str, any]]): Inputs to be partially specialized + leaf_modules (Optional[List[str]]): modules do not trace + + Returns: + GraphModule: a Module created from the recorded operations from ``root``. + """ + # ComputeJTDictToKJT could not be traced + _leaf_modules = ["ComputeJTDictToKJT"] + if leaf_modules: + _leaf_modules.extend(leaf_modules) + return _symbolic_trace(root, concrete_args, _leaf_modules) diff --git a/tzrec/utils/test_util.py b/tzrec/utils/test_util.py index ba0b986..e2d576d 100644 --- a/tzrec/utils/test_util.py +++ b/tzrec/utils/test_util.py @@ -17,9 +17,9 @@ import torch from torch import nn from torch.fx import GraphModule -from torchrec.fx import symbolic_trace from tzrec.models.model import ScriptWrapper +from tzrec.utils.fx_util import symbolic_trace class TestGraphType(Enum): diff --git a/tzrec/version.py b/tzrec/version.py index 5c07ea6..64d31d9 100644 --- a/tzrec/version.py +++ b/tzrec/version.py @@ -9,4 +9,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.6.6" +__version__ = "0.6.7"