diff --git a/docs/source/feature/feature.md b/docs/source/feature/feature.md index d50481c..cf5ad03 100644 --- a/docs/source/feature/feature.md +++ b/docs/source/feature/feature.md @@ -58,6 +58,17 @@ feature_configs { embedding_dim: 32 vocab_dict: [{key:"a" value:2}, {key:"b" value:3}, {key:"c" value:2}] } +feature_configs { + id_feature { + feature_name: "cate" + expression: "item:cate" + embedding_dim: 32 + zch: { + zch_size: 1000000 + eviction_interval: 2 + lfu {} + } + } } ``` @@ -75,6 +86,8 @@ feature_configs { - **vocab_dict**: 指定字典形式词表,适合多个词需要编码到同一个编号情况,**编号需要从2开始**,编码0预留给默认值,编码1预留给超出词表的词 +- **zch**: 零冲突hash,可设置Id的准入和驱逐策略,详见[文档](../zch.md) + - **weighted**: 是否为带权重的Id特征,输入形式为`k1:v1\x1dk2:v2` - **value_dim**: 默认值是0,可以设置1,value_dim=0时支持多值ID输出 @@ -207,6 +220,7 @@ feature_configs: { - **num_buckets**: buckets数量, 仅仅当输入是integer类型时,可以使用num_buckets - **vocab_list**: 指定词表,适合取值比较少可以枚举的特征。 - **vocab_dict**: 指定字典形式词表,适合多个词需要编码到同一个编号情况,**编号需要从2开始**,编码0预留给默认值,编码1预留给超出词表的词 +- **zch**: 零冲突hash,可设置Id的准入和驱逐策略,详见[文档](../zch.md) - **value_dim**: 默认值是0,可以设置1,value_dim=0时支持多值ID输出 如果Map的值为连续值,可设置: @@ -247,6 +261,7 @@ feature_configs: { - **num_buckets**: buckets数量, 仅仅当输入是integer类型时,可以使用num_buckets - **vocab_list**: 指定词表,适合取值比较少可以枚举的特征。 - **vocab_dict**: 指定字典形式词表,适合多个词需要编码到同一个编号情况,**编号需要从2开始**,编码0预留给默认值,编码1预留给超出词表的词 +- **zch**: 零冲突hash,可设置Id的准入和驱逐策略,详见[文档](../zch.md) - **value_dim**: 默认值是0,可以设置1,value_dim=0时支持多值ID输出 如果Map的值为连续值,可设置: diff --git a/docs/source/feature/zch.md b/docs/source/feature/zch.md new file mode 100644 index 0000000..779f1f2 --- /dev/null +++ b/docs/source/feature/zch.md @@ -0,0 +1,151 @@ +# 零冲突Hash Embedding + +零冲突Hash (Zero Collision Hash, zch) 是特征Id化的一种方式,它相比设置`hash_bucket_size`的方式能减少hash冲突,相比设置`vocab_dict`和`vocab_list`的方式能更灵活动态地进行id的准入和驱逐。零冲突Hash常用于user id,item id,combo feature等超大id枚举数的特征配置中。 + +以id_feature的配置为例,零冲突Hash只需在id_feature新增一个zch的配置字段 + +``` +feature_configs { + id_feature { + feature_name: "cate" + expression: "item:cate" + embedding_dim: 32 + zch: { + zch_size: 1000000 + eviction_interval: 2 + lfu {} + } + } +} +``` + +- **zch_size**: 零冲突Hash的Bucket大小,Id数超过后会根据Id的驱逐策略进行淘汰 + +- **eviction_interval**: Id准入和驱逐策略执行的频率(训练步数间隔) + +- **eviction_policy**: 驱逐策略,可选`lfu`,`lru`,`distance_lfu`,详见下文驱逐策略 + +- **threshold_filtering_func**: 准入策略lambda函数,默认为全部准入,详见下文准入策略 + +## 驱逐策略 + +### LFU_EvictionPolicy + +驱逐最小出现次数的Id +id_score = access_cnt + +``` +lfu {} +``` + +### LRU_EvictionPolicy + +驱逐最早出现的Id +id_score = 1 / pow((current_iter - last_access_iter), decay_exponent) + +``` +lru { + decay_exponent: 1.0 +} +``` + +### DistanceLFU_EvictionPolicy + +综合出现次数和出现时间综合根据综合驱逐id_score较小的Id +id_score = access_cnt / pow((current_iter - last_access_iter), decay_exponent) + +``` +distance_lfu { + decay_exponent: 1.0 +} +``` + +## 准入策略 + +准入策略需设置一个lambda函数表达式,函数输入输出应符合如下格式 + +- 输入:一个1维的IntTensor表示最近`eviction_interval`个batch中每个id的出现次数 +- 输出:一个1维的BoolTensor表示保留的id位置 和 一个float值表示id出现次数的阈值 + +函数可支持直接用torch的tensor库来撰写,样例如下: + +``` +zch: { + zch_size: 1000000 + eviction_interval: 2 + lfu {} + threshold_filtering_func: "lambda x: (x > 10, 10)" +} +``` + +函数也可以支持调用内置函数:`dynamic_threshold_filter`, `average_threshold_filter` 和 `probabilistic_threshold_filter`,样例如下: + +``` +zch: { + zch_size: 1000000 + eviction_interval: 2 + lfu {} + threshold_filtering_func: "lambda x: dynamic_threshold_filter(x, 1.0)" +} +``` + +相关内置函数的实现细节如下: + +```python +@torch.no_grad() +def dynamic_threshold_filter( + id_counts: torch.Tensor, + threshold_skew_multiplier: float = 10.0, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Threshold is total_count / num_ids * threshold_skew_multiplier. An id is + added if its count is strictly greater than the threshold. + """ + + num_ids = id_counts.numel() + total_count = id_counts.sum() + + BASE_THRESHOLD = 1 / num_ids + threshold_mass = BASE_THRESHOLD * threshold_skew_multiplier + + threshold = threshold_mass * total_count + threshold_mask = id_counts > threshold + + return threshold_mask, threshold + + +@torch.no_grad() +def average_threshold_filter( + id_counts: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Threshold is average of id_counts. An id is added if its count is strictly + greater than the mean. + """ + if id_counts.dtype != torch.float: + id_counts = id_counts.float() + threshold = id_counts.mean() + threshold_mask = id_counts > threshold + + return threshold_mask, threshold + + +@torch.no_grad() +def probabilistic_threshold_filter( + id_counts: torch.Tensor, + per_id_probability: float = 0.01, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Each id has probability per_id_probability of being added. For example, + if per_id_probability is 0.01 and an id appears 100 times, then it has a 60% + of being added. More precisely, the id score is 1 - (1 - per_id_probability) ^ id_count, + and for a randomly generated threshold, the id score is the chance of it being added. + """ + probability = torch.full_like(id_counts, 1 - per_id_probability, dtype=torch.float) + id_scores = 1 - torch.pow(probability, id_counts) + + threshold: torch.Tensor = torch.rand(id_counts.size(), device=id_counts.device) + threshold_mask = id_scores > threshold + + return threshold_mask, threshold +``` diff --git a/docs/source/index.rst b/docs/source/index.rst index a643e2f..3c82600 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -14,6 +14,7 @@ Welcome to TorchEasyRec's documentation! feature/data feature/feature + feature/zch .. toctree:: :maxdepth: 2 diff --git a/requirements.txt b/requirements.txt index b519733..5ffa1aa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ -r requirements/runtime.txt -r requirements/test.txt -r requirements/docs.txt +-r requirements/gpu.txt diff --git a/requirements/gpu.txt b/requirements/gpu.txt new file mode 100644 index 0000000..d50dd84 --- /dev/null +++ b/requirements/gpu.txt @@ -0,0 +1,2 @@ +torch-tensorrt @ http://tzrec.oss-cn-beijing.aliyuncs.com/third_party/trt/torch_tensorrt-2.5.0a0-cp311-cp311-linux_x86_64.whl ; python_version=="3.11" +torch-tensorrt @ http://tzrec.oss-cn-beijing.aliyuncs.com/third_party/trt/torch_tensorrt-2.5.0a0-cp310-cp310-linux_x86_64.whl ; python_version=="3.10" diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 815bb87..93a7298 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -7,13 +7,11 @@ graphlearn @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/graphlearn-1. graphlearn @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/graphlearn-1.3.1-cp310-cp310-linux_x86_64.whl ; python_version=="3.10" grpcio-tools<1.63.0 pandas -pyfg @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/pyfg-0.3.7-cp311-cp311-linux_x86_64.whl ; python_version=="3.11" -pyfg @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/pyfg-0.3.7-cp310-cp310-linux_x86_64.whl ; python_version=="3.10" +pyfg @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/pyfg-0.3.9-cp311-cp311-linux_x86_64.whl ; python_version=="3.11" +pyfg @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/pyfg-0.3.9-cp310-cp310-linux_x86_64.whl ; python_version=="3.10" pyodps>=0.12.0 scikit-learn tensorboard torch==2.5.0 -torch-tensorrt @ http://tzrec.oss-cn-beijing.aliyuncs.com/third_party/trt/torch_tensorrt-2.5.0a0-cp311-cp311-linux_x86_64.whl ; python_version=="3.11" -torch-tensorrt @ http://tzrec.oss-cn-beijing.aliyuncs.com/third_party/trt/torch_tensorrt-2.5.0a0-cp310-cp310-linux_x86_64.whl ; python_version=="3.10" torchmetrics==1.0.3 torchrec==1.0.0 diff --git a/setup.py b/setup.py index 398e908..d749cff 100644 --- a/setup.py +++ b/setup.py @@ -78,5 +78,6 @@ def parse_require_file(fpath): extras_require={ "all": parse_requirements("requirements.txt"), "tests": parse_requirements("requirements/test.txt"), + "gpu": parse_requirements("requirements/gpu.txt"), }, ) diff --git a/tzrec/acc/trt_utils.py b/tzrec/acc/trt_utils.py index 49e8e7a..89d32ef 100644 --- a/tzrec/acc/trt_utils.py +++ b/tzrec/acc/trt_utils.py @@ -21,10 +21,10 @@ pass from torch import nn from torch.profiler import ProfilerActivity, profile, record_function -from torchrec.fx import symbolic_trace from tzrec.acc.utils import is_debug_trt from tzrec.models.model import ScriptWrapper +from tzrec.utils.fx_util import symbolic_trace from tzrec.utils.logging_util import logger diff --git a/tzrec/acc/utils.py b/tzrec/acc/utils.py index 1f353b7..9a4263a 100644 --- a/tzrec/acc/utils.py +++ b/tzrec/acc/utils.py @@ -115,25 +115,26 @@ def write_mapping_file_for_input_tile( state_dict (Dict[str, torch.Tensor]): model state_dict remap_file_path (str) : store new_params_name\told_params_name\n """ - input_tile_keys = [ - ".ebc_user.embedding_bags.", - ".ebc_item.embedding_bags.", - ] - input_tile_keys_ec = [ - ".ec_list_user.", - ".ec_list_item.", - ] + input_tile_mapping = { + ".ebc_user.embedding_bags.": ".ebc.embedding_bags.", + ".ebc_item.embedding_bags.": ".ebc.embedding_bags.", + ".mc_ebc_user._embedding_module.": ".mc_ebc._embedding_module.", + ".mc_ebc_item._embedding_module.": ".mc_ebc._embedding_module.", + ".mc_ebc_user._managed_collision_collection.": ".mc_ebc._managed_collision_collection.", # NOQA + ".mc_ebc_item._managed_collision_collection.": ".mc_ebc._managed_collision_collection.", # NOQA + ".ec_list_user.": ".ec_list.", + ".ec_list_item.": ".ec_list.", + ".mc_ec_list_user.": ".mc_ec_list.", + ".mc_ec_list_item.": ".mc_ec_list.", + } remap_str = "" for key, _ in state_dict.items(): - for input_tile_key in input_tile_keys: + for input_tile_key in input_tile_mapping: if input_tile_key in key: - src_key = key.replace(input_tile_key, ".ebc.embedding_bags.") - remap_str += key + "\t" + src_key + "\n" - - for input_tile_key in input_tile_keys_ec: - if input_tile_key in key: - src_key = key.replace(input_tile_key, ".ec_list.") + src_key = key.replace( + input_tile_key, input_tile_mapping[input_tile_key] + ) remap_str += key + "\t" + src_key + "\n" with open(remap_file_path, "w") as f: @@ -142,7 +143,8 @@ def write_mapping_file_for_input_tile( def export_acc_config() -> Dict[str, str]: """Export acc config for model online inference.""" - acc_config = dict() + # use int64 sparse id as input + acc_config = {"SPARSE_INT64": "1"} if "INPUT_TILE" in os.environ: acc_config["INPUT_TILE"] = os.environ["INPUT_TILE"] if "QUANT_EMB" in os.environ: diff --git a/tzrec/features/combo_feature.py b/tzrec/features/combo_feature.py index 1664f28..c608c1d 100644 --- a/tzrec/features/combo_feature.py +++ b/tzrec/features/combo_feature.py @@ -20,7 +20,11 @@ ParsedData, SparseData, ) -from tzrec.features.feature import FgMode, _parse_fg_encoded_sparse_feature_impl +from tzrec.features.feature import ( + MAX_HASH_BUCKET_SIZE, + FgMode, + _parse_fg_encoded_sparse_feature_impl, +) from tzrec.features.id_feature import IdFeature from tzrec.protos.feature_pb2 import FeatureConfig from tzrec.utils.logging_util import logger @@ -53,7 +57,9 @@ def is_neg(self, value: bool) -> None: @property def num_embeddings(self) -> int: """Get embedding row count.""" - if self.config.HasField("hash_bucket_size"): + if self.config.HasField("zch"): + num_embeddings = self.config.zch.zch_size + elif self.config.HasField("hash_bucket_size"): num_embeddings = self.config.hash_bucket_size elif len(self.config.vocab_list) > 0: num_embeddings = len(self.config.vocab_list) + 2 @@ -69,7 +75,7 @@ def num_embeddings(self) -> int: else: raise ValueError( f"{self.__class__.__name__}[{self.name}] must set hash_bucket_size" - " or vocab_list or vocab_dict" + " or vocab_list or vocab_dict or zch.zch_size" ) return num_embeddings @@ -116,7 +122,9 @@ def fg_json(self) -> List[Dict[str, Any]]: } if self.config.separator != "\x1d": fg_cfg["separator"] = self.config.separator - if self.config.HasField("hash_bucket_size"): + if self.config.HasField("zch"): + fg_cfg["hash_bucket_size"] = MAX_HASH_BUCKET_SIZE + elif self.config.HasField("hash_bucket_size"): fg_cfg["hash_bucket_size"] = self.config.hash_bucket_size elif len(self.config.vocab_list) > 0: fg_cfg["vocab_list"] = [self.config.default_value, ""] + list( diff --git a/tzrec/features/feature.py b/tzrec/features/feature.py index 1268d89..3a895c3 100644 --- a/tzrec/features/feature.py +++ b/tzrec/features/feature.py @@ -20,12 +20,23 @@ import numpy as np import pyarrow as pa import pyfg +import torch from torch import nn # NOQA from torchrec.modules.embedding_configs import ( EmbeddingBagConfig, EmbeddingConfig, PoolingType, ) +from torchrec.modules.mc_modules import ( + DistanceLFU_EvictionPolicy, + LFU_EvictionPolicy, + LRU_EvictionPolicy, + ManagedCollisionModule, + MCHManagedCollisionModule, + average_threshold_filter, # NOQA + dynamic_threshold_filter, # NOQA + probabilistic_threshold_filter, # NOQA +) from tzrec.datasets.utils import ( BASE_DATA_GROUP, @@ -52,6 +63,9 @@ class FgMode(Enum): DAG = 3 +MAX_HASH_BUCKET_SIZE = 2**63 - 1 + + def _parse_fg_encoded_sparse_feature_impl( name: str, feat: pa.Array, @@ -386,6 +400,41 @@ def emb_config(self) -> Optional[EmbeddingConfig]: else: return None + def mc_module(self, device: torch.device) -> Optional[ManagedCollisionModule]: + """Get ManagedCollisionModule.""" + if self.is_sparse: + if hasattr(self.config, "zch") and self.config.HasField("zch"): + evict_type = self.config.zch.WhichOneof("eviction_policy") + evict_config = getattr(self.config.zch, evict_type) + threshold_filtering_func = None + if self.config.zch.HasField("threshold_filtering_func"): + threshold_filtering_func = eval( + self.config.zch.threshold_filtering_func + ) + if evict_type == "lfu": + eviction_policy = LFU_EvictionPolicy( + threshold_filtering_func=threshold_filtering_func + ) + elif evict_type == "lru": + eviction_policy = LRU_EvictionPolicy( + decay_exponent=evict_config.decay_exponent, + threshold_filtering_func=threshold_filtering_func, + ) + elif evict_type == "distance_lfu": + eviction_policy = DistanceLFU_EvictionPolicy( + decay_exponent=evict_config.decay_exponent, + threshold_filtering_func=threshold_filtering_func, + ) + else: + raise ValueError("Unknown evict policy type: {evict_type}") + return MCHManagedCollisionModule( + zch_size=self.config.zch.zch_size, + device=device, + eviction_interval=self.config.zch.eviction_interval, + eviction_policy=eviction_policy, + ) + return None + @property def inputs(self) -> List[str]: """Input field names.""" diff --git a/tzrec/features/id_feature.py b/tzrec/features/id_feature.py index 13891a9..5102dd5 100644 --- a/tzrec/features/id_feature.py +++ b/tzrec/features/id_feature.py @@ -20,6 +20,7 @@ SparseData, ) from tzrec.features.feature import ( + MAX_HASH_BUCKET_SIZE, BaseFeature, FgMode, _parse_fg_encoded_sparse_feature_impl, @@ -71,7 +72,9 @@ def is_sparse(self) -> bool: @property def num_embeddings(self) -> int: """Get embedding row count.""" - if self.config.HasField("hash_bucket_size"): + if self.config.HasField("zch"): + num_embeddings = self.config.zch.zch_size + elif self.config.HasField("hash_bucket_size"): num_embeddings = self.config.hash_bucket_size elif self.config.HasField("num_buckets"): num_embeddings = self.config.num_buckets @@ -89,7 +92,7 @@ def num_embeddings(self) -> int: else: raise ValueError( f"{self.__class__.__name__}[{self.name}] must set hash_bucket_size" - " or num_buckets or vocab_list or vocab_dict" + " or num_buckets or vocab_list or vocab_dict or zch.zch_size" ) return num_embeddings @@ -156,7 +159,10 @@ def fg_json(self) -> List[Dict[str, Any]]: } if self.config.separator != "\x1d": fg_cfg["separator"] = self.config.separator - if self.config.HasField("hash_bucket_size"): + if self.config.HasField("zch"): + fg_cfg["hash_bucket_size"] = MAX_HASH_BUCKET_SIZE + fg_cfg["value_type"] = "string" + elif self.config.HasField("hash_bucket_size"): fg_cfg["hash_bucket_size"] = self.config.hash_bucket_size fg_cfg["value_type"] = "string" elif len(self.config.vocab_list) > 0: diff --git a/tzrec/features/id_feature_test.py b/tzrec/features/id_feature_test.py index e35025e..389ce9c 100644 --- a/tzrec/features/id_feature_test.py +++ b/tzrec/features/id_feature_test.py @@ -15,6 +15,7 @@ import numpy as np import pyarrow as pa +import torch from parameterized import parameterized from torch import nn from torchrec.modules.embedding_configs import ( @@ -88,6 +89,51 @@ def test_init_fn_id_feature(self): ) self.assertEqual(repr(id_feat.emb_config), repr(expected_emb_config)) + @parameterized.expand( + [ + ["lambda x: probabilistic_threshold_filter(x,0.05)"], + ["lambda x: (x > 10, 10)"], + ], + name_func=test_util.parameterized_name_func, + ) + def test_zch_id_feature(self, threshold_filtering_func): + id_feat_cfg = feature_pb2.FeatureConfig( + id_feature=feature_pb2.IdFeature( + feature_name="id_feat", + embedding_dim=16, + zch=feature_pb2.ZeroCollisionHash( + zch_size=100, + eviction_interval=5, + distance_lfu=feature_pb2.DistanceLFU_EvictionPolicy( + decay_exponent=1.0, + ), + threshold_filtering_func=threshold_filtering_func, + ), + ) + ) + id_feat = id_feature_lib.IdFeature(id_feat_cfg) + expected_emb_bag_config = EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=16, + name="id_feat_emb", + feature_names=["id_feat"], + pooling=PoolingType.SUM, + ) + self.assertEqual(repr(id_feat.emb_bag_config), repr(expected_emb_bag_config)) + expected_emb_config = EmbeddingConfig( + num_embeddings=100, + embedding_dim=16, + name="id_feat_emb", + feature_names=["id_feat"], + ) + self.assertEqual(repr(id_feat.emb_config), repr(expected_emb_config)) + mc_module = id_feat.mc_module(torch.device("meta")) + self.assertEqual(mc_module._zch_size, 100) + self.assertEqual(mc_module._eviction_interval, 5) + self.assertTrue( + mc_module._eviction_policy._threshold_filtering_func is not None + ) + def test_fg_encoded_with_weighted(self): id_feat_cfg = feature_pb2.FeatureConfig( id_feature=feature_pb2.IdFeature( diff --git a/tzrec/features/lookup_feature.py b/tzrec/features/lookup_feature.py index a24bf63..9d4aac9 100644 --- a/tzrec/features/lookup_feature.py +++ b/tzrec/features/lookup_feature.py @@ -23,6 +23,7 @@ SparseData, ) from tzrec.features.feature import ( + MAX_HASH_BUCKET_SIZE, BaseFeature, FgMode, _parse_fg_encoded_dense_feature_impl, @@ -84,7 +85,9 @@ def is_sparse(self) -> bool: @property def num_embeddings(self) -> int: """Get embedding row count.""" - if self.config.HasField("hash_bucket_size"): + if self.config.HasField("zch"): + num_embeddings = self.config.zch.zch_size + elif self.config.HasField("hash_bucket_size"): num_embeddings = self.config.hash_bucket_size elif self.config.HasField("num_buckets"): num_embeddings = self.config.num_buckets @@ -201,7 +204,11 @@ def fg_json(self) -> List[Dict[str, Any]]: fg_cfg["separator"] = self.config.separator if self.config.HasField("normalizer"): fg_cfg["normalizer"] = self.config.normalizer - if self.config.HasField("hash_bucket_size"): + if self.config.HasField("zch"): + fg_cfg["hash_bucket_size"] = MAX_HASH_BUCKET_SIZE + fg_cfg["value_type"] = "string" + fg_cfg["needDiscrete"] = True + elif self.config.HasField("hash_bucket_size"): fg_cfg["hash_bucket_size"] = self.config.hash_bucket_size fg_cfg["value_type"] = "string" fg_cfg["needDiscrete"] = True diff --git a/tzrec/features/match_feature.py b/tzrec/features/match_feature.py index 13de909..4fbc557 100644 --- a/tzrec/features/match_feature.py +++ b/tzrec/features/match_feature.py @@ -23,6 +23,7 @@ SparseData, ) from tzrec.features.feature import ( + MAX_HASH_BUCKET_SIZE, BaseFeature, FgMode, _parse_fg_encoded_dense_feature_impl, @@ -86,7 +87,9 @@ def is_sparse(self) -> bool: @property def num_embeddings(self) -> int: """Get embedding row count.""" - if self.config.HasField("hash_bucket_size"): + if self.config.HasField("zch"): + num_embeddings = self.config.zch.zch_size + elif self.config.HasField("hash_bucket_size"): num_embeddings = self.config.hash_bucket_size elif self.config.HasField("num_buckets"): num_embeddings = self.config.num_buckets @@ -174,7 +177,11 @@ def fg_json(self) -> List[Dict[str, Any]]: fg_cfg["separator"] = self.config.separator if self.config.HasField("normalizer"): fg_cfg["normalizer"] = self.config.normalizer - if self.config.HasField("hash_bucket_size"): + if self.config.HasField("zch"): + fg_cfg["hash_bucket_size"] = MAX_HASH_BUCKET_SIZE + fg_cfg["value_type"] = "string" + fg_cfg["needDiscrete"] = True + elif self.config.HasField("hash_bucket_size"): fg_cfg["hash_bucket_size"] = self.config.hash_bucket_size fg_cfg["value_type"] = "string" fg_cfg["needDiscrete"] = True diff --git a/tzrec/features/sequence_feature.py b/tzrec/features/sequence_feature.py index ee1734e..fa7c9c6 100644 --- a/tzrec/features/sequence_feature.py +++ b/tzrec/features/sequence_feature.py @@ -22,6 +22,7 @@ SequenceDenseData, SequenceSparseData, ) +from tzrec.features.feature import MAX_HASH_BUCKET_SIZE from tzrec.features.id_feature import FgMode, IdFeature from tzrec.features.raw_feature import RawFeature from tzrec.protos import feature_pb2 @@ -303,7 +304,9 @@ def fg_json(self) -> List[Dict[str, Any]]: fg_cfg["sequence_length"] = self.config.sequence_length if self.config.separator != "\x1d": fg_cfg["separator"] = self.config.separator - if self.config.HasField("hash_bucket_size"): + if self.config.HasField("zch"): + fg_cfg["hash_bucket_size"] = MAX_HASH_BUCKET_SIZE + elif self.config.HasField("hash_bucket_size"): fg_cfg["hash_bucket_size"] = self.config.hash_bucket_size elif self.config.HasField("num_buckets"): fg_cfg["num_buckets"] = self.config.num_buckets diff --git a/tzrec/main.py b/tzrec/main.py index 1d5c56d..b176d4a 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -31,7 +31,6 @@ # NOQA from torchrec.distributed.train_pipeline import TrainPipelineSparseDist -from torchrec.fx import symbolic_trace from torchrec.inference.modules import quantize_embeddings from torchrec.inference.state_dict_transform import ( state_dict_gather, @@ -81,6 +80,7 @@ from tzrec.protos.pipeline_pb2 import EasyRecConfig from tzrec.protos.train_pb2 import TrainConfig from tzrec.utils import checkpoint_util, config_util +from tzrec.utils.fx_util import symbolic_trace from tzrec.utils.logging_util import ProgressLogger, logger from tzrec.utils.plan_util import create_planner, get_default_sharders from tzrec.version import __version__ as tzrec_version @@ -747,7 +747,7 @@ def _script_model( model.eval() if is_trt_convert: - data_cuda = batch.to_dict(sparse_dtype=torch.int32) + data_cuda = batch.to_dict(sparse_dtype=torch.int64) result = model(data_cuda, "cuda:0") result_info = {k: (v.size(), v.dtype) for k, v in result.items()} logger.info(f"Model Outputs: {result_info}") @@ -1004,7 +1004,6 @@ def predict( device_and_backend = init_process_group() device: torch.device = device_and_backend[0] - sparse_dtype: torch.dtype = torch.int32 if device.type == "cuda" else torch.int64 is_rank_zero = int(os.environ.get("RANK", 0)) == 0 is_local_rank_zero = int(os.environ.get("LOCAL_RANK", 0)) == 0 @@ -1070,7 +1069,7 @@ def predict( def _forward(batch: Batch) -> Tuple[Dict[str, torch.Tensor], RecordBatchTensor]: with torch.no_grad(): - parsed_inputs = batch.to_dict(sparse_dtype=sparse_dtype) + parsed_inputs = batch.to_dict(sparse_dtype=torch.int64) # when predicting with a model exported using INPUT_TILE, # we set the batch size tensor to 1 to disable tiling. parsed_inputs["batch_size"] = torch.tensor(1, dtype=torch.int64) diff --git a/tzrec/modules/embedding.py b/tzrec/modules/embedding.py index 6f4f7cc..cd24668 100644 --- a/tzrec/modules/embedding.py +++ b/tzrec/modules/embedding.py @@ -19,6 +19,15 @@ EmbeddingBagCollection, EmbeddingCollection, ) +from torchrec.modules.mc_embedding_modules import ( + ManagedCollisionEmbeddingBagCollection, + ManagedCollisionEmbeddingCollection, +) +from torchrec.modules.mc_modules import ( + ManagedCollisionCollection, + ManagedCollisionModule, + MCHManagedCollisionModule, +) from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor from tzrec.acc.utils import is_input_tile, is_input_tile_emb @@ -464,17 +473,14 @@ def predict( return values_list -def add_embedding_bag_config( +def _add_embedding_bag_config( emb_bag_configs: Dict[str, EmbeddingBagConfig], emb_bag_config: EmbeddingBagConfig ) -> None: - """Add embedding bag config . + """Add embedding bag config to a dict of embedding bag config. Args: - emb_bag_configs: Dict[str, EmbeddingBagConfig]: a dict contains emb_bag_configs - emb_bag_config: EmbeddingBagConfig - - Returns: - None + emb_bag_configs(Dict[str, EmbeddingBagConfig]): a dict contains emb_bag_configs + emb_bag_config(EmbeddingBagConfig): an instance of EmbeddingBagConfig """ if emb_bag_config.name in emb_bag_configs: existed_emb_bag_config = emb_bag_configs[emb_bag_config.name] @@ -494,17 +500,14 @@ def add_embedding_bag_config( emb_bag_configs[emb_bag_config.name] = emb_bag_config -def add_embedding_config( +def _add_embedding_config( emb_configs: Dict[str, EmbeddingConfig], emb_config: EmbeddingConfig ) -> None: - """Add embedding config . + """Add embedding config to a dict of embedding config. Args: - emb_configs: Dict[str, EmbeddingConfig]: a dict contains emb_configs - emb_config: EmbeddingConfig - - Returns: - None + emb_configs(Dict[str, EmbeddingConfig]): a dict contains emb_configs + emb_config(EmbeddingConfig): an instance of EmbeddingConfig """ if emb_config.name in emb_configs: existed_emb_config = emb_configs[emb_config.name] @@ -523,6 +526,28 @@ def add_embedding_config( emb_configs[emb_config.name] = emb_config +def _add_mc_module( + mc_modules: Dict[str, ManagedCollisionModule], + emb_name: str, + mc_module: ManagedCollisionModule, +) -> None: + """Add ManagedCollisionModule to a dict of ManagedCollisionModule. + + Args: + mc_modules(Dict[str, ManagedCollisionModule]): a dict of ManagedCollisionModule. + emb_name(str): embedding_name. + mc_module(ManagedCollisionModule): an instance of ManagedCollisionModule. + """ + if emb_name in mc_modules: + existed_mc_module = mc_modules[emb_name] + if isinstance(mc_module, MCHManagedCollisionModule): + assert isinstance(existed_mc_module, MCHManagedCollisionModule) + assert mc_module._zch_size == existed_mc_module._zch_size + assert mc_module._eviction_interval == existed_mc_module._eviction_interval + assert repr(mc_module._eviction_policy) == repr(mc_module._eviction_policy) + mc_modules[emb_name] = mc_module + + class EmbeddingGroupImpl(nn.Module): """Applies embedding lookup transformation for feature group. @@ -545,9 +570,13 @@ def __init__( device = torch.device("meta") name_to_feature = {x.name: x for x in features} emb_bag_configs = OrderedDict() + mc_emb_bag_configs = OrderedDict() + mc_modules = OrderedDict() self.has_sparse = False self.has_sparse_user = False + self.has_mc_sparse = False + self.has_mc_sparse_user = False self.has_dense = False self.has_dense_user = False @@ -580,6 +609,10 @@ def __init__( input_tile = is_input_tile() emb_bag_configs_user = OrderedDict() emb_bag_configs_item = OrderedDict() + mc_emb_bag_configs_user = OrderedDict() + mc_emb_bag_configs_item = OrderedDict() + mc_modules_item = OrderedDict() + mc_modules_user = OrderedDict() for feature_group in feature_groups: total_dim = 0 @@ -595,6 +628,7 @@ def __init__( if feature.is_sparse: output_dim = feature.output_dim emb_bag_config = feature.emb_bag_config + mc_module = feature.mc_module(device) assert emb_bag_config is not None if is_wide: # TODO(hongsheng.jhs): change to embedding_dim to 1 @@ -607,24 +641,45 @@ def __init__( if input_tile_emb: if feature.is_user_feat: - add_embedding_bag_config( - emb_bag_configs=emb_bag_configs_user, + _add_embedding_bag_config( + emb_bag_configs=mc_emb_bag_configs_user + if mc_module + else emb_bag_configs_user, emb_bag_config=emb_bag_config, ) + if mc_module: + _add_mc_module( + mc_modules_user, emb_bag_config.name, mc_module + ) + self.has_mc_sparse_user = True + else: + self.has_sparse_user = True else: - add_embedding_bag_config( - emb_bag_configs=emb_bag_configs_item, + _add_embedding_bag_config( + emb_bag_configs=mc_emb_bag_configs_item + if mc_module + else emb_bag_configs_item, emb_bag_config=emb_bag_config, ) + if mc_module: + _add_mc_module( + mc_modules_item, emb_bag_config.name, mc_module + ) + self.has_mc_sparse = True + else: + self.has_sparse = True else: - add_embedding_bag_config( - emb_bag_configs=emb_bag_configs, + _add_embedding_bag_config( + emb_bag_configs=mc_emb_bag_configs + if mc_module + else emb_bag_configs, emb_bag_config=emb_bag_config, ) - if input_tile_emb and feature.is_user_feat: - self.has_sparse_user = True - else: - self.has_sparse = True + if mc_module: + _add_mc_module(mc_modules, emb_bag_config.name, mc_module) + self.has_mc_sparse = True + else: + self.has_sparse = True if shared_feature_flag[name]: shared_name = shared_name + "@" + emb_bag_config.name @@ -656,10 +711,37 @@ def __init__( self.ebc_item = EmbeddingBagCollection( list(emb_bag_configs_item.values()), device=device ) + if self.has_mc_sparse_user: + self.mc_ebc_user = ManagedCollisionEmbeddingBagCollection( + EmbeddingBagCollection( + list(mc_emb_bag_configs_user.values()), device=device + ), + ManagedCollisionCollection( + mc_modules_user, list(mc_emb_bag_configs_user.values()) + ), + ) + if self.has_mc_sparse: + self.mc_ebc_item = ManagedCollisionEmbeddingBagCollection( + EmbeddingBagCollection( + list(mc_emb_bag_configs_item.values()), device=device + ), + ManagedCollisionCollection( + mc_modules_item, list(mc_emb_bag_configs_item.values()) + ), + ) else: self.ebc = EmbeddingBagCollection( list(emb_bag_configs.values()), device=device ) + if self.has_mc_sparse: + self.mc_ebc = ManagedCollisionEmbeddingBagCollection( + EmbeddingBagCollection( + list(mc_emb_bag_configs.values()), device=device + ), + ManagedCollisionCollection( + mc_modules, list(mc_emb_bag_configs.values()) + ), + ) def group_dims(self, group_name: str) -> List[int]: """Output dimension of each feature in a feature group.""" @@ -689,6 +771,12 @@ def forward( else: kts.append(self.ebc(sparse_feature)) + if self.has_mc_sparse: + if is_input_tile_emb(): + kts.append(self.mc_ebc_item(sparse_feature)[0]) + else: + kts.append(self.mc_ebc(sparse_feature)[0]) + if self.has_sparse_user: keyed_tensor_user = self.ebc_user(sparse_feature_user) values_tile = keyed_tensor_user.values().tile(batch_size, 1) @@ -699,6 +787,16 @@ def forward( ) kts.append(keyed_tensor_user_tile) + if self.has_mc_sparse_user: + keyed_tensor_user = self.mc_ebc_user(sparse_feature_user)[0] + values_tile = keyed_tensor_user.values().tile(batch_size, 1) + keyed_tensor_user_tile = KeyedTensor( + keys=keyed_tensor_user.keys(), + length_per_key=keyed_tensor_user.length_per_key(), + values=values_tile, + ) + kts.append(keyed_tensor_user_tile) + if self.has_dense: kts.append(dense_feature) @@ -734,9 +832,13 @@ def __init__( device = torch.device("meta") name_to_feature = {x.name: x for x in features} dim_to_emb_configs = defaultdict(OrderedDict) + dim_to_mc_emb_configs = defaultdict(OrderedDict) + dim_to_mc_modules = defaultdict(OrderedDict) self.has_sparse = False self.has_sparse_user = False + self.has_mc_sparse = False + self.has_mc_sparse_user = False self.has_dense = False self.has_dense_user = False self.has_sequence_dense = False @@ -768,6 +870,10 @@ def __init__( need_input_tile_emb = is_input_tile_emb() dim_to_emb_configs_user = defaultdict(OrderedDict) dim_to_emb_configs_item = defaultdict(OrderedDict) + dim_to_mc_emb_configs_user = defaultdict(OrderedDict) + dim_to_mc_emb_configs_item = defaultdict(OrderedDict) + dim_to_mc_modules_user = defaultdict(OrderedDict) + dim_to_mc_modules_item = defaultdict(OrderedDict) for feature_group in feature_groups: query_dim = 0 @@ -785,35 +891,71 @@ def __init__( if feature.is_sparse: output_dim = feature.output_dim emb_config = feature.emb_config + mc_module = feature.mc_module(device) + assert emb_config is not None # we may/could modify ec name at feat_to_group_to_emb_name emb_config.name = feat_to_group_to_emb_name[name][group_name] - assert emb_config is not None - emb_configs = dim_to_emb_configs[emb_config.embedding_dim] + embedding_dim = emb_config.embedding_dim if need_input_tile_emb: if feature.is_user_feat: - add_embedding_config( - emb_configs=dim_to_emb_configs_user[ - emb_config.embedding_dim - ], + emb_configs = ( + dim_to_mc_emb_configs_user[embedding_dim] + if mc_module + else dim_to_emb_configs_user[embedding_dim] + ) + _add_embedding_config( + emb_configs=emb_configs, emb_config=emb_config, ) + if mc_module: + _add_mc_module( + dim_to_mc_modules_user[embedding_dim], + emb_config.name, + mc_module, + ) + self.has_mc_sparse_user = True + else: + self.has_sparse_user = True else: - add_embedding_config( - emb_configs=dim_to_emb_configs_item[ - emb_config.embedding_dim - ], + emb_configs = ( + dim_to_mc_emb_configs_item[embedding_dim] + if mc_module + else dim_to_emb_configs_item[embedding_dim] + ) + _add_embedding_config( + emb_configs=emb_configs, emb_config=emb_config, ) + if mc_module: + _add_mc_module( + dim_to_mc_modules_item[embedding_dim], + emb_config.name, + mc_module, + ) + self.has_mc_sparse = True + else: + self.has_sparse = True else: - add_embedding_config( + emb_configs = ( + dim_to_mc_emb_configs[embedding_dim] + if mc_module + else dim_to_emb_configs[embedding_dim] + ) + _add_embedding_config( emb_configs=emb_configs, emb_config=emb_config, ) - if need_input_tile_emb and feature.is_user_feat: - self.has_sparse_user = True - else: - self.has_sparse = True + if mc_module: + _add_mc_module( + dim_to_mc_modules[embedding_dim], + emb_config.name, + mc_module, + ) + self.has_mc_sparse = True + else: + self.has_sparse = True + if shared_feature_flag[name]: shared_name = shared_name + "@" + emb_config.name else: @@ -857,12 +999,42 @@ def __init__( self.ec_list_item.append( EmbeddingCollection(list(emb_configs.values()), device=device) ) + self.mc_ec_list_user = nn.ModuleList() + for k, emb_configs in dim_to_mc_emb_configs_user.items(): + self.mc_ec_list_user.append( + ManagedCollisionEmbeddingCollection( + EmbeddingCollection(list(emb_configs.values()), device=device), + ManagedCollisionCollection( + dim_to_mc_modules_user[k], list(emb_configs.values()) + ), + ) + ) + self.mc_ec_list_item = nn.ModuleList() + for k, emb_configs in dim_to_mc_emb_configs_item.items(): + self.mc_ec_list_item.append( + ManagedCollisionEmbeddingCollection( + EmbeddingCollection(list(emb_configs.values()), device=device), + ManagedCollisionCollection( + dim_to_mc_modules_item[k], list(emb_configs.values()) + ), + ) + ) else: self.ec_list = nn.ModuleList() for _, emb_configs in dim_to_emb_configs.items(): self.ec_list.append( EmbeddingCollection(list(emb_configs.values()), device=device) ) + self.mc_ec_list = nn.ModuleList() + for k, emb_configs in dim_to_mc_emb_configs.items(): + self.mc_ec_list.append( + ManagedCollisionEmbeddingCollection( + EmbeddingCollection(list(emb_configs.values()), device=device), + ManagedCollisionCollection( + dim_to_mc_modules[k], list(emb_configs.values()) + ), + ) + ) def group_dims(self, group_name: str) -> List[int]: """Output dimension of each feature in a feature group.""" @@ -911,10 +1083,22 @@ def forward( for ec in self.ec_list: sparse_jt_dict_list.append(ec(sparse_feature)) + if self.has_mc_sparse: + if need_input_tile_emb: + for ec in self.mc_ec_list_item: + sparse_jt_dict_list.append(ec(sparse_feature)[0]) + else: + for ec in self.mc_ec_list: + sparse_jt_dict_list.append(ec(sparse_feature)[0]) + if self.has_sparse_user: for ec in self.ec_list_user: sparse_jt_dict_list.append(ec(sparse_feature_user)) + if self.has_mc_sparse_user: + for ec in self.mc_ec_list_user: + sparse_jt_dict_list.append(ec(sparse_feature_user)[0]) + sparse_jt_dict = _merge_list_of_jt_dict(sparse_jt_dict_list) if self.has_dense: diff --git a/tzrec/modules/embedding_test.py b/tzrec/modules/embedding_test.py index 265a98d..5ec1b87 100644 --- a/tzrec/modules/embedding_test.py +++ b/tzrec/modules/embedding_test.py @@ -31,16 +31,28 @@ from tzrec.utils.test_util import TestGraphType, create_test_module -def _create_test_features(): +def _create_test_features(has_zch=False): + cat_a_kwargs = {} + cat_b_kwargs = {} + if has_zch: + cat_a_kwargs["zch"] = feature_pb2.ZeroCollisionHash( + zch_size=100, lfu=feature_pb2.LFU_EvictionPolicy() + ) + cat_b_kwargs["zch"] = feature_pb2.ZeroCollisionHash( + zch_size=1000, lru=feature_pb2.LRU_EvictionPolicy() + ) + else: + cat_a_kwargs["num_buckets"] = 100 + cat_b_kwargs["num_buckets"] = 1000 feature_cfgs = [ feature_pb2.FeatureConfig( id_feature=feature_pb2.IdFeature( - feature_name="cat_a", embedding_dim=16, num_buckets=100 + feature_name="cat_a", embedding_dim=16, **cat_a_kwargs ) ), feature_pb2.FeatureConfig( id_feature=feature_pb2.IdFeature( - feature_name="cat_b", embedding_dim=8, num_buckets=1000 + feature_name="cat_b", embedding_dim=8, **cat_b_kwargs ) ), feature_pb2.FeatureConfig( @@ -56,22 +68,34 @@ def _create_test_features(): return features -def _create_test_sequence_features(): +def _create_test_sequence_features(has_zch=False): + cat_a_kwargs = {} + cat_b_kwargs = {} + if has_zch: + cat_a_kwargs["zch"] = feature_pb2.ZeroCollisionHash( + zch_size=100, lfu=feature_pb2.LFU_EvictionPolicy() + ) + cat_b_kwargs["zch"] = feature_pb2.ZeroCollisionHash( + zch_size=1000, lru=feature_pb2.LRU_EvictionPolicy() + ) + else: + cat_a_kwargs["num_buckets"] = 100 + cat_b_kwargs["num_buckets"] = 1000 feature_cfgs = [ feature_pb2.FeatureConfig( id_feature=feature_pb2.IdFeature( feature_name="cat_a", embedding_dim=16, - num_buckets=100, expression="item:cat_a", + **cat_a_kwargs, ) ), feature_pb2.FeatureConfig( id_feature=feature_pb2.IdFeature( feature_name="cat_b", embedding_dim=8, - num_buckets=1000, expression="item:cat_b", + **cat_b_kwargs, ) ), feature_pb2.FeatureConfig( @@ -88,7 +112,7 @@ def _create_test_sequence_features(): feature_name="cat_a", expression="item:cat_a", embedding_dim=16, - num_buckets=100, + **cat_a_kwargs, ) ), feature_pb2.SeqFeatureConfig( @@ -96,7 +120,7 @@ def _create_test_sequence_features(): feature_name="cat_b", expression="item:cat_b", embedding_dim=8, - num_buckets=1000, + **cat_b_kwargs, ) ), feature_pb2.SeqFeatureConfig( @@ -117,7 +141,7 @@ def _create_test_sequence_features(): feature_name="cat_a", expression="item:cat_a", embedding_dim=16, - num_buckets=100, + **cat_b_kwargs, ) ), feature_pb2.SeqFeatureConfig( @@ -172,7 +196,7 @@ def tearDown(self): [[TestGraphType.NORMAL], [TestGraphType.FX_TRACE], [TestGraphType.JIT_SCRIPT]] ) def test_embedding_group_impl(self, graph_type) -> None: - features = _create_test_sequence_features() + features = _create_test_features() feature_groups = [ model_pb2.FeatureGroupConfig( group_name="wide", @@ -183,27 +207,6 @@ def test_embedding_group_impl(self, graph_type) -> None: group_name="deep", feature_names=["cat_a", "cat_b", "int_a"], group_type=model_pb2.FeatureGroupType.DEEP, - sequence_groups=[ - model_pb2.SeqGroupConfig( - group_name="click", - feature_names=[ - "cat_a", - "cat_b", - "int_a", - "click_seq__cat_a", - "click_seq__cat_b", - "click_seq__int_a", - ], - ), - ], - sequence_encoders=[ - seq_encoder_pb2.SeqEncoderConfig( - din_encoder=seq_encoder_pb2.DINEncoder( - input="click", - attn_mlp=module_pb2.MLP(hidden_units=[128, 64]), - ) - ), - ], ), ] embedding_group = EmbeddingGroupImpl( @@ -221,20 +224,12 @@ def test_embedding_group_impl(self, graph_type) -> None: self.assertDictEqual( embedding_group.group_feature_dims("deep"), deep_feature_dims ) - embedding_group = create_test_module(embedding_group, graph_type) sparse_feature = KeyedJaggedTensor.from_lengths_sync( - keys=[ - "cat_a", - "cat_b", - "click_seq__cat_a", - "click_seq__cat_b", - "buy_seq__cat_a", - "buy_seq__cat_b", - ], - values=torch.tensor(list(range(24))), - lengths=torch.tensor([1, 1, 1, 1, 3, 3, 3, 3, 2, 2, 2, 2]), + keys=["cat_a", "cat_b"], + values=torch.tensor([1, 2, 3, 4, 5, 6, 7]), + lengths=torch.tensor([1, 2, 1, 3]), ) dense_feature = KeyedTensor.from_tensor_list( keys=["int_a"], tensors=[torch.tensor([[0.2], [0.3]])] @@ -252,8 +247,58 @@ def test_embedding_group_impl(self, graph_type) -> None: @parameterized.expand( [[TestGraphType.NORMAL], [TestGraphType.FX_TRACE], [TestGraphType.JIT_SCRIPT]] ) - def test_sequence_embedding_group_impl(self, graph_type) -> None: - features = _create_test_sequence_features() + def test_zch_embedding_group_impl(self, graph_type) -> None: + features = _create_test_features(has_zch=True) + # TODO(hongsheng.jhs) zch not support wide group now. + feature_groups = [ + model_pb2.FeatureGroupConfig( + group_name="deep", + feature_names=["cat_a", "cat_b", "int_a"], + group_type=model_pb2.FeatureGroupType.DEEP, + ), + ] + embedding_group = EmbeddingGroupImpl( + features, feature_groups, device=torch.device("cpu") + ) + self.assertEqual(embedding_group.group_dims("deep"), [16, 8, 1]) + self.assertEqual(embedding_group.group_total_dim("deep"), 25) + deep_feature_dims = OrderedDict({"cat_a": 16, "cat_b": 8, "int_a": 1}) + self.assertDictEqual( + embedding_group.group_feature_dims("deep"), deep_feature_dims + ) + + if graph_type != TestGraphType.NORMAL: + embedding_group.eval() + embedding_group = create_test_module(embedding_group, graph_type) + + sparse_feature = KeyedJaggedTensor.from_lengths_sync( + keys=["cat_a", "cat_b"], + values=torch.tensor([1, 2, 3, 4, 5, 6, 7]), + lengths=torch.tensor([1, 2, 1, 3]), + ) + dense_feature = KeyedTensor.from_tensor_list( + keys=["int_a"], tensors=[torch.tensor([[0.2], [0.3]])] + ) + result = embedding_group( + sparse_feature, + dense_feature, + EMPTY_KJT, + EMPTY_KT, + ) + self.assertEqual(result["deep"].size(), (2, 25)) + + @parameterized.expand( + [ + [TestGraphType.NORMAL, False], + [TestGraphType.FX_TRACE, False], + [TestGraphType.JIT_SCRIPT, False], + [TestGraphType.NORMAL, True], + [TestGraphType.FX_TRACE, True], + [TestGraphType.JIT_SCRIPT, True], + ] + ) + def test_sequence_embedding_group_impl(self, graph_type, has_zch=False) -> None: + features = _create_test_sequence_features(has_zch) feature_groups = [ model_pb2.FeatureGroupConfig( group_name="click", @@ -341,6 +386,8 @@ def test_sequence_embedding_group_impl(self, graph_type) -> None: embedding_group.group_total_dim("deep___click_no_query.query"), 0 ) + if has_zch and graph_type != TestGraphType.NORMAL: + embedding_group = embedding_group.eval() embedding_group = create_test_module(embedding_group, graph_type) sparse_feature = KeyedJaggedTensor.from_lengths_sync( @@ -563,11 +610,20 @@ def test_embedding_group(self, graph_type) -> None: self.assertEqual(result["buy.sequence_length"].size(), (2,)) @parameterized.expand( - [[TestGraphType.NORMAL], [TestGraphType.FX_TRACE], [TestGraphType.JIT_SCRIPT]] + [ + [TestGraphType.NORMAL, False], + [TestGraphType.FX_TRACE, False], + [TestGraphType.JIT_SCRIPT, False], + [TestGraphType.NORMAL, True], + [TestGraphType.FX_TRACE, True], + [TestGraphType.JIT_SCRIPT, True], + ] ) - def test_sequence_embedding_group_impl_input_tile(self, graph_type) -> None: + def test_sequence_embedding_group_impl_input_tile( + self, graph_type, has_zch=False + ) -> None: os.environ["INPUT_TILE"] = "2" - features = _create_test_sequence_features() + features = _create_test_sequence_features(has_zch=has_zch) feature_groups = [ model_pb2.FeatureGroupConfig( group_name="click", @@ -599,6 +655,8 @@ def test_sequence_embedding_group_impl_input_tile(self, graph_type) -> None: self.assertEqual(embedding_group.group_total_dim("buy.sequence"), 17) self.assertEqual(embedding_group.group_total_dim("buy.query"), 17) + if has_zch and graph_type != TestGraphType.NORMAL: + embedding_group = embedding_group.eval() embedding_group = create_test_module(embedding_group, graph_type) sparse_feature = KeyedJaggedTensor.from_lengths_sync( @@ -638,11 +696,20 @@ def test_sequence_embedding_group_impl_input_tile(self, graph_type) -> None: self.assertEqual(result["buy.sequence_length"].size(), (2,)) @parameterized.expand( - [[TestGraphType.NORMAL], [TestGraphType.FX_TRACE], [TestGraphType.JIT_SCRIPT]] + [ + [TestGraphType.NORMAL, False], + [TestGraphType.FX_TRACE, False], + [TestGraphType.JIT_SCRIPT, False], + [TestGraphType.NORMAL, True], + [TestGraphType.FX_TRACE, True], + [TestGraphType.JIT_SCRIPT, True], + ] ) - def test_sequence_embedding_group_impl_input_tile_emb(self, graph_type) -> None: + def test_sequence_embedding_group_impl_input_tile_emb( + self, graph_type, has_zch=False + ) -> None: os.environ["INPUT_TILE"] = "3" - features = _create_test_sequence_features() + features = _create_test_sequence_features(has_zch=has_zch) feature_groups = [ model_pb2.FeatureGroupConfig( group_name="click", @@ -674,6 +741,8 @@ def test_sequence_embedding_group_impl_input_tile_emb(self, graph_type) -> None: self.assertEqual(embedding_group.group_total_dim("buy.sequence"), 17) self.assertEqual(embedding_group.group_total_dim("buy.query"), 17) + if has_zch and graph_type != TestGraphType.NORMAL: + embedding_group = embedding_group.eval() embedding_group = create_test_module(embedding_group, graph_type) sparse_feature = KeyedJaggedTensor.from_lengths_sync( diff --git a/tzrec/protos/feature.proto b/tzrec/protos/feature.proto index c2d8a30..ce580dc 100644 --- a/tzrec/protos/feature.proto +++ b/tzrec/protos/feature.proto @@ -2,6 +2,41 @@ syntax = "proto2"; package tzrec.protos; +// LFU: evict_score = access_cnt +message LFU_EvictionPolicy { +} + +// LRU: evict_score = 1 / pow((current_iter - last_access_iter), decay_exponent) +message LRU_EvictionPolicy { + // decay rate is access step + optional float decay_exponent = 1 [default = 1.0]; +} + +// DistanceLFU: evict_score = access_cnt / pow((current_iter - last_access_iter), decay_exponent) +message DistanceLFU_EvictionPolicy { + // decay rate is access step + optional float decay_exponent = 1 [default = 1.0]; +} + +message ZeroCollisionHash { + // zero collision size + required uint64 zch_size = 1; + // evict interval steps + optional uint32 eviction_interval = 2 [default = 5]; + // evict policy + oneof eviction_policy { + LFU_EvictionPolicy lfu = 101; + LRU_EvictionPolicy lru = 102; + DistanceLFU_EvictionPolicy distance_lfu = 103; + } + // lambda function string used to filter incoming ids before update/eviction. experimental feature. + // [input: Tensor] the function takes as input a 1-d tensor of unique id counts. + // [output1: Tensor] the function returns a boolean_mask or index array of corresponding elements in the input tensor that pass the filter. + // [output2: float, Tensor] the function returns the threshold that will be used to filter ids before update/eviction. all values <= this value will be filtered out. + optional string threshold_filtering_func = 3; +} + + message IdFeature { // feature name, e.g. item_id required string feature_name = 1; @@ -34,6 +69,8 @@ message IdFeature { optional string init_fn = 14; // mask value in training progress optional bool use_mask = 15; + // zero collision hash + optional ZeroCollisionHash zch = 16; // default value when fg_encoded = true, // when use pai-fg, you do not need to set the param. @@ -108,6 +145,8 @@ message ComboFeature { optional string init_fn = 13; // mask value in training progress optional bool use_mask = 14; + // zero collision hash + optional ZeroCollisionHash zch = 15; // default value when fg_encoded = true, // when use pai-fg, you do not need to set the param. @@ -166,6 +205,8 @@ message LookupFeature { // mask value in training progress optional bool use_mask = 25; + // zero collision hash + optional ZeroCollisionHash zch = 26; // default value when fg_encoded = true, // when use pai-fg, you do not need to set the param. @@ -226,6 +267,8 @@ message MatchFeature { optional uint32 value_dim = 20; // mask value in training progress optional bool use_mask = 25; + // zero collision hash + optional ZeroCollisionHash zch = 26; // default value when fg_encoded = true, // when use pai-fg, you do not need to set the param. @@ -393,6 +436,8 @@ message SequenceIdFeature { optional uint32 value_dim = 15 [default = 1]; // mask value in training progress optional bool use_mask = 20; + // zero collision hash + optional ZeroCollisionHash zch = 21; // default value when fg_encoded = true, // when use pai-fg, you do not need to set the param. diff --git a/tzrec/tests/configs/multi_tower_din_zch_fg_mock.config b/tzrec/tests/configs/multi_tower_din_zch_fg_mock.config new file mode 100644 index 0000000..0789a39 --- /dev/null +++ b/tzrec/tests/configs/multi_tower_din_zch_fg_mock.config @@ -0,0 +1,358 @@ +train_input_path: "" +eval_input_path: "" +model_dir: "experiments/multi_tower_din_mock_fg" +train_config { + sparse_optimizer { + adagrad_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + dense_optimizer { + adam_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + num_epochs: 1 +} +eval_config { +} +data_config { + batch_size: 8192 + dataset_type: ParquetDataset + fg_encoded: false + label_fields: "clk" + num_workers: 8 +} +feature_configs { + id_feature { + feature_name: "user_id" + expression: "user:user_id" + zch { + zch_size: 1000000 + eviction_interval: 2 + lfu {} + } + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "item_id" + expression: "item:item_id" + zch { + zch_size: 10000 + eviction_interval: 2 + lru {} + } + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "id_3" + expression: "item:id_3" + vocab_list: ["a", "b", "c"] + embedding_dim: 8 + } +} +feature_configs { + id_feature { + feature_name: "id_4" + expression: "item:id_4" + hash_bucket_size: 100 + embedding_dim: 16 + embedding_name: "id_4_emb" + } +} +feature_configs { + id_feature { + feature_name: "id_5" + expression: "item:id_5" + hash_bucket_size: 100 + embedding_dim: 16 + embedding_name: "id_4_emb" + } +} +feature_configs { + raw_feature { + feature_name: "raw_1" + expression: "item:raw_1" + boundaries: [0.1, 0.2, 0.3, 0.4] + embedding_dim: 16 + } +} +feature_configs { + raw_feature { + feature_name: "raw_2" + expression: "item:raw_2" + } +} +feature_configs { + raw_feature { + feature_name: "raw_3" + expression: "user:raw_3" + value_dim: 4 + } +} +feature_configs { + raw_feature { + feature_name: "raw_4" + expression: "user:raw_4" + value_dim: 4 + boundaries: [0.1, 0.2, 0.3, 0.4] + embedding_dim: 16 + } +} +feature_configs { + raw_feature { + feature_name: "raw_5" + expression: "user:raw_5" + } +} +feature_configs { + raw_feature { + feature_name: "raw_6_id" + expression: "item:raw_6" + boundaries: [0.1, 0.2, 0.3, 0.4] + embedding_dim: 16 + } +} +feature_configs { + combo_feature { + feature_name: "combo_1" + expression: ["user:id_1", "item:id_2"] + hash_bucket_size: 1000000 + embedding_dim: 16 + } +} +feature_configs { + lookup_feature { + feature_name: "lookup_1" + map: "user:map_1" + key: "item:id_2" + } +} +feature_configs { + lookup_feature { + feature_name: "lookup_2" + map: "user:map_2" + key: "item:id_2" + hash_bucket_size: 10000 + embedding_dim: 8 + } +} +feature_configs { + lookup_feature { + feature_name: "lookup_3" + map: "user:map_3" + key: "item:id_2" + num_buckets: 1000 + embedding_dim: 8 + } +} +feature_configs { + lookup_feature { + feature_name: "lookup_4" + map: "user:map_4" + key: "item:id_2" + vocab_list: ["e", "f", "g"] + embedding_dim: 16 + } +} +feature_configs { + lookup_feature { + feature_name: "lookup_5" + map: "user:map_5" + key: "feature:raw_6_id" + vocab_list: ["e", "f", "g"] + embedding_dim: 16 + } +} +feature_configs { + match_feature { + feature_name: "match_1" + nested_map: "user:nested_map_1" + pkey: "item:id_2" + skey: "item:id_3" + } +} +feature_configs { + match_feature { + feature_name: "match_2" + nested_map: "user:nested_map_2" + pkey: "item:id_2" + skey: "item:id_3" + hash_bucket_size: 100000 + embedding_dim: 16 + } +} +feature_configs { + match_feature { + feature_name: "match_3" + nested_map: "user:nested_map_3" + pkey: "item:id_2" + skey: "item:id_3" + num_buckets: 10000 + embedding_dim: 8 + } +} +feature_configs { + match_feature { + feature_name: "match_4" + nested_map: "user:nested_map_4" + pkey: "item:id_2" + skey: "item:id_3" + vocab_list: ["e", "f", "g"] + embedding_dim: 16 + } +} +feature_configs { + expr_feature { + feature_name: "expr_1" + expression: "raw_1 + raw_2" + variables: ["item:raw_1", "item:raw_2"] + } +} +feature_configs { + sequence_feature { + sequence_name: "click_50_seq" + sequence_length: 50 + sequence_delim: "|" + features { + id_feature { + feature_name: "item_id" + expression: "item:item_id" + zch { + zch_size: 10000 + eviction_interval: 2 + lru {} + } + embedding_dim: 16 + } + } + features { + id_feature { + feature_name: "id_3" + expression: "item:id_3" + vocab_list: ["a", "b", "c"] + embedding_dim: 8 + } + } + features { + raw_feature { + feature_name: "raw_1" + expression: "item:raw_1" + boundaries: [0.1, 0.2, 0.3, 0.4] + embedding_dim: 16 + } + } + features { + raw_feature { + feature_name: "raw_2" + expression: "item:raw_2" + } + } + } +} +feature_configs { + sequence_id_feature { + feature_name: "buy_50_user_id_seq" + sequence_length: 50 + sequence_delim: "|" + expression: "item:buy_50_user_id_seq" + zch { + zch_size: 1000000 + eviction_interval: 2 + lfu {} + } + embedding_dim: 16 + } +} +feature_configs { + sequence_raw_feature { + feature_name: "buy_50_raw_5_seq" + sequence_length: 50 + sequence_delim: "|" + expression: "item:buy_50_raw_5_seq" + } +} +model_config { + feature_groups { + group_name: "deep" + feature_names: "user_id" + feature_names: "item_id" + feature_names: "id_3" + feature_names: "id_4" + feature_names: "id_5" + feature_names: "raw_1" + feature_names: "raw_2" + feature_names: "raw_3" + feature_names: "raw_4" + feature_names: "raw_5" + feature_names: "combo_1" + feature_names: "lookup_1" + feature_names: "lookup_2" + feature_names: "lookup_3" + feature_names: "lookup_4" + feature_names: "lookup_5" + feature_names: "match_1" + feature_names: "match_2" + feature_names: "match_3" + feature_names: "match_4" + feature_names: "expr_1" + group_type: DEEP + } + feature_groups { + group_name: "seq" + feature_names: "item_id" + feature_names: "id_3" + feature_names: "raw_1" + feature_names: "raw_2" + feature_names: "click_50_seq__item_id" + feature_names: "click_50_seq__id_3" + feature_names: "click_50_seq__raw_1" + feature_names: "click_50_seq__raw_2" + group_type: SEQUENCE + } + feature_groups { + group_name: "seq_item" + feature_names: "user_id" + feature_names: "raw_5" + feature_names: "buy_50_user_id_seq" + feature_names: "buy_50_raw_5_seq" + group_type: SEQUENCE + } + multi_tower_din { + towers { + input: 'deep' + mlp { + hidden_units: [512, 256, 128] + } + } + din_towers { + input: 'seq' + attn_mlp { + hidden_units: [256, 64] + } + } + din_towers { + input: 'seq_item' + attn_mlp { + hidden_units: [256, 64] + } + } + final { + hidden_units: [64] + } + } + metrics { + auc {} + } + losses { + binary_cross_entropy {} + } +} diff --git a/tzrec/tests/configs/multi_tower_din_zch_trt_fg_mock.config b/tzrec/tests/configs/multi_tower_din_zch_trt_fg_mock.config new file mode 100644 index 0000000..b1578d7 --- /dev/null +++ b/tzrec/tests/configs/multi_tower_din_zch_trt_fg_mock.config @@ -0,0 +1,358 @@ +train_input_path: "" +eval_input_path: "" +model_dir: "experiments/multi_tower_din_mock_fg" +train_config { + sparse_optimizer { + adagrad_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + dense_optimizer { + adam_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + num_epochs: 1 +} +eval_config { +} +data_config { + batch_size: 8192 + dataset_type: ParquetDataset + fg_encoded: false + label_fields: "clk" + num_workers: 8 +} +feature_configs { + id_feature { + feature_name: "user_id" + expression: "user:user_id" + zch { + zch_size: 1000000 + eviction_interval: 2 + lfu {} + } + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "item_id" + expression: "item:item_id" + zch { + zch_size: 10000 + eviction_interval: 2 + lru {} + } + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "id_3" + expression: "item:id_3" + vocab_list: ["a", "b", "c"] + embedding_dim: 8 + } +} +feature_configs { + id_feature { + feature_name: "id_4" + expression: "item:id_4" + hash_bucket_size: 100 + embedding_dim: 16 + embedding_name: "id_4_emb" + } +} +feature_configs { + id_feature { + feature_name: "id_5" + expression: "item:id_5" + hash_bucket_size: 100 + embedding_dim: 16 + embedding_name: "id_4_emb" + } +} +feature_configs { + raw_feature { + feature_name: "raw_1" + expression: "item:raw_1" + boundaries: [0.1, 0.2, 0.3, 0.4] + embedding_dim: 16 + } +} +feature_configs { + raw_feature { + feature_name: "raw_2" + expression: "item:raw_2" + } +} +feature_configs { + raw_feature { + feature_name: "raw_3" + expression: "user:raw_3" + value_dim: 4 + } +} +feature_configs { + raw_feature { + feature_name: "raw_4" + expression: "user:raw_4" + value_dim: 4 + boundaries: [0.1, 0.2, 0.3, 0.4] + embedding_dim: 16 + } +} +feature_configs { + raw_feature { + feature_name: "raw_5" + expression: "user:raw_5" + } +} +feature_configs { + raw_feature { + feature_name: "raw_6_id" + expression: "item:raw_6" + boundaries: [0.1, 0.2, 0.3, 0.4] + embedding_dim: 16 + } +} +feature_configs { + combo_feature { + feature_name: "combo_1" + expression: ["user:id_1", "item:id_2"] + hash_bucket_size: 1000000 + embedding_dim: 16 + } +} +feature_configs { + lookup_feature { + feature_name: "lookup_1" + map: "user:map_1" + key: "item:id_2" + } +} +feature_configs { + lookup_feature { + feature_name: "lookup_2" + map: "user:map_2" + key: "item:id_2" + hash_bucket_size: 10000 + embedding_dim: 8 + } +} +feature_configs { + lookup_feature { + feature_name: "lookup_3" + map: "user:map_3" + key: "item:id_2" + num_buckets: 1000 + embedding_dim: 8 + } +} +feature_configs { + lookup_feature { + feature_name: "lookup_4" + map: "user:map_4" + key: "item:id_2" + vocab_list: ["e", "f", "g"] + embedding_dim: 16 + } +} +feature_configs { + lookup_feature { + feature_name: "lookup_5" + map: "user:map_5" + key: "feature:raw_6_id" + vocab_list: ["e", "f", "g"] + embedding_dim: 16 + } +} +feature_configs { + match_feature { + feature_name: "match_1" + nested_map: "user:nested_map_1" + pkey: "item:id_2" + skey: "item:id_3" + } +} +feature_configs { + match_feature { + feature_name: "match_2" + nested_map: "user:nested_map_2" + pkey: "item:id_2" + skey: "item:id_3" + hash_bucket_size: 100000 + embedding_dim: 16 + } +} +feature_configs { + match_feature { + feature_name: "match_3" + nested_map: "user:nested_map_3" + pkey: "item:id_2" + skey: "item:id_3" + num_buckets: 10000 + embedding_dim: 8 + } +} +feature_configs { + match_feature { + feature_name: "match_4" + nested_map: "user:nested_map_4" + pkey: "item:id_2" + skey: "item:id_3" + vocab_list: ["e", "f", "g"] + embedding_dim: 16 + } +} +feature_configs { + expr_feature { + feature_name: "expr_1" + expression: "raw_1 + raw_2" + variables: ["item:raw_1", "item:raw_2"] + } +} +feature_configs { + sequence_feature { + sequence_name: "click_50_seq" + sequence_length: 50 + sequence_delim: "|" + features { + id_feature { + feature_name: "item_id" + expression: "item:item_id" + zch { + zch_size: 10000 + eviction_interval: 2 + lru {} + } + embedding_dim: 16 + } + } + features { + id_feature { + feature_name: "id_3" + expression: "item:id_3" + vocab_list: ["a", "b", "c"] + embedding_dim: 8 + } + } + features { + raw_feature { + feature_name: "raw_1" + expression: "item:raw_1" + boundaries: [0.1, 0.2, 0.3, 0.4] + embedding_dim: 16 + } + } + features { + raw_feature { + feature_name: "raw_2" + expression: "item:raw_2" + } + } + } +} +feature_configs { + sequence_id_feature { + feature_name: "buy_50_user_id_seq" + sequence_length: 50 + sequence_delim: "|" + expression: "item:buy_50_user_id_seq" + zch { + zch_size: 1000000 + eviction_interval: 2 + lfu {} + } + embedding_dim: 16 + } +} +feature_configs { + sequence_raw_feature { + feature_name: "buy_50_raw_5_seq" + sequence_length: 50 + sequence_delim: "|" + expression: "item:buy_50_raw_5_seq" + } +} +model_config { + feature_groups { + group_name: "deep" + feature_names: "user_id" + feature_names: "item_id" + feature_names: "id_3" + feature_names: "id_4" + feature_names: "id_5" + feature_names: "raw_1" + feature_names: "raw_2" + feature_names: "raw_3" + feature_names: "raw_4" + feature_names: "raw_5" + feature_names: "combo_1" + feature_names: "lookup_1" + feature_names: "lookup_2" + feature_names: "lookup_3" + feature_names: "lookup_4" + feature_names: "lookup_5" + feature_names: "match_1" + feature_names: "match_2" + feature_names: "match_3" + feature_names: "match_4" + feature_names: "expr_1" + group_type: DEEP + } + feature_groups { + group_name: "seq" + feature_names: "item_id" + feature_names: "id_3" + feature_names: "raw_1" + feature_names: "raw_2" + feature_names: "click_50_seq__item_id" + feature_names: "click_50_seq__id_3" + feature_names: "click_50_seq__raw_1" + feature_names: "click_50_seq__raw_2" + group_type: SEQUENCE + } + feature_groups { + group_name: "seq_item" + feature_names: "user_id" + feature_names: "raw_5" + feature_names: "buy_50_user_id_seq" + feature_names: "buy_50_raw_5_seq" + group_type: SEQUENCE + } + multi_tower_din_trt { + towers { + input: 'deep' + mlp { + hidden_units: [512, 256, 128] + } + } + din_towers { + input: 'seq' + attn_mlp { + hidden_units: [256, 64] + } + } + din_towers { + input: 'seq_item' + attn_mlp { + hidden_units: [256, 64] + } + } + final { + hidden_units: [64] + } + } + metrics { + auc {} + } + losses { + binary_cross_entropy {} + } +} diff --git a/tzrec/tests/match_integration_test.py b/tzrec/tests/match_integration_test.py new file mode 100644 index 0000000..442c51a --- /dev/null +++ b/tzrec/tests/match_integration_test.py @@ -0,0 +1,256 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import tempfile +import unittest + +from tzrec.tests import utils + + +class MatchIntegrationTest(unittest.TestCase): + def setUp(self): + self.success = False + if not os.path.exists("./tmp"): + os.makedirs("./tmp") + self.test_dir = tempfile.mkdtemp(prefix="tzrec_", dir="./tmp") + os.chmod(self.test_dir, 0o755) + + def tearDown(self): + if self.success: + if os.path.exists(self.test_dir): + shutil.rmtree(self.test_dir) + + def test_dssm_fg_encoded_train_eval_export(self): + self.success = utils.test_train_eval( + "tzrec/tests/configs/dssm_mock.config", self.test_dir, item_id="item_id" + ) + if self.success: + self.success = utils.test_eval( + os.path.join(self.test_dir, "pipeline.config"), self.test_dir + ) + if self.success: + self.success = utils.test_export( + os.path.join(self.test_dir, "pipeline.config"), self.test_dir + ) + if self.success: + self.success = utils.test_predict( + scripted_model_path=os.path.join(self.test_dir, "export/item"), + predict_input_path=os.path.join(self.test_dir, r"eval_data/\*.parquet"), + predict_output_path=os.path.join(self.test_dir, "predict_result"), + reserved_columns="item_id", + output_columns="item_tower_emb", + test_dir=self.test_dir, + ) + self.assertTrue(self.success) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "train/model.ckpt-63")) + ) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "export/user/scripted_model.pt")) + ) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "export/item/scripted_model.pt")) + ) + + def test_dssm_fg_encoded_variational_dropout(self): + self.success = utils.test_train_eval( + "tzrec/tests/configs/dssm_variational_dropout_mock.config", + self.test_dir, + item_id="item_id", + ) + if self.success: + self.success = utils.test_feature_selection( + os.path.join(self.test_dir, "pipeline.config"), self.test_dir + ) + self.assertTrue(self.success) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "train/model.ckpt-63")) + ) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "output_dir/pipeline.config")) + ) + + def test_dssm_with_fg_train_eval_export(self): + self.success = utils.test_train_eval( + "tzrec/tests/configs/dssm_fg_mock.config", + self.test_dir, + user_id="user_id", + item_id="item_id", + ) + if self.success: + self.success = utils.test_eval( + os.path.join(self.test_dir, "pipeline.config"), self.test_dir + ) + if self.success: + self.success = utils.test_export( + os.path.join(self.test_dir, "pipeline.config"), self.test_dir + ) + if self.success: + self.success = utils.test_predict( + scripted_model_path=os.path.join(self.test_dir, "export/item"), + predict_input_path=os.path.join(self.test_dir, r"item_data/\*.parquet"), + predict_output_path=os.path.join(self.test_dir, "item_emb"), + reserved_columns="item_id", + output_columns="item_tower_emb", + test_dir=self.test_dir, + ) + if self.success: + self.success = utils.test_create_faiss_index( + embedding_input_path=os.path.join( + self.test_dir, r"item_emb/\*.parquet" + ), + index_output_dir=os.path.join(self.test_dir, "export/user"), + id_field="item_id", + embedding_field="item_tower_emb", + test_dir=self.test_dir, + ) + if self.success: + self.success = utils.test_predict( + scripted_model_path=os.path.join(self.test_dir, "export/user"), + predict_input_path=os.path.join(self.test_dir, r"user_data/\*.parquet"), + predict_output_path=os.path.join(self.test_dir, "user_emb"), + reserved_columns="user_id,click_50_seq__item_id", + output_columns="user_tower_emb", + test_dir=self.test_dir, + ) + if self.success: + self.success = utils.test_hitrate( + user_gt_input=os.path.join(self.test_dir, r"user_emb/\*.parquet"), + item_embedding_input=os.path.join( + self.test_dir, r"item_emb/\*.parquet" + ), + total_hitrate_output=os.path.join(self.test_dir, "total_hitrate"), + hitrate_details_output=os.path.join(self.test_dir, "hitrate_details"), + request_id_field="user_id", + gt_items_field="click_50_seq__item_id", + test_dir=self.test_dir, + ) + if self.success: + self.success = utils.test_create_fg_json( + os.path.join(self.test_dir, "pipeline.config"), + fg_output_dir=os.path.join(self.test_dir, "fg_output"), + reserves="clk", + test_dir=self.test_dir, + ) + self.assertTrue(self.success) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "export/user/scripted_model.pt")) + ) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "export/user/faiss_index")) + ) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "export/user/id_mapping")) + ) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "export/item/scripted_model.pt")) + ) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "fg_output/fg.json")) + ) + self.assertTrue( + os.path.exists( + os.path.join(self.test_dir, "fg_output/item_title_tokenizer.json") + ) + ) + + def test_dssm_v2_with_fg_train_eval_export(self): + self.success = utils.test_train_eval( + "tzrec/tests/configs/dssm_v2_fg_mock.config", + self.test_dir, + user_id="user_id", + item_id="item_id", + ) + if self.success: + self.success = utils.test_eval( + os.path.join(self.test_dir, "pipeline.config"), self.test_dir + ) + if self.success: + self.success = utils.test_export( + os.path.join(self.test_dir, "pipeline.config"), self.test_dir + ) + self.assertTrue(self.success) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "export/user/scripted_model.pt")) + ) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "export/item/scripted_model.pt")) + ) + + def test_tdm_train_eval_export(self): + self.success = utils.test_train_eval( + "tzrec/tests/configs/tdm_fg_mock.config", + self.test_dir, + user_id="user_id", + item_id="item_id", + cate_id="cate_id", + ) + if self.success: + self.success = utils.test_eval( + os.path.join(self.test_dir, "pipeline.config"), self.test_dir + ) + if self.success: + self.success = utils.test_export( + os.path.join(self.test_dir, "pipeline.config"), + self.test_dir, + asset_files=os.path.join(self.test_dir, "init_tree/serving_tree"), + ) + if self.success: + self.success = utils.test_predict( + scripted_model_path=os.path.join(self.test_dir, "export/embedding"), + predict_input_path=os.path.join(self.test_dir, r"item_data/\*.parquet"), + predict_output_path=os.path.join(self.test_dir, "item_emb"), + reserved_columns="item_id,cate_id,id_4,id_5,raw_1,raw_2", + output_columns="item_emb", + test_dir=self.test_dir, + ) + if self.success: + self.success = utils.test_tdm_cluster_train_eval( + pipeline_config_path=os.path.join(self.test_dir, "pipeline.config"), + test_dir=self.test_dir, + item_input_path=os.path.join(self.test_dir, r"item_emb/\*.parquet"), + item_id="item_id", + embedding_field="item_emb", + ) + if self.success: + self.success = utils.test_tdm_retrieval( + scripted_model_path=os.path.join(self.test_dir, "export/model"), + eval_data_path=os.path.join(self.test_dir, r"eval_data/\*.parquet"), + retrieval_output_path=os.path.join(self.test_dir, "retrieval_result"), + reserved_columns="user_id,item_id", + test_dir=self.test_dir, + ) + + self.assertTrue(self.success) + self.assertTrue(os.path.exists(os.path.join(self.test_dir, "learnt_tree"))) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "train/eval_result.txt")) + ) + self.assertTrue( + os.path.exists( + os.path.join(self.test_dir, "export/embedding/scripted_model.pt") + ) + ) + self.assertTrue( + os.path.exists( + os.path.join(self.test_dir, "export/model/scripted_model.pt") + ) + ) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "export/model/serving_tree")) + ) + self.assertTrue(os.path.exists(os.path.join(self.test_dir, "retrieval_result"))) + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/tests/train_eval_export_test.py b/tzrec/tests/rank_integration_test.py similarity index 71% rename from tzrec/tests/train_eval_export_test.py rename to tzrec/tests/rank_integration_test.py index 54b8bfd..6084fe4 100644 --- a/tzrec/tests/train_eval_export_test.py +++ b/tzrec/tests/rank_integration_test.py @@ -25,7 +25,7 @@ from tzrec.utils.test_util import dfs_are_close -class TrainEvalExportTest(unittest.TestCase): +class RankIntegrationTest(unittest.TestCase): def setUp(self): self.success = False if not os.path.exists("./tmp"): @@ -41,10 +41,10 @@ def tearDown(self): os.environ.pop("INPUT_TILE", None) os.environ.pop("ENABLE_TRT", None) - def test_multi_tower_din_fg_encoded_train_eval_export(self): - self.success = utils.test_train_eval( - "tzrec/tests/configs/multi_tower_din_mock.config", self.test_dir - ) + def _test_rank_fg_encoded( + self, pipeline_config_path, reserved_columns, output_columns + ): + self.success = utils.test_train_eval(pipeline_config_path, self.test_dir) if self.success: self.success = utils.test_eval( os.path.join(self.test_dir, "pipeline.config"), self.test_dir @@ -58,8 +58,8 @@ def test_multi_tower_din_fg_encoded_train_eval_export(self): scripted_model_path=os.path.join(self.test_dir, "export"), predict_input_path=os.path.join(self.test_dir, r"eval_data/\*.parquet"), predict_output_path=os.path.join(self.test_dir, "predict_result"), - reserved_columns="clk", - output_columns="probs", + reserved_columns=reserved_columns, + output_columns=output_columns, test_dir=self.test_dir, ) self.assertTrue(self.success) @@ -70,6 +70,20 @@ def test_multi_tower_din_fg_encoded_train_eval_export(self): os.path.exists(os.path.join(self.test_dir, "export/scripted_model.pt")) ) + def test_multi_tower_din_fg_encoded_train_eval_export(self): + self._test_rank_fg_encoded( + "tzrec/tests/configs/multi_tower_din_mock.config", + reserved_columns="clk", + output_columns="probs", + ) + + def test_dbmtl_has_sequence_fg_encoded_train_eval_export(self): + self._test_rank_fg_encoded( + "tzrec/tests/configs/dbmtl_has_sequence_mock.config", + reserved_columns="clk,buy", + output_columns="probs_ctr,probs_cvr", + ) + def test_multi_tower_din_fg_encoded_finetune(self): self.success = utils.test_train_eval( "tzrec/tests/configs/multi_tower_din_mock.config", @@ -83,9 +97,12 @@ def test_multi_tower_din_fg_encoded_finetune(self): ) self.assertTrue(self.success) - def test_dssm_fg_encoded_train_eval_export(self): + def _test_rank_with_fg(self, pipeline_config_path, comp_cpu_gpu_pred_result=False): self.success = utils.test_train_eval( - "tzrec/tests/configs/dssm_mock.config", self.test_dir, item_id="item_id" + pipeline_config_path, + self.test_dir, + user_id="user_id", + item_id="item_id", ) if self.success: self.success = utils.test_eval( @@ -97,65 +114,21 @@ def test_dssm_fg_encoded_train_eval_export(self): ) if self.success: self.success = utils.test_predict( - scripted_model_path=os.path.join(self.test_dir, "export/item"), + scripted_model_path=os.path.join(self.test_dir, "export"), predict_input_path=os.path.join(self.test_dir, r"eval_data/\*.parquet"), predict_output_path=os.path.join(self.test_dir, "predict_result"), - reserved_columns="item_id", - output_columns="item_tower_emb", + reserved_columns="user_id,item_id", + output_columns="", test_dir=self.test_dir, ) self.assertTrue(self.success) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "train/model.ckpt-63")) - ) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "export/user/scripted_model.pt")) - ) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "export/item/scripted_model.pt")) - ) - - def test_dssm_fg_encoded_variational_dropout(self): - self.success = utils.test_train_eval( - "tzrec/tests/configs/dssm_variational_dropout_mock.config", - self.test_dir, - item_id="item_id", - ) - if self.success: - self.success = utils.test_feature_selection( - os.path.join(self.test_dir, "pipeline.config"), self.test_dir - ) - self.assertTrue(self.success) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "train/model.ckpt-63")) - ) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "output_dir/pipeline.config")) - ) - - def test_multi_tower_din_with_fg_train_eval_export(self): - self.success = utils.test_train_eval( - "tzrec/tests/configs/multi_tower_din_fg_mock.config", - self.test_dir, - user_id="user_id", - item_id="item_id", - ) - if self.success: - self.success = utils.test_eval( - os.path.join(self.test_dir, "pipeline.config"), self.test_dir - ) - if self.success: - self.success = utils.test_export( - os.path.join(self.test_dir, "pipeline.config"), self.test_dir - ) - self.assertTrue(self.success) self.assertTrue( os.path.exists(os.path.join(self.test_dir, "train/eval_result.txt")) ) self.assertTrue( os.path.exists(os.path.join(self.test_dir, "export/scripted_model.pt")) ) - if torch.cuda.is_available(): + if comp_cpu_gpu_pred_result and torch.cuda.is_available(): pipeline_config = config_util.load_pipeline_config( os.path.join(self.test_dir, "pipeline.config") ) @@ -181,122 +154,18 @@ def test_multi_tower_din_with_fg_train_eval_export(self): os.path.join(self.test_dir, "export/scripted_model.pt"), map_location=device, ) - result_gpu = model_gpu(data.to_dict(sparse_dtype=torch.int32), device) + result_gpu = model_gpu(data.to_dict(sparse_dtype=torch.int64), device) for k, v in result_gpu.items(): torch.testing.assert_close( result_cpu[k].to(device), v, rtol=5e-3, atol=1e-5 ) - def test_dssm_with_fg_train_eval_export(self): - self.success = utils.test_train_eval( - "tzrec/tests/configs/dssm_fg_mock.config", - self.test_dir, - user_id="user_id", - item_id="item_id", - ) - if self.success: - self.success = utils.test_eval( - os.path.join(self.test_dir, "pipeline.config"), self.test_dir - ) - if self.success: - self.success = utils.test_export( - os.path.join(self.test_dir, "pipeline.config"), self.test_dir - ) - if self.success: - self.success = utils.test_predict( - scripted_model_path=os.path.join(self.test_dir, "export/item"), - predict_input_path=os.path.join(self.test_dir, r"item_data/\*.parquet"), - predict_output_path=os.path.join(self.test_dir, "item_emb"), - reserved_columns="item_id", - output_columns="item_tower_emb", - test_dir=self.test_dir, - ) - if self.success: - self.success = utils.test_create_faiss_index( - embedding_input_path=os.path.join( - self.test_dir, r"item_emb/\*.parquet" - ), - index_output_dir=os.path.join(self.test_dir, "export/user"), - id_field="item_id", - embedding_field="item_tower_emb", - test_dir=self.test_dir, - ) - if self.success: - self.success = utils.test_predict( - scripted_model_path=os.path.join(self.test_dir, "export/user"), - predict_input_path=os.path.join(self.test_dir, r"user_data/\*.parquet"), - predict_output_path=os.path.join(self.test_dir, "user_emb"), - reserved_columns="user_id,click_50_seq__item_id", - output_columns="user_tower_emb", - test_dir=self.test_dir, - ) - if self.success: - self.success = utils.test_hitrate( - user_gt_input=os.path.join(self.test_dir, r"user_emb/\*.parquet"), - item_embedding_input=os.path.join( - self.test_dir, r"item_emb/\*.parquet" - ), - total_hitrate_output=os.path.join(self.test_dir, "total_hitrate"), - hitrate_details_output=os.path.join(self.test_dir, "hitrate_details"), - request_id_field="user_id", - gt_items_field="click_50_seq__item_id", - test_dir=self.test_dir, - ) - if self.success: - self.success = utils.test_create_fg_json( - os.path.join(self.test_dir, "pipeline.config"), - fg_output_dir=os.path.join(self.test_dir, "fg_output"), - reserves="clk", - test_dir=self.test_dir, - ) - self.assertTrue(self.success) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "export/user/scripted_model.pt")) - ) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "export/user/faiss_index")) - ) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "export/user/id_mapping")) - ) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "export/item/scripted_model.pt")) - ) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "fg_output/fg.json")) - ) - self.assertTrue( - os.path.exists( - os.path.join(self.test_dir, "fg_output/item_title_tokenizer.json") - ) - ) - - def test_dssm_v2_with_fg_train_eval_export(self): - self.success = utils.test_train_eval( - "tzrec/tests/configs/dssm_v2_fg_mock.config", - self.test_dir, - user_id="user_id", - item_id="item_id", - ) - if self.success: - self.success = utils.test_eval( - os.path.join(self.test_dir, "pipeline.config"), self.test_dir - ) - if self.success: - self.success = utils.test_export( - os.path.join(self.test_dir, "pipeline.config"), self.test_dir - ) - self.assertTrue(self.success) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "export/user/scripted_model.pt")) - ) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "export/item/scripted_model.pt")) - ) - - def test_multi_tower_din_with_fg_train_eval_export_input_tile(self): + def _test_rank_with_fg_input_tile( + self, + pipeline_config_path, + ): self.success = utils.test_train_eval( - "tzrec/tests/configs/multi_tower_din_fg_mock.config", + pipeline_config_path, self.test_dir, user_id="user_id", item_id="item_id", @@ -457,7 +326,7 @@ def test_multi_tower_din_with_fg_train_eval_export_input_tile(self): os.path.join(self.test_dir, "export/scripted_model.pt"), map_location=device, ) - result_gpu = model_gpu(data.to_dict(sparse_dtype=torch.int32), device) + result_gpu = model_gpu(data.to_dict(sparse_dtype=torch.int64), device) result_dict_json_path = os.path.join(self.test_dir, "result_gpu.json") utils.save_predict_result_json(result_gpu, result_dict_json_path) for k, v in result_gpu.items(): @@ -471,7 +340,7 @@ def test_multi_tower_din_with_fg_train_eval_export_input_tile(self): map_location=device, ) result_gpu_no_quant = model_gpu_no_quant( - data.to_dict(sparse_dtype=torch.int32), device + data.to_dict(sparse_dtype=torch.int64), device ) result_dict_json_path = os.path.join( self.test_dir, "result_gpu_no_quant.json" @@ -493,7 +362,7 @@ def test_multi_tower_din_with_fg_train_eval_export_input_tile(self): iterator_input_tile = iter(dataloader) data_input_tile = next(iterator_input_tile) result_gpu_input_tile = model_gpu_input_tile( - data_input_tile.to(device=device).to_dict(sparse_dtype=torch.int32), + data_input_tile.to(device=device).to_dict(sparse_dtype=torch.int64), device, ) result_dict_json_path = os.path.join( @@ -507,7 +376,7 @@ def test_multi_tower_din_with_fg_train_eval_export_input_tile(self): map_location=device, ) result_gpu_input_tile_emb = model_gpu_input_tile_emb( - data_input_tile.to(device=device).to_dict(sparse_dtype=torch.int32), + data_input_tile.to(device=device).to_dict(sparse_dtype=torch.int64), device, ) @@ -525,133 +394,9 @@ def test_multi_tower_din_with_fg_train_eval_export_input_tile(self): result_gpu_input_tile[k].to(device), v, rtol=1e-4, atol=1e-4 ) - def test_dbmtl_has_sequence_train_eval_export(self): + def _test_rank_with_fg_trt(self, pipeline_config_path, predict_columns): self.success = utils.test_train_eval( - "tzrec/tests/configs/dbmtl_has_sequence_mock.config", - self.test_dir, - ) - if self.success: - self.success = utils.test_eval( - os.path.join(self.test_dir, "pipeline.config"), self.test_dir - ) - if self.success: - self.success = utils.test_export( - os.path.join(self.test_dir, "pipeline.config"), self.test_dir - ) - - if self.success: - self.success = utils.test_predict( - scripted_model_path=os.path.join(self.test_dir, "export"), - predict_input_path=os.path.join(self.test_dir, r"eval_data/\*.parquet"), - predict_output_path=os.path.join(self.test_dir, "predict_result"), - reserved_columns="clk,buy", - output_columns="probs_ctr,probs_cvr", - test_dir=self.test_dir, - ) - self.assertTrue(self.success) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "train/eval_result.txt")) - ) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "export/scripted_model.pt")) - ) - - def test_dbmtl_has_sequence_variational_dropout_train_eval_export(self): - self.success = utils.test_train_eval( - "tzrec/tests/configs/dbmtl_has_sequence_variational_dropout_mock.config", - self.test_dir, - ) - if self.success: - self.success = utils.test_eval( - os.path.join(self.test_dir, "pipeline.config"), self.test_dir - ) - if self.success: - self.success = utils.test_export( - os.path.join(self.test_dir, "pipeline.config"), self.test_dir - ) - if self.success: - self.success = utils.test_feature_selection( - os.path.join(self.test_dir, "pipeline.config"), self.test_dir - ) - - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "train/eval_result.txt")) - ) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "export/scripted_model.pt")) - ) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "output_dir/pipeline.config")) - ) - - def test_tdm_train_eval_export(self): - self.success = utils.test_train_eval( - "tzrec/tests/configs/tdm_fg_mock.config", - self.test_dir, - user_id="user_id", - item_id="item_id", - cate_id="cate_id", - ) - if self.success: - self.success = utils.test_eval( - os.path.join(self.test_dir, "pipeline.config"), self.test_dir - ) - if self.success: - self.success = utils.test_export( - os.path.join(self.test_dir, "pipeline.config"), - self.test_dir, - asset_files=os.path.join(self.test_dir, "init_tree/serving_tree"), - ) - if self.success: - self.success = utils.test_predict( - scripted_model_path=os.path.join(self.test_dir, "export/embedding"), - predict_input_path=os.path.join(self.test_dir, r"item_data/\*.parquet"), - predict_output_path=os.path.join(self.test_dir, "item_emb"), - reserved_columns="item_id,cate_id,id_4,id_5,raw_1,raw_2", - output_columns="item_emb", - test_dir=self.test_dir, - ) - if self.success: - self.success = utils.test_tdm_cluster_train_eval( - pipeline_config_path=os.path.join(self.test_dir, "pipeline.config"), - test_dir=self.test_dir, - item_input_path=os.path.join(self.test_dir, r"item_emb/\*.parquet"), - item_id="item_id", - embedding_field="item_emb", - ) - if self.success: - self.success = utils.test_tdm_retrieval( - scripted_model_path=os.path.join(self.test_dir, "export/model"), - eval_data_path=os.path.join(self.test_dir, r"eval_data/\*.parquet"), - retrieval_output_path=os.path.join(self.test_dir, "retrieval_result"), - reserved_columns="user_id,item_id", - test_dir=self.test_dir, - ) - - self.assertTrue(self.success) - self.assertTrue(os.path.exists(os.path.join(self.test_dir, "learnt_tree"))) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "train/eval_result.txt")) - ) - self.assertTrue( - os.path.exists( - os.path.join(self.test_dir, "export/embedding/scripted_model.pt") - ) - ) - self.assertTrue( - os.path.exists( - os.path.join(self.test_dir, "export/model/scripted_model.pt") - ) - ) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "export/model/serving_tree")) - ) - self.assertTrue(os.path.exists(os.path.join(self.test_dir, "retrieval_result"))) - - @unittest.skipIf(not torch.cuda.is_available(), "cuda not found") - def test_multi_tower_with_fg_train_eval_export_trt(self): - self.success = utils.test_train_eval( - "tzrec/tests/configs/multi_tower_din_trt_fg_mock.config", + pipeline_config_path, self.test_dir, user_id="user_id", item_id="item_id", @@ -675,7 +420,6 @@ def test_multi_tower_with_fg_train_eval_export_trt(self): self.test_dir, "predict_result_tile_emb_trt" ) - predict_columns = ["user_id", "item_id", "clk", "probs"] # quant and no-input-tile if self.success: self.success = utils.test_export( @@ -833,7 +577,7 @@ def test_multi_tower_with_fg_train_eval_export_trt(self): model_gpu = torch.jit.load( os.path.join(self.test_dir, "export/scripted_model.pt"), map_location=device ) - result_gpu = model_gpu(data.to_dict(sparse_dtype=torch.int32), device) + result_gpu = model_gpu(data.to_dict(sparse_dtype=torch.int64), device) result_dict_json_path = os.path.join(self.test_dir, "result_gpu.json") utils.save_predict_result_json(result_gpu, result_dict_json_path) @@ -842,7 +586,7 @@ def test_multi_tower_with_fg_train_eval_export_trt(self): os.path.join(self.test_dir, "trt/export/scripted_model.pt"), map_location=device, ) - result_gpu_trt = model_gpu_trt(data.to_dict(sparse_dtype=torch.int32)) + result_gpu_trt = model_gpu_trt(data.to_dict(sparse_dtype=torch.int64)) result_dict_json_path = os.path.join(self.test_dir, "result_gpu_trt.json") utils.save_predict_result_json(result_gpu_trt, result_dict_json_path) @@ -861,7 +605,7 @@ def test_multi_tower_with_fg_train_eval_export_trt(self): iterator_input_tile = iter(dataloader) data_input_tile = next(iterator_input_tile) result_gpu_input_tile = model_gpu_input_tile( - data_input_tile.to(device=device).to_dict(sparse_dtype=torch.int32) + data_input_tile.to(device=device).to_dict(sparse_dtype=torch.int64) ) result_dict_json_path = os.path.join( self.test_dir, "result_gpu_input_tile.json" @@ -874,7 +618,7 @@ def test_multi_tower_with_fg_train_eval_export_trt(self): map_location=device, ) result_gpu_input_tile_emb = model_gpu_input_tile_emb( - data_input_tile.to(device=device).to_dict(sparse_dtype=torch.int32) + data_input_tile.to(device=device).to_dict(sparse_dtype=torch.int64) ) # trt is all same sa no-trt @@ -895,6 +639,72 @@ def test_multi_tower_with_fg_train_eval_export_trt(self): result_gpu_input_tile_emb[k].to(device), v, rtol=1e-6, atol=1e-6 ) + def test_multi_tower_din_with_fg_train_eval_export(self): + self._test_rank_with_fg( + "tzrec/tests/configs/multi_tower_din_fg_mock.config", + comp_cpu_gpu_pred_result=True, + ) + + @unittest.skipIf(not torch.cuda.is_available(), "cuda not found") + def test_multi_tower_din_zch_with_fg_train_eval_export(self): + self._test_rank_with_fg( + "tzrec/tests/configs/multi_tower_din_zch_fg_mock.config", + comp_cpu_gpu_pred_result=True, + ) + + def test_multi_tower_din_with_fg_train_eval_export_input_tile(self): + self._test_rank_with_fg_input_tile( + "tzrec/tests/configs/multi_tower_din_fg_mock.config" + ) + + @unittest.skipIf(not torch.cuda.is_available(), "cuda not found") + def test_multi_tower_din_zch_with_fg_train_eval_export_input_tile(self): + self._test_rank_with_fg_input_tile( + "tzrec/tests/configs/multi_tower_din_zch_fg_mock.config" + ) + + def test_dbmtl_has_sequence_variational_dropout_train_eval_export(self): + self.success = utils.test_train_eval( + "tzrec/tests/configs/dbmtl_has_sequence_variational_dropout_mock.config", + self.test_dir, + ) + if self.success: + self.success = utils.test_eval( + os.path.join(self.test_dir, "pipeline.config"), self.test_dir + ) + if self.success: + self.success = utils.test_export( + os.path.join(self.test_dir, "pipeline.config"), self.test_dir + ) + if self.success: + self.success = utils.test_feature_selection( + os.path.join(self.test_dir, "pipeline.config"), self.test_dir + ) + + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "train/eval_result.txt")) + ) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "export/scripted_model.pt")) + ) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "output_dir/pipeline.config")) + ) + + @unittest.skipIf(not torch.cuda.is_available(), "cuda not found") + def test_multi_tower_with_fg_train_eval_export_trt(self): + self._test_rank_with_fg_trt( + "tzrec/tests/configs/multi_tower_din_trt_fg_mock.config", + predict_columns=["user_id", "item_id", "clk", "probs"], + ) + + @unittest.skipIf(not torch.cuda.is_available(), "cuda not found") + def test_multi_tower_zch_with_fg_train_eval_export_trt(self): + self._test_rank_with_fg_trt( + "tzrec/tests/configs/multi_tower_din_zch_trt_fg_mock.config", + predict_columns=["user_id", "item_id", "clk", "probs"], + ) + if __name__ == "__main__": unittest.main() diff --git a/tzrec/tests/utils.py b/tzrec/tests/utils.py index c8d438c..2c1623e 100644 --- a/tzrec/tests/utils.py +++ b/tzrec/tests/utils.py @@ -940,8 +940,9 @@ def test_predict( f"--predict_input_path {predict_input_path} " f"--predict_output_path {predict_output_path} " f"--reserved_columns {reserved_columns} " - f"--output_columns {output_columns}" ) + if output_columns: + cmd_str += f"--output_columns {output_columns}" p = misc_util.run_cmd(cmd_str, os.path.join(test_dir, "log_predict.txt")) p.wait(600) diff --git a/tzrec/utils/fx_util.py b/tzrec/utils/fx_util.py new file mode 100644 index 0000000..919f3ce --- /dev/null +++ b/tzrec/utils/fx_util.py @@ -0,0 +1,45 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from torchrec.fx import symbolic_trace as _symbolic_trace + + +def symbolic_trace( + # pyre-ignore[24] + root: Union[torch.nn.Module, Callable], + concrete_args: Optional[Dict[str, Any]] = None, + leaf_modules: Optional[List[str]] = None, +) -> torch.fx.GraphModule: + """Symbolic tracing API. + + Given an `nn.Module` or function instance `root`, this function will return a + `GraphModule` constructed by recording operations seen while tracing through `root`. + + `concrete_args` allows you to partially specialize your function, whether it's to + remove control flow or data structures. + + Args: + root (Union[torch.nn.Module, Callable]): Module or function to be traced and + converted into a Graph representation. + concrete_args (Optional[Dict[str, any]]): Inputs to be partially specialized + leaf_modules (Optional[List[str]]): modules do not trace + + Returns: + GraphModule: a Module created from the recorded operations from ``root``. + """ + # ComputeJTDictToKJT could not be traced + _leaf_modules = ["ComputeJTDictToKJT"] + if leaf_modules: + _leaf_modules.extend(leaf_modules) + return _symbolic_trace(root, concrete_args, _leaf_modules) diff --git a/tzrec/utils/test_util.py b/tzrec/utils/test_util.py index ba0b986..e2d576d 100644 --- a/tzrec/utils/test_util.py +++ b/tzrec/utils/test_util.py @@ -17,9 +17,9 @@ import torch from torch import nn from torch.fx import GraphModule -from torchrec.fx import symbolic_trace from tzrec.models.model import ScriptWrapper +from tzrec.utils.fx_util import symbolic_trace class TestGraphType(Enum): diff --git a/tzrec/version.py b/tzrec/version.py index 5c07ea6..64d31d9 100644 --- a/tzrec/version.py +++ b/tzrec/version.py @@ -9,4 +9,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.6.6" +__version__ = "0.6.7"