Skip to content

Commit

Permalink
[feat] add zero collision hash embedding module (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
tiankongdeguiji authored Dec 12, 2024
1 parent 8caf811 commit 80d22e1
Show file tree
Hide file tree
Showing 28 changed files with 1,853 additions and 431 deletions.
15 changes: 15 additions & 0 deletions docs/source/feature/feature.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,17 @@ feature_configs {
embedding_dim: 32
vocab_dict: [{key:"a" value:2}, {key:"b" value:3}, {key:"c" value:2}]
}
feature_configs {
id_feature {
feature_name: "cate"
expression: "item:cate"
embedding_dim: 32
zch: {
zch_size: 1000000
eviction_interval: 2
lfu {}
}
}
}
```

Expand All @@ -75,6 +86,8 @@ feature_configs {

- **vocab_dict**: 指定字典形式词表,适合多个词需要编码到同一个编号情况,**编号需要从2开始**,编码0预留给默认值,编码1预留给超出词表的词

- **zch**: 零冲突hash,可设置Id的准入和驱逐策略,详见[文档](../zch.md)

- **weighted**: 是否为带权重的Id特征,输入形式为`k1:v1\x1dk2:v2`

- **value_dim**: 默认值是0,可以设置1,value_dim=0时支持多值ID输出
Expand Down Expand Up @@ -207,6 +220,7 @@ feature_configs: {
- **num_buckets**: buckets数量, 仅仅当输入是integer类型时,可以使用num_buckets
- **vocab_list**: 指定词表,适合取值比较少可以枚举的特征。
- **vocab_dict**: 指定字典形式词表,适合多个词需要编码到同一个编号情况,**编号需要从2开始**,编码0预留给默认值,编码1预留给超出词表的词
- **zch**: 零冲突hash,可设置Id的准入和驱逐策略,详见[文档](../zch.md)
- **value_dim**: 默认值是0,可以设置1,value_dim=0时支持多值ID输出

如果Map的值为连续值,可设置:
Expand Down Expand Up @@ -247,6 +261,7 @@ feature_configs: {
- **num_buckets**: buckets数量, 仅仅当输入是integer类型时,可以使用num_buckets
- **vocab_list**: 指定词表,适合取值比较少可以枚举的特征。
- **vocab_dict**: 指定字典形式词表,适合多个词需要编码到同一个编号情况,**编号需要从2开始**,编码0预留给默认值,编码1预留给超出词表的词
- **zch**: 零冲突hash,可设置Id的准入和驱逐策略,详见[文档](../zch.md)
- **value_dim**: 默认值是0,可以设置1,value_dim=0时支持多值ID输出

如果Map的值为连续值,可设置:
Expand Down
151 changes: 151 additions & 0 deletions docs/source/feature/zch.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# 零冲突Hash Embedding

零冲突Hash (Zero Collision Hash, zch) 是特征Id化的一种方式,它相比设置`hash_bucket_size`的方式能减少hash冲突,相比设置`vocab_dict``vocab_list`的方式能更灵活动态地进行id的准入和驱逐。零冲突Hash常用于user id,item id,combo feature等超大id枚举数的特征配置中。

以id_feature的配置为例,零冲突Hash只需在id_feature新增一个zch的配置字段

```
feature_configs {
id_feature {
feature_name: "cate"
expression: "item:cate"
embedding_dim: 32
zch: {
zch_size: 1000000
eviction_interval: 2
lfu {}
}
}
}
```

- **zch_size**: 零冲突Hash的Bucket大小,Id数超过后会根据Id的驱逐策略进行淘汰

- **eviction_interval**: Id准入和驱逐策略执行的频率(训练步数间隔)

- **eviction_policy**: 驱逐策略,可选`lfu``lru``distance_lfu`,详见下文驱逐策略

- **threshold_filtering_func**: 准入策略lambda函数,默认为全部准入,详见下文准入策略

## 驱逐策略

### LFU_EvictionPolicy

驱逐最小出现次数的Id
id_score = access_cnt

```
lfu {}
```

### LRU_EvictionPolicy

驱逐最早出现的Id
id_score = 1 / pow((current_iter - last_access_iter), decay_exponent)

```
lru {
decay_exponent: 1.0
}
```

### DistanceLFU_EvictionPolicy

综合出现次数和出现时间综合根据综合驱逐id_score较小的Id
id_score = access_cnt / pow((current_iter - last_access_iter), decay_exponent)

```
distance_lfu {
decay_exponent: 1.0
}
```

## 准入策略

准入策略需设置一个lambda函数表达式,函数输入输出应符合如下格式

- 输入:一个1维的IntTensor表示最近`eviction_interval`个batch中每个id的出现次数
- 输出:一个1维的BoolTensor表示保留的id位置 和 一个float值表示id出现次数的阈值

函数可支持直接用torch的tensor库来撰写,样例如下:

```
zch: {
zch_size: 1000000
eviction_interval: 2
lfu {}
threshold_filtering_func: "lambda x: (x > 10, 10)"
}
```

函数也可以支持调用内置函数:`dynamic_threshold_filter`, `average_threshold_filter``probabilistic_threshold_filter`,样例如下:

```
zch: {
zch_size: 1000000
eviction_interval: 2
lfu {}
threshold_filtering_func: "lambda x: dynamic_threshold_filter(x, 1.0)"
}
```

相关内置函数的实现细节如下:

```python
@torch.no_grad()
def dynamic_threshold_filter(
id_counts: torch.Tensor,
threshold_skew_multiplier: float = 10.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Threshold is total_count / num_ids * threshold_skew_multiplier. An id is
added if its count is strictly greater than the threshold.
"""

num_ids = id_counts.numel()
total_count = id_counts.sum()

BASE_THRESHOLD = 1 / num_ids
threshold_mass = BASE_THRESHOLD * threshold_skew_multiplier

threshold = threshold_mass * total_count
threshold_mask = id_counts > threshold

return threshold_mask, threshold


@torch.no_grad()
def average_threshold_filter(
id_counts: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Threshold is average of id_counts. An id is added if its count is strictly
greater than the mean.
"""
if id_counts.dtype != torch.float:
id_counts = id_counts.float()
threshold = id_counts.mean()
threshold_mask = id_counts > threshold

return threshold_mask, threshold


@torch.no_grad()
def probabilistic_threshold_filter(
id_counts: torch.Tensor,
per_id_probability: float = 0.01,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Each id has probability per_id_probability of being added. For example,
if per_id_probability is 0.01 and an id appears 100 times, then it has a 60%
of being added. More precisely, the id score is 1 - (1 - per_id_probability) ^ id_count,
and for a randomly generated threshold, the id score is the chance of it being added.
"""
probability = torch.full_like(id_counts, 1 - per_id_probability, dtype=torch.float)
id_scores = 1 - torch.pow(probability, id_counts)

threshold: torch.Tensor = torch.rand(id_counts.size(), device=id_counts.device)
threshold_mask = id_scores > threshold

return threshold_mask, threshold
```
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Welcome to TorchEasyRec's documentation!

feature/data
feature/feature
feature/zch

.. toctree::
:maxdepth: 2
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
-r requirements/runtime.txt
-r requirements/test.txt
-r requirements/docs.txt
-r requirements/gpu.txt
2 changes: 2 additions & 0 deletions requirements/gpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
torch-tensorrt @ http://tzrec.oss-cn-beijing.aliyuncs.com/third_party/trt/torch_tensorrt-2.5.0a0-cp311-cp311-linux_x86_64.whl ; python_version=="3.11"
torch-tensorrt @ http://tzrec.oss-cn-beijing.aliyuncs.com/third_party/trt/torch_tensorrt-2.5.0a0-cp310-cp310-linux_x86_64.whl ; python_version=="3.10"
6 changes: 2 additions & 4 deletions requirements/runtime.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@ graphlearn @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/graphlearn-1.
graphlearn @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/graphlearn-1.3.1-cp310-cp310-linux_x86_64.whl ; python_version=="3.10"
grpcio-tools<1.63.0
pandas
pyfg @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/pyfg-0.3.7-cp311-cp311-linux_x86_64.whl ; python_version=="3.11"
pyfg @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/pyfg-0.3.7-cp310-cp310-linux_x86_64.whl ; python_version=="3.10"
pyfg @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/pyfg-0.3.9-cp311-cp311-linux_x86_64.whl ; python_version=="3.11"
pyfg @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/pyfg-0.3.9-cp310-cp310-linux_x86_64.whl ; python_version=="3.10"
pyodps>=0.12.0
scikit-learn
tensorboard
torch==2.5.0
torch-tensorrt @ http://tzrec.oss-cn-beijing.aliyuncs.com/third_party/trt/torch_tensorrt-2.5.0a0-cp311-cp311-linux_x86_64.whl ; python_version=="3.11"
torch-tensorrt @ http://tzrec.oss-cn-beijing.aliyuncs.com/third_party/trt/torch_tensorrt-2.5.0a0-cp310-cp310-linux_x86_64.whl ; python_version=="3.10"
torchmetrics==1.0.3
torchrec==1.0.0
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,6 @@ def parse_require_file(fpath):
extras_require={
"all": parse_requirements("requirements.txt"),
"tests": parse_requirements("requirements/test.txt"),
"gpu": parse_requirements("requirements/gpu.txt"),
},
)
2 changes: 1 addition & 1 deletion tzrec/acc/trt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
pass
from torch import nn
from torch.profiler import ProfilerActivity, profile, record_function
from torchrec.fx import symbolic_trace

from tzrec.acc.utils import is_debug_trt
from tzrec.models.model import ScriptWrapper
from tzrec.utils.fx_util import symbolic_trace
from tzrec.utils.logging_util import logger


Expand Down
34 changes: 18 additions & 16 deletions tzrec/acc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,25 +115,26 @@ def write_mapping_file_for_input_tile(
state_dict (Dict[str, torch.Tensor]): model state_dict
remap_file_path (str) : store new_params_name\told_params_name\n
"""
input_tile_keys = [
".ebc_user.embedding_bags.",
".ebc_item.embedding_bags.",
]
input_tile_keys_ec = [
".ec_list_user.",
".ec_list_item.",
]
input_tile_mapping = {
".ebc_user.embedding_bags.": ".ebc.embedding_bags.",
".ebc_item.embedding_bags.": ".ebc.embedding_bags.",
".mc_ebc_user._embedding_module.": ".mc_ebc._embedding_module.",
".mc_ebc_item._embedding_module.": ".mc_ebc._embedding_module.",
".mc_ebc_user._managed_collision_collection.": ".mc_ebc._managed_collision_collection.", # NOQA
".mc_ebc_item._managed_collision_collection.": ".mc_ebc._managed_collision_collection.", # NOQA
".ec_list_user.": ".ec_list.",
".ec_list_item.": ".ec_list.",
".mc_ec_list_user.": ".mc_ec_list.",
".mc_ec_list_item.": ".mc_ec_list.",
}

remap_str = ""
for key, _ in state_dict.items():
for input_tile_key in input_tile_keys:
for input_tile_key in input_tile_mapping:
if input_tile_key in key:
src_key = key.replace(input_tile_key, ".ebc.embedding_bags.")
remap_str += key + "\t" + src_key + "\n"

for input_tile_key in input_tile_keys_ec:
if input_tile_key in key:
src_key = key.replace(input_tile_key, ".ec_list.")
src_key = key.replace(
input_tile_key, input_tile_mapping[input_tile_key]
)
remap_str += key + "\t" + src_key + "\n"

with open(remap_file_path, "w") as f:
Expand All @@ -142,7 +143,8 @@ def write_mapping_file_for_input_tile(

def export_acc_config() -> Dict[str, str]:
"""Export acc config for model online inference."""
acc_config = dict()
# use int64 sparse id as input
acc_config = {"SPARSE_INT64": "1"}
if "INPUT_TILE" in os.environ:
acc_config["INPUT_TILE"] = os.environ["INPUT_TILE"]
if "QUANT_EMB" in os.environ:
Expand Down
16 changes: 12 additions & 4 deletions tzrec/features/combo_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
ParsedData,
SparseData,
)
from tzrec.features.feature import FgMode, _parse_fg_encoded_sparse_feature_impl
from tzrec.features.feature import (
MAX_HASH_BUCKET_SIZE,
FgMode,
_parse_fg_encoded_sparse_feature_impl,
)
from tzrec.features.id_feature import IdFeature
from tzrec.protos.feature_pb2 import FeatureConfig
from tzrec.utils.logging_util import logger
Expand Down Expand Up @@ -53,7 +57,9 @@ def is_neg(self, value: bool) -> None:
@property
def num_embeddings(self) -> int:
"""Get embedding row count."""
if self.config.HasField("hash_bucket_size"):
if self.config.HasField("zch"):
num_embeddings = self.config.zch.zch_size
elif self.config.HasField("hash_bucket_size"):
num_embeddings = self.config.hash_bucket_size
elif len(self.config.vocab_list) > 0:
num_embeddings = len(self.config.vocab_list) + 2
Expand All @@ -69,7 +75,7 @@ def num_embeddings(self) -> int:
else:
raise ValueError(
f"{self.__class__.__name__}[{self.name}] must set hash_bucket_size"
" or vocab_list or vocab_dict"
" or vocab_list or vocab_dict or zch.zch_size"
)
return num_embeddings

Expand Down Expand Up @@ -116,7 +122,9 @@ def fg_json(self) -> List[Dict[str, Any]]:
}
if self.config.separator != "\x1d":
fg_cfg["separator"] = self.config.separator
if self.config.HasField("hash_bucket_size"):
if self.config.HasField("zch"):
fg_cfg["hash_bucket_size"] = MAX_HASH_BUCKET_SIZE
elif self.config.HasField("hash_bucket_size"):
fg_cfg["hash_bucket_size"] = self.config.hash_bucket_size
elif len(self.config.vocab_list) > 0:
fg_cfg["vocab_list"] = [self.config.default_value, "<OOV>"] + list(
Expand Down
Loading

0 comments on commit 80d22e1

Please sign in to comment.