From cec440b5172c36ba7a3b5c16dacc1d14d26d2e10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Tue, 10 Dec 2024 20:33:49 +0800 Subject: [PATCH] add doc --- docs/source/feature/feature.md | 15 +++++++++++++++ docs/source/index.rst | 1 + tzrec/datasets/data_parser.py | 1 - tzrec/features/combo_feature.py | 6 ++++-- tzrec/features/lookup_feature.py | 4 +++- tzrec/features/match_feature.py | 4 +++- 6 files changed, 26 insertions(+), 5 deletions(-) diff --git a/docs/source/feature/feature.md b/docs/source/feature/feature.md index d50481c..cf5ad03 100644 --- a/docs/source/feature/feature.md +++ b/docs/source/feature/feature.md @@ -58,6 +58,17 @@ feature_configs { embedding_dim: 32 vocab_dict: [{key:"a" value:2}, {key:"b" value:3}, {key:"c" value:2}] } +feature_configs { + id_feature { + feature_name: "cate" + expression: "item:cate" + embedding_dim: 32 + zch: { + zch_size: 1000000 + eviction_interval: 2 + lfu {} + } + } } ``` @@ -75,6 +86,8 @@ feature_configs { - **vocab_dict**: 指定字典形式词表,适合多个词需要编码到同一个编号情况,**编号需要从2开始**,编码0预留给默认值,编码1预留给超出词表的词 +- **zch**: 零冲突hash,可设置Id的准入和驱逐策略,详见[文档](../zch.md) + - **weighted**: 是否为带权重的Id特征,输入形式为`k1:v1\x1dk2:v2` - **value_dim**: 默认值是0,可以设置1,value_dim=0时支持多值ID输出 @@ -207,6 +220,7 @@ feature_configs: { - **num_buckets**: buckets数量, 仅仅当输入是integer类型时,可以使用num_buckets - **vocab_list**: 指定词表,适合取值比较少可以枚举的特征。 - **vocab_dict**: 指定字典形式词表,适合多个词需要编码到同一个编号情况,**编号需要从2开始**,编码0预留给默认值,编码1预留给超出词表的词 +- **zch**: 零冲突hash,可设置Id的准入和驱逐策略,详见[文档](../zch.md) - **value_dim**: 默认值是0,可以设置1,value_dim=0时支持多值ID输出 如果Map的值为连续值,可设置: @@ -247,6 +261,7 @@ feature_configs: { - **num_buckets**: buckets数量, 仅仅当输入是integer类型时,可以使用num_buckets - **vocab_list**: 指定词表,适合取值比较少可以枚举的特征。 - **vocab_dict**: 指定字典形式词表,适合多个词需要编码到同一个编号情况,**编号需要从2开始**,编码0预留给默认值,编码1预留给超出词表的词 +- **zch**: 零冲突hash,可设置Id的准入和驱逐策略,详见[文档](../zch.md) - **value_dim**: 默认值是0,可以设置1,value_dim=0时支持多值ID输出 如果Map的值为连续值,可设置: diff --git a/docs/source/index.rst b/docs/source/index.rst index a643e2f..3c82600 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -14,6 +14,7 @@ Welcome to TorchEasyRec's documentation! feature/data feature/feature + feature/zch .. toctree:: :maxdepth: 2 diff --git a/tzrec/datasets/data_parser.py b/tzrec/datasets/data_parser.py index 2d1267d..b7c38f9 100644 --- a/tzrec/datasets/data_parser.py +++ b/tzrec/datasets/data_parser.py @@ -119,7 +119,6 @@ def _init_fg_hander(self) -> None: if not self._fg_handler: fg_json = create_fg_json(self._features) # pyre-ignore [16] - print(fg_json) self._fg_handler = pyfg.FgArrowHandler(fg_json, self._fg_threads) def parse(self, input_data: Dict[str, pa.Array]) -> Dict[str, torch.Tensor]: diff --git a/tzrec/features/combo_feature.py b/tzrec/features/combo_feature.py index f5cd55b..c608c1d 100644 --- a/tzrec/features/combo_feature.py +++ b/tzrec/features/combo_feature.py @@ -57,7 +57,9 @@ def is_neg(self, value: bool) -> None: @property def num_embeddings(self) -> int: """Get embedding row count.""" - if self.config.HasField("hash_bucket_size"): + if self.config.HasField("zch"): + num_embeddings = self.config.zch.zch_size + elif self.config.HasField("hash_bucket_size"): num_embeddings = self.config.hash_bucket_size elif len(self.config.vocab_list) > 0: num_embeddings = len(self.config.vocab_list) + 2 @@ -73,7 +75,7 @@ def num_embeddings(self) -> int: else: raise ValueError( f"{self.__class__.__name__}[{self.name}] must set hash_bucket_size" - " or vocab_list or vocab_dict" + " or vocab_list or vocab_dict or zch.zch_size" ) return num_embeddings diff --git a/tzrec/features/lookup_feature.py b/tzrec/features/lookup_feature.py index 2899c4a..9d4aac9 100644 --- a/tzrec/features/lookup_feature.py +++ b/tzrec/features/lookup_feature.py @@ -85,7 +85,9 @@ def is_sparse(self) -> bool: @property def num_embeddings(self) -> int: """Get embedding row count.""" - if self.config.HasField("hash_bucket_size"): + if self.config.HasField("zch"): + num_embeddings = self.config.zch.zch_size + elif self.config.HasField("hash_bucket_size"): num_embeddings = self.config.hash_bucket_size elif self.config.HasField("num_buckets"): num_embeddings = self.config.num_buckets diff --git a/tzrec/features/match_feature.py b/tzrec/features/match_feature.py index fcd8136..4fbc557 100644 --- a/tzrec/features/match_feature.py +++ b/tzrec/features/match_feature.py @@ -87,7 +87,9 @@ def is_sparse(self) -> bool: @property def num_embeddings(self) -> int: """Get embedding row count.""" - if self.config.HasField("hash_bucket_size"): + if self.config.HasField("zch"): + num_embeddings = self.config.zch.zch_size + elif self.config.HasField("hash_bucket_size"): num_embeddings = self.config.hash_bucket_size elif self.config.HasField("num_buckets"): num_embeddings = self.config.num_buckets