Skip to content

Commit

Permalink
add zero collision hash embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
tiankongdeguiji committed Dec 10, 2024
1 parent 1c82c42 commit 7efc68d
Show file tree
Hide file tree
Showing 19 changed files with 1,228 additions and 363 deletions.
4 changes: 2 additions & 2 deletions requirements/runtime.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ graphlearn @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/graphlearn-1.
graphlearn @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/graphlearn-1.3.1-cp310-cp310-linux_x86_64.whl ; python_version=="3.10"
grpcio-tools<1.63.0
pandas
pyfg @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/pyfg-0.3.7-cp311-cp311-linux_x86_64.whl ; python_version=="3.11"
pyfg @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/pyfg-0.3.7-cp310-cp310-linux_x86_64.whl ; python_version=="3.10"
pyfg @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/pyfg-0.3.9-cp311-cp311-linux_x86_64.whl ; python_version=="3.11"
pyfg @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/pyfg-0.3.9-cp310-cp310-linux_x86_64.whl ; python_version=="3.10"
pyodps>=0.12.0
scikit-learn
tensorboard
Expand Down
1 change: 1 addition & 0 deletions tzrec/datasets/data_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def _init_fg_hander(self) -> None:
if not self._fg_handler:
fg_json = create_fg_json(self._features)
# pyre-ignore [16]
print(fg_json)
self._fg_handler = pyfg.FgArrowHandler(fg_json, self._fg_threads)

def parse(self, input_data: Dict[str, pa.Array]) -> Dict[str, torch.Tensor]:
Expand Down
10 changes: 8 additions & 2 deletions tzrec/features/combo_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
ParsedData,
SparseData,
)
from tzrec.features.feature import FgMode, _parse_fg_encoded_sparse_feature_impl
from tzrec.features.feature import (
MAX_HASH_BUCKET_SIZE,
FgMode,
_parse_fg_encoded_sparse_feature_impl,
)
from tzrec.features.id_feature import IdFeature
from tzrec.protos.feature_pb2 import FeatureConfig
from tzrec.utils.logging_util import logger
Expand Down Expand Up @@ -116,7 +120,9 @@ def fg_json(self) -> List[Dict[str, Any]]:
}
if self.config.separator != "\x1d":
fg_cfg["separator"] = self.config.separator
if self.config.HasField("hash_bucket_size"):
if self.config.HasField("zch"):
fg_cfg["hash_bucket_size"] = MAX_HASH_BUCKET_SIZE
elif self.config.HasField("hash_bucket_size"):
fg_cfg["hash_bucket_size"] = self.config.hash_bucket_size
elif len(self.config.vocab_list) > 0:
fg_cfg["vocab_list"] = [self.config.default_value, "<OOV>"] + list(
Expand Down
49 changes: 49 additions & 0 deletions tzrec/features/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,23 @@
import numpy as np
import pyarrow as pa
import pyfg
import torch
from torch import nn # NOQA
from torchrec.modules.embedding_configs import (
EmbeddingBagConfig,
EmbeddingConfig,
PoolingType,
)
from torchrec.modules.mc_modules import (
DistanceLFU_EvictionPolicy,
LFU_EvictionPolicy,
LRU_EvictionPolicy,
ManagedCollisionModule,
MCHManagedCollisionModule,
average_threshold_filter, # NOQA
dynamic_threshold_filter, # NOQA
probabilistic_threshold_filter, # NOQA
)

from tzrec.datasets.utils import (
BASE_DATA_GROUP,
Expand All @@ -52,6 +63,9 @@ class FgMode(Enum):
DAG = 3


MAX_HASH_BUCKET_SIZE = 2**31 - 1


def _parse_fg_encoded_sparse_feature_impl(
name: str,
feat: pa.Array,
Expand Down Expand Up @@ -386,6 +400,41 @@ def emb_config(self) -> Optional[EmbeddingConfig]:
else:
return None

def mc_module(self, device: torch.device) -> Optional[ManagedCollisionModule]:
"""Get ManagedCollisionModule."""
if self.is_sparse:
if hasattr(self.config, "zch") and self.config.HasField("zch"):
evict_type = self.config.zch.WhichOneof("eviction_policy")
evict_config = getattr(self.config.zch, evict_type)
threshold_filtering_func = None
if evict_config.HasField("threshold_filtering_func"):
threshold_filtering_func = eval(
evict_config.threshold_filtering_func
)
if evict_type == "lfu":
eviction_policy = LFU_EvictionPolicy(
threshold_filtering_func=threshold_filtering_func
)
elif evict_type == "lru":
eviction_policy = LRU_EvictionPolicy(
decay_exponent=evict_config.decay_exponent,
threshold_filtering_func=threshold_filtering_func,
)
elif evict_type == "distance_lfu":
eviction_policy = DistanceLFU_EvictionPolicy(
decay_exponent=evict_config.decay_exponent,
threshold_filtering_func=threshold_filtering_func,
)
else:
raise ValueError("Unknown evict policy type: {evict_type}")
return MCHManagedCollisionModule(
zch_size=self.config.zch.zch_size,
device=device,
eviction_interval=self.config.zch.eviction_interval,
eviction_policy=eviction_policy,
)
return None

@property
def inputs(self) -> List[str]:
"""Input field names."""
Expand Down
12 changes: 9 additions & 3 deletions tzrec/features/id_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
SparseData,
)
from tzrec.features.feature import (
MAX_HASH_BUCKET_SIZE,
BaseFeature,
FgMode,
_parse_fg_encoded_sparse_feature_impl,
Expand Down Expand Up @@ -71,7 +72,9 @@ def is_sparse(self) -> bool:
@property
def num_embeddings(self) -> int:
"""Get embedding row count."""
if self.config.HasField("hash_bucket_size"):
if self.config.HasField("zch"):
num_embeddings = self.config.zch.zch_size
elif self.config.HasField("hash_bucket_size"):
num_embeddings = self.config.hash_bucket_size
elif self.config.HasField("num_buckets"):
num_embeddings = self.config.num_buckets
Expand All @@ -89,7 +92,7 @@ def num_embeddings(self) -> int:
else:
raise ValueError(
f"{self.__class__.__name__}[{self.name}] must set hash_bucket_size"
" or num_buckets or vocab_list or vocab_dict"
" or num_buckets or vocab_list or vocab_dict or zch.zch_size"
)
return num_embeddings

Expand Down Expand Up @@ -156,7 +159,10 @@ def fg_json(self) -> List[Dict[str, Any]]:
}
if self.config.separator != "\x1d":
fg_cfg["separator"] = self.config.separator
if self.config.HasField("hash_bucket_size"):
if self.config.HasField("zch"):
fg_cfg["hash_bucket_size"] = MAX_HASH_BUCKET_SIZE
fg_cfg["value_type"] = "string"
elif self.config.HasField("hash_bucket_size"):
fg_cfg["hash_bucket_size"] = self.config.hash_bucket_size
fg_cfg["value_type"] = "string"
elif len(self.config.vocab_list) > 0:
Expand Down
40 changes: 40 additions & 0 deletions tzrec/features/id_feature_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import numpy as np
import pyarrow as pa
import torch
from parameterized import parameterized
from torch import nn
from torchrec.modules.embedding_configs import (
Expand Down Expand Up @@ -88,6 +89,45 @@ def test_init_fn_id_feature(self):
)
self.assertEqual(repr(id_feat.emb_config), repr(expected_emb_config))

def test_zch_id_feature(self):
id_feat_cfg = feature_pb2.FeatureConfig(
id_feature=feature_pb2.IdFeature(
feature_name="id_feat",
embedding_dim=16,
zch=feature_pb2.ZeroCollisionHash(
zch_size=100,
eviction_interval=5,
distance_lfu=feature_pb2.DistanceLFU_EvictionPolicy(
decay_exponent=1.0,
threshold_filtering_func="lambda x:"
" probabilistic_threshold_filter(x,0.05)",
),
),
)
)
id_feat = id_feature_lib.IdFeature(id_feat_cfg)
expected_emb_bag_config = EmbeddingBagConfig(
num_embeddings=100,
embedding_dim=16,
name="id_feat_emb",
feature_names=["id_feat"],
pooling=PoolingType.SUM,
)
self.assertEqual(repr(id_feat.emb_bag_config), repr(expected_emb_bag_config))
expected_emb_config = EmbeddingConfig(
num_embeddings=100,
embedding_dim=16,
name="id_feat_emb",
feature_names=["id_feat"],
)
self.assertEqual(repr(id_feat.emb_config), repr(expected_emb_config))
mc_module = id_feat.mc_module(torch.device("meta"))
self.assertEqual(mc_module._zch_size, 100)
self.assertEqual(mc_module._eviction_interval, 5)
self.assertTrue(
mc_module._eviction_policy._threshold_filtering_func is not None
)

def test_fg_encoded_with_weighted(self):
id_feat_cfg = feature_pb2.FeatureConfig(
id_feature=feature_pb2.IdFeature(
Expand Down
7 changes: 6 additions & 1 deletion tzrec/features/lookup_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
SparseData,
)
from tzrec.features.feature import (
MAX_HASH_BUCKET_SIZE,
BaseFeature,
FgMode,
_parse_fg_encoded_dense_feature_impl,
Expand Down Expand Up @@ -201,7 +202,11 @@ def fg_json(self) -> List[Dict[str, Any]]:
fg_cfg["separator"] = self.config.separator
if self.config.HasField("normalizer"):
fg_cfg["normalizer"] = self.config.normalizer
if self.config.HasField("hash_bucket_size"):
if self.config.HasField("zch"):
fg_cfg["hash_bucket_size"] = MAX_HASH_BUCKET_SIZE
fg_cfg["value_type"] = "string"
fg_cfg["needDiscrete"] = True
elif self.config.HasField("hash_bucket_size"):
fg_cfg["hash_bucket_size"] = self.config.hash_bucket_size
fg_cfg["value_type"] = "string"
fg_cfg["needDiscrete"] = True
Expand Down
7 changes: 6 additions & 1 deletion tzrec/features/match_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
SparseData,
)
from tzrec.features.feature import (
MAX_HASH_BUCKET_SIZE,
BaseFeature,
FgMode,
_parse_fg_encoded_dense_feature_impl,
Expand Down Expand Up @@ -174,7 +175,11 @@ def fg_json(self) -> List[Dict[str, Any]]:
fg_cfg["separator"] = self.config.separator
if self.config.HasField("normalizer"):
fg_cfg["normalizer"] = self.config.normalizer
if self.config.HasField("hash_bucket_size"):
if self.config.HasField("zch"):
fg_cfg["hash_bucket_size"] = MAX_HASH_BUCKET_SIZE
fg_cfg["value_type"] = "string"
fg_cfg["needDiscrete"] = True
elif self.config.HasField("hash_bucket_size"):
fg_cfg["hash_bucket_size"] = self.config.hash_bucket_size
fg_cfg["value_type"] = "string"
fg_cfg["needDiscrete"] = True
Expand Down
5 changes: 4 additions & 1 deletion tzrec/features/sequence_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
SequenceDenseData,
SequenceSparseData,
)
from tzrec.features.feature import MAX_HASH_BUCKET_SIZE
from tzrec.features.id_feature import FgMode, IdFeature
from tzrec.features.raw_feature import RawFeature
from tzrec.protos import feature_pb2
Expand Down Expand Up @@ -303,7 +304,9 @@ def fg_json(self) -> List[Dict[str, Any]]:
fg_cfg["sequence_length"] = self.config.sequence_length
if self.config.separator != "\x1d":
fg_cfg["separator"] = self.config.separator
if self.config.HasField("hash_bucket_size"):
if self.config.HasField("zch"):
fg_cfg["hash_bucket_size"] = MAX_HASH_BUCKET_SIZE
elif self.config.HasField("hash_bucket_size"):
fg_cfg["hash_bucket_size"] = self.config.hash_bucket_size
elif self.config.HasField("num_buckets"):
fg_cfg["num_buckets"] = self.config.num_buckets
Expand Down
7 changes: 3 additions & 4 deletions tzrec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@

# NOQA
from torchrec.distributed.train_pipeline import TrainPipelineSparseDist
from torchrec.fx import symbolic_trace
from torchrec.inference.modules import quantize_embeddings
from torchrec.inference.state_dict_transform import (
state_dict_gather,
Expand Down Expand Up @@ -85,6 +84,7 @@
from tzrec.protos.pipeline_pb2 import EasyRecConfig
from tzrec.protos.train_pb2 import TrainConfig
from tzrec.utils import checkpoint_util, config_util
from tzrec.utils.fx_util import symbolic_trace
from tzrec.utils.logging_util import ProgressLogger, logger
from tzrec.utils.plan_util import create_planner, get_default_sharders
from tzrec.version import __version__ as tzrec_version
Expand Down Expand Up @@ -751,7 +751,7 @@ def _script_model(
model.eval()

if is_trt_convert:
data_cuda = batch.to_dict(sparse_dtype=torch.int32)
data_cuda = batch.to_dict(sparse_dtype=torch.int64)
result = model(data_cuda, "cuda:0")
result_info = {k: (v.size(), v.dtype) for k, v in result.items()}
logger.info(f"Model Outputs: {result_info}")
Expand Down Expand Up @@ -1008,7 +1008,6 @@ def predict(

device_and_backend = init_process_group()
device: torch.device = device_and_backend[0]
sparse_dtype: torch.dtype = torch.int32 if device.type == "cuda" else torch.int64

is_rank_zero = int(os.environ.get("RANK", 0)) == 0
is_local_rank_zero = int(os.environ.get("LOCAL_RANK", 0)) == 0
Expand Down Expand Up @@ -1077,7 +1076,7 @@ def predict(

def _forward(batch: Batch) -> Tuple[Dict[str, torch.Tensor], RecordBatchTensor]:
with torch.no_grad():
parsed_inputs = batch.to_dict(sparse_dtype=sparse_dtype)
parsed_inputs = batch.to_dict(sparse_dtype=torch.int64)
# when predicting with a model exported using INPUT_TILE,
# we set the batch size tensor to 1 to disable tiling.
parsed_inputs["batch_size"] = torch.tensor(1, dtype=torch.int64)
Expand Down
Loading

0 comments on commit 7efc68d

Please sign in to comment.