Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] add zero collision hash embedding module #60

Merged
merged 13 commits into from
Dec 12, 2024
15 changes: 15 additions & 0 deletions docs/source/feature/feature.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,17 @@ feature_configs {
embedding_dim: 32
vocab_dict: [{key:"a" value:2}, {key:"b" value:3}, {key:"c" value:2}]
}
feature_configs {
id_feature {
feature_name: "cate"
expression: "item:cate"
embedding_dim: 32
zch: {
zch_size: 1000000
eviction_interval: 2
lfu {}
}
}
}
```

Expand All @@ -75,6 +86,8 @@ feature_configs {

- **vocab_dict**: 指定字典形式词表,适合多个词需要编码到同一个编号情况,**编号需要从2开始**,编码0预留给默认值,编码1预留给超出词表的词

- **zch**: 零冲突hash,可设置Id的准入和驱逐策略,详见[文档](../zch.md)

- **weighted**: 是否为带权重的Id特征,输入形式为`k1:v1\x1dk2:v2`

- **value_dim**: 默认值是0,可以设置1,value_dim=0时支持多值ID输出
Expand Down Expand Up @@ -207,6 +220,7 @@ feature_configs: {
- **num_buckets**: buckets数量, 仅仅当输入是integer类型时,可以使用num_buckets
- **vocab_list**: 指定词表,适合取值比较少可以枚举的特征。
- **vocab_dict**: 指定字典形式词表,适合多个词需要编码到同一个编号情况,**编号需要从2开始**,编码0预留给默认值,编码1预留给超出词表的词
- **zch**: 零冲突hash,可设置Id的准入和驱逐策略,详见[文档](../zch.md)
- **value_dim**: 默认值是0,可以设置1,value_dim=0时支持多值ID输出

如果Map的值为连续值,可设置:
Expand Down Expand Up @@ -247,6 +261,7 @@ feature_configs: {
- **num_buckets**: buckets数量, 仅仅当输入是integer类型时,可以使用num_buckets
- **vocab_list**: 指定词表,适合取值比较少可以枚举的特征。
- **vocab_dict**: 指定字典形式词表,适合多个词需要编码到同一个编号情况,**编号需要从2开始**,编码0预留给默认值,编码1预留给超出词表的词
- **zch**: 零冲突hash,可设置Id的准入和驱逐策略,详见[文档](../zch.md)
tiankongdeguiji marked this conversation as resolved.
Show resolved Hide resolved
- **value_dim**: 默认值是0,可以设置1,value_dim=0时支持多值ID输出

如果Map的值为连续值,可设置:
Expand Down
151 changes: 151 additions & 0 deletions docs/source/feature/zch.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# 零冲突Hash Embedding

零冲突Hash (Zero Collision Hash, zch) 是特征Id化的一种方式,它相比设置`hash_bucket_size`的方式能减少hash冲突,相比设置`vocab_dict`和`vocab_list`的方式能更灵活动态地进行id的准入和驱逐。零冲突Hash常用于user id,item id,combo feature等超大id枚举数的特征配置中。

以id_feature的配置为例,零冲突Hash只需在id_feature新增一个zch的配置字段

```
feature_configs {
id_feature {
feature_name: "cate"
expression: "item:cate"
embedding_dim: 32
zch: {
zch_size: 1000000
eviction_interval: 2
lfu {}
}
}
}
```

- **zch_size**: 零冲突Hash的Bucket大小,Id数超过后会根据Id的驱逐策略进行淘汰

- **eviction_interval**: Id准入和驱逐策略执行的频率(训练步数间隔)

- **eviction_policy**: 驱逐策略,可选`lfu`,`lru`,`distance_lfu`,详见下文驱逐策略

- **threshold_filtering_func**: 准入策略lambda函数,默认为全部准入,详见下文准入策略

## 驱逐策略

### LFU_EvictionPolicy

驱逐最小出现次数的Id
id_score = access_cnt

```
lfu {}
```

### LRU_EvictionPolicy

驱逐最早出现的Id
id_score = 1 / pow((current_iter - last_access_iter), decay_exponent)

```
lru {
decay_exponent: 1.0
}
```

### DistanceLFU_EvictionPolicy

综合出现次数和出现时间综合根据综合驱逐id_score较小的Id
id_score = access_cnt / pow((current_iter - last_access_iter), decay_exponent)

```
distance_lfu {
decay_exponent: 1.0
}
```

## 准入策略

准入策略需设置一个lambda函数表达式,函数输入输出应符合如下格式

- 输入:一个1维的IntTensor表示最近`eviction_interval`个batch中每个id的出现次数
- 输出:一个1维的BoolTensor表示保留的id位置 和 一个float值表示id出现次数的阈值

函数可支持直接用torch的tensor库来撰写,样例如下:

```
zch: {
zch_size: 1000000
eviction_interval: 2
lfu {}
threshold_filtering_func: "lambda x: (x > 10, 10)"
}
```

函数也可以支持调用内置函数:`dynamic_threshold_filter`, `average_threshold_filter` 和 `probabilistic_threshold_filter`,样例如下:

```
zch: {
zch_size: 1000000
eviction_interval: 2
lfu {}
threshold_filtering_func: "lambda x: dynamic_threshold_filter(x, 1.0)"
}
```

相关内置函数的实现细节如下:

```python
@torch.no_grad()
def dynamic_threshold_filter(
id_counts: torch.Tensor,
threshold_skew_multiplier: float = 10.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Threshold is total_count / num_ids * threshold_skew_multiplier. An id is
added if its count is strictly greater than the threshold.
"""

num_ids = id_counts.numel()
total_count = id_counts.sum()

BASE_THRESHOLD = 1 / num_ids
threshold_mass = BASE_THRESHOLD * threshold_skew_multiplier

threshold = threshold_mass * total_count
threshold_mask = id_counts > threshold

return threshold_mask, threshold


@torch.no_grad()
def average_threshold_filter(
id_counts: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Threshold is average of id_counts. An id is added if its count is strictly
greater than the mean.
"""
if id_counts.dtype != torch.float:
id_counts = id_counts.float()
threshold = id_counts.mean()
threshold_mask = id_counts > threshold

return threshold_mask, threshold


@torch.no_grad()
def probabilistic_threshold_filter(
id_counts: torch.Tensor,
per_id_probability: float = 0.01,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Each id has probability per_id_probability of being added. For example,
if per_id_probability is 0.01 and an id appears 100 times, then it has a 60%
of being added. More precisely, the id score is 1 - (1 - per_id_probability) ^ id_count,
and for a randomly generated threshold, the id score is the chance of it being added.
"""
probability = torch.full_like(id_counts, 1 - per_id_probability, dtype=torch.float)
id_scores = 1 - torch.pow(probability, id_counts)

threshold: torch.Tensor = torch.rand(id_counts.size(), device=id_counts.device)
threshold_mask = id_scores > threshold

return threshold_mask, threshold
```
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Welcome to TorchEasyRec's documentation!

feature/data
feature/feature
feature/zch

.. toctree::
:maxdepth: 2
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
-r requirements/runtime.txt
-r requirements/test.txt
-r requirements/docs.txt
-r requirements/gpu.txt
2 changes: 2 additions & 0 deletions requirements/gpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
torch-tensorrt @ http://tzrec.oss-cn-beijing.aliyuncs.com/third_party/trt/torch_tensorrt-2.5.0a0-cp311-cp311-linux_x86_64.whl ; python_version=="3.11"
torch-tensorrt @ http://tzrec.oss-cn-beijing.aliyuncs.com/third_party/trt/torch_tensorrt-2.5.0a0-cp310-cp310-linux_x86_64.whl ; python_version=="3.10"
6 changes: 2 additions & 4 deletions requirements/runtime.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@ graphlearn @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/graphlearn-1.
graphlearn @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/graphlearn-1.3.1-cp310-cp310-linux_x86_64.whl ; python_version=="3.10"
grpcio-tools<1.63.0
pandas
pyfg @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/pyfg-0.3.7-cp311-cp311-linux_x86_64.whl ; python_version=="3.11"
pyfg @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/pyfg-0.3.7-cp310-cp310-linux_x86_64.whl ; python_version=="3.10"
pyfg @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/pyfg-0.3.9-cp311-cp311-linux_x86_64.whl ; python_version=="3.11"
pyfg @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/pyfg-0.3.9-cp310-cp310-linux_x86_64.whl ; python_version=="3.10"
pyodps>=0.12.0
scikit-learn
tensorboard
torch==2.5.0
torch-tensorrt @ http://tzrec.oss-cn-beijing.aliyuncs.com/third_party/trt/torch_tensorrt-2.5.0a0-cp311-cp311-linux_x86_64.whl ; python_version=="3.11"
torch-tensorrt @ http://tzrec.oss-cn-beijing.aliyuncs.com/third_party/trt/torch_tensorrt-2.5.0a0-cp310-cp310-linux_x86_64.whl ; python_version=="3.10"
torchmetrics==1.0.3
torchrec==1.0.0
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,6 @@ def parse_require_file(fpath):
extras_require={
"all": parse_requirements("requirements.txt"),
"tests": parse_requirements("requirements/test.txt"),
"gpu": parse_requirements("requirements/gpu.txt"),
},
)
2 changes: 1 addition & 1 deletion tzrec/acc/trt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
pass
from torch import nn
from torch.profiler import ProfilerActivity, profile, record_function
from torchrec.fx import symbolic_trace

from tzrec.acc.utils import is_debug_trt
from tzrec.models.model import ScriptWrapper
from tzrec.utils.fx_util import symbolic_trace
from tzrec.utils.logging_util import logger


Expand Down
34 changes: 18 additions & 16 deletions tzrec/acc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,25 +115,26 @@ def write_mapping_file_for_input_tile(
state_dict (Dict[str, torch.Tensor]): model state_dict
remap_file_path (str) : store new_params_name\told_params_name\n
"""
input_tile_keys = [
".ebc_user.embedding_bags.",
".ebc_item.embedding_bags.",
]
input_tile_keys_ec = [
".ec_list_user.",
".ec_list_item.",
]
input_tile_mapping = {
".ebc_user.embedding_bags.": ".ebc.embedding_bags.",
".ebc_item.embedding_bags.": ".ebc.embedding_bags.",
".mc_ebc_user._embedding_module.": ".mc_ebc._embedding_module.",
".mc_ebc_item._embedding_module.": ".mc_ebc._embedding_module.",
".mc_ebc_user._managed_collision_collection.": ".mc_ebc._managed_collision_collection.", # NOQA
".mc_ebc_item._managed_collision_collection.": ".mc_ebc._managed_collision_collection.", # NOQA
".ec_list_user.": ".ec_list.",
".ec_list_item.": ".ec_list.",
".mc_ec_list_user.": ".mc_ec_list.",
".mc_ec_list_item.": ".mc_ec_list.",
}

remap_str = ""
for key, _ in state_dict.items():
for input_tile_key in input_tile_keys:
for input_tile_key in input_tile_mapping:
if input_tile_key in key:
src_key = key.replace(input_tile_key, ".ebc.embedding_bags.")
remap_str += key + "\t" + src_key + "\n"

for input_tile_key in input_tile_keys_ec:
if input_tile_key in key:
src_key = key.replace(input_tile_key, ".ec_list.")
src_key = key.replace(
input_tile_key, input_tile_mapping[input_tile_key]
)
remap_str += key + "\t" + src_key + "\n"

with open(remap_file_path, "w") as f:
Expand All @@ -142,7 +143,8 @@ def write_mapping_file_for_input_tile(

def export_acc_config() -> Dict[str, str]:
"""Export acc config for model online inference."""
acc_config = dict()
# use int64 sparse id as input
acc_config = {"SPARSE_INT64": "1"}
if "INPUT_TILE" in os.environ:
acc_config["INPUT_TILE"] = os.environ["INPUT_TILE"]
if "QUANT_EMB" in os.environ:
Expand Down
16 changes: 12 additions & 4 deletions tzrec/features/combo_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
ParsedData,
SparseData,
)
from tzrec.features.feature import FgMode, _parse_fg_encoded_sparse_feature_impl
from tzrec.features.feature import (
MAX_HASH_BUCKET_SIZE,
FgMode,
_parse_fg_encoded_sparse_feature_impl,
)
from tzrec.features.id_feature import IdFeature
from tzrec.protos.feature_pb2 import FeatureConfig
from tzrec.utils.logging_util import logger
Expand Down Expand Up @@ -53,7 +57,9 @@ def is_neg(self, value: bool) -> None:
@property
def num_embeddings(self) -> int:
"""Get embedding row count."""
if self.config.HasField("hash_bucket_size"):
if self.config.HasField("zch"):
num_embeddings = self.config.zch.zch_size
elif self.config.HasField("hash_bucket_size"):
num_embeddings = self.config.hash_bucket_size
elif len(self.config.vocab_list) > 0:
num_embeddings = len(self.config.vocab_list) + 2
Expand All @@ -69,7 +75,7 @@ def num_embeddings(self) -> int:
else:
raise ValueError(
f"{self.__class__.__name__}[{self.name}] must set hash_bucket_size"
" or vocab_list or vocab_dict"
" or vocab_list or vocab_dict or zch.zch_size"
)
return num_embeddings

Expand Down Expand Up @@ -116,7 +122,9 @@ def fg_json(self) -> List[Dict[str, Any]]:
}
if self.config.separator != "\x1d":
fg_cfg["separator"] = self.config.separator
if self.config.HasField("hash_bucket_size"):
if self.config.HasField("zch"):
fg_cfg["hash_bucket_size"] = MAX_HASH_BUCKET_SIZE
elif self.config.HasField("hash_bucket_size"):
fg_cfg["hash_bucket_size"] = self.config.hash_bucket_size
elif len(self.config.vocab_list) > 0:
fg_cfg["vocab_list"] = [self.config.default_value, "<OOV>"] + list(
Expand Down
Loading