From 27d22d57608ef20f9390296bd54693c20de6dac2 Mon Sep 17 00:00:00 2001 From: tianzhou Date: Fri, 6 Dec 2024 14:36:39 +0800 Subject: [PATCH 1/3] add match model proto --- tzrec/protos/models/match_model.proto | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tzrec/protos/models/match_model.proto b/tzrec/protos/models/match_model.proto index c648872..b1e09e5 100644 --- a/tzrec/protos/models/match_model.proto +++ b/tzrec/protos/models/match_model.proto @@ -22,6 +22,19 @@ message DSSM { optional bool in_batch_negative = 6 [default = false]; } +message HSTU { + required Tower user_tower = 1; + required Tower item_tower = 2; + // user and item tower output dimension + required int32 output_dim = 3; + // similarity method + optional Similarity similarity = 4 [default=INNER_PRODUCT]; + // similarity scaling factor + optional float temperature = 5 [default = 1.0]; + // use in batch items as negative items. + optional bool in_batch_negative = 6 [default = false]; +} + message DSSMV2 { required Tower user_tower = 1; required Tower item_tower = 2; From d71db9648d18e77f0100258fef68e5b3279ec8db Mon Sep 17 00:00:00 2001 From: tianzhou Date: Fri, 6 Dec 2024 14:44:48 +0800 Subject: [PATCH 2/3] add hstu model content, initialization --- tzrec/models/hstu.py | 165 ++++++++++ tzrec/modules/hstu.py | 407 ++++++++++++++++++++++++ tzrec/modules/sequence.py | 214 ++++++++++++- tzrec/modules/sequence_test.py | 41 +++ tzrec/protos/model.proto | 1 + tzrec/protos/seq_encoder.proto | 32 ++ tzrec/tests/configs/hstu_fg_mock.config | 259 +++++++++++++++ tzrec/tests/train_eval_export_test.py | 24 ++ 8 files changed, 1142 insertions(+), 1 deletion(-) create mode 100644 tzrec/models/hstu.py create mode 100644 tzrec/modules/hstu.py create mode 100644 tzrec/tests/configs/hstu_fg_mock.config diff --git a/tzrec/models/hstu.py b/tzrec/models/hstu.py new file mode 100644 index 0000000..154b7a1 --- /dev/null +++ b/tzrec/models/hstu.py @@ -0,0 +1,165 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +from typing import Any, Dict, List, Optional + +import torch +import torch.nn.functional as F +from torch import nn +from torch._tensor import Tensor + +from tzrec.datasets.utils import Batch +from tzrec.features.feature import BaseFeature +from tzrec.models.match_model import MatchModel, MatchTower +from tzrec.modules.mlp import MLP +from tzrec.protos import model_pb2, tower_pb2 +from tzrec.protos.models import match_model_pb2 +from tzrec.utils.config_util import config_to_kwargs + + +@torch.fx.wrap +def _update_dict_tensor( + tensor_dict: Dict[str, torch.Tensor], + new_tensor_dict: Optional[Dict[str, Optional[torch.Tensor]]], +) -> None: + if new_tensor_dict: + for k, v in new_tensor_dict.items(): + if v is not None: + tensor_dict[k] = v + + +class HSTUTower(MatchTower): + """HSTU user/item tower. + + Args: + tower_config (Tower): user/item tower config. + output_dim (int): user/item output embedding dimension. + similarity (Similarity): when use COSINE similarity, + will norm the output embedding. + feature_group (FeatureGroupConfig): feature group config. + features (list): list of features. + """ + + def __init__( + self, + tower_config: tower_pb2.Tower, + output_dim: int, + similarity: match_model_pb2.Similarity, + feature_group: model_pb2.FeatureGroupConfig, + features: List[BaseFeature], + model_config: model_pb2.ModelConfig, + ) -> None: + super().__init__( + tower_config, output_dim, similarity, feature_group, features, model_config + ) + self.init_input() + tower_feature_in = self.embedding_group.group_total_dim(self._group_name) + self.mlp = MLP(tower_feature_in, **config_to_kwargs(tower_config.mlp)) + if self._output_dim > 0: + self.output = nn.Linear(self.mlp.output_dim(), output_dim) + + def forward(self, batch: Batch) -> torch.Tensor: + """Forward the tower. + + Args: + batch (Batch): input batch data. + + Return: + embedding (dict): tower output embedding. + """ + grouped_features = self.build_input(batch) + output = self.mlp(grouped_features[self._group_name]) + # TODO: this will cause dimension unmatch in self.output, considering resolutions + # if self._tower_config.input == 'user': + # output = grouped_features[self._group_name] + # else: + # output = self.mlp(grouped_features[self._group_name]) + + if self._output_dim > 0: + output = self.output(output) + if self._similarity == match_model_pb2.Similarity.COSINE: + output = F.normalize(output, p=2.0, dim=1) + return output + + +class HSTU(MatchModel): + """HSTU model. + + Args: + model_config (ModelConfig): an instance of ModelConfig. + features (list): list of features. + labels (list): list of label names. + """ + + def __init__( + self, + model_config: model_pb2.ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + super().__init__(model_config, features, labels, sample_weights, **kwargs) + name_to_feature_group = {x.group_name: x for x in model_config.feature_groups} + + user_group = name_to_feature_group[self._model_config.user_tower.input] + item_group = name_to_feature_group[self._model_config.item_tower.input] + + name_to_feature = {x.name: x for x in features} + user_features = OrderedDict( + [(x, name_to_feature[x]) for x in user_group.feature_names] + ) + for sequence_group in user_group.sequence_groups: + for x in sequence_group.feature_names: + user_features[x] = name_to_feature[x] + item_features = [name_to_feature[x] for x in item_group.feature_names] + + self.user_tower = HSTUTower( + self._model_config.user_tower, + self._model_config.output_dim, + self._model_config.similarity, + user_group, + list(user_features.values()), + model_config, + ) + + self.item_tower = HSTUTower( + self._model_config.item_tower, + self._model_config.output_dim, + self._model_config.similarity, + item_group, + item_features, + model_config, + ) + + def predict(self, batch: Batch) -> Dict[str, Tensor]: + """Forward the model. + + Args: + batch (Batch): input batch data. + + Return: + predictions (dict): a dict of predicted result. + """ + user_tower_emb = self.user_tower(batch) + item_tower_emb = self.item_tower(batch) + _update_dict_tensor( + self._loss_collection, self.user_tower.group_variational_dropout_loss + ) + _update_dict_tensor( + self._loss_collection, self.item_tower.group_variational_dropout_loss + ) + + ui_sim = ( + self.sim(user_tower_emb, item_tower_emb) / self._model_config.temperature + ) + return {"similarity": ui_sim} diff --git a/tzrec/modules/hstu.py b/tzrec/modules/hstu.py new file mode 100644 index 0000000..86b845a --- /dev/null +++ b/tzrec/modules/hstu.py @@ -0,0 +1,407 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +import abc + +from typing import Callable, Optional, Tuple + +TIMESTAMPS_KEY = "timestamps" + + +class RelativeAttentionBiasModule(torch.nn.Module): + + @abc.abstractmethod + def forward( + self, + all_timestamps: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + all_timestamps: [B, N] x int64 + Returns: + torch.float tensor broadcastable to [B, N, N] + """ + pass + + +class RelativePositionalBias(RelativeAttentionBiasModule): + + def __init__(self, max_seq_len: int) -> None: + super().__init__() + + self._max_seq_len: int = max_seq_len + self._w = torch.nn.Parameter( + torch.empty(2 * max_seq_len - 1).normal_(mean=0, std=0.02), + ) + + def forward( + self, + all_timestamps: torch.Tensor, + ) -> torch.Tensor: + del all_timestamps + n: int = self._max_seq_len + t = F.pad(self._w[: 2 * n - 1], [0, n]).repeat(n) + t = t[..., :-n].reshape(1, n, 3 * n - 2) + r = (2 * n - 1) // 2 + return t[..., r:-r] + + +class RelativeBucketedTimeAndPositionBasedBias(RelativeAttentionBiasModule): + """ + Bucketizes timespans based on ts(next-item) - ts(current-item). + """ + + def __init__( + self, + max_seq_len: int, + num_buckets: int, + bucketization_fn: Callable[[torch.Tensor], torch.Tensor], + ) -> None: + super().__init__() + + self._max_seq_len: int = max_seq_len + self._ts_w = torch.nn.Parameter( + torch.empty(num_buckets + 1).normal_(mean=0, std=0.02), + ) + self._pos_w = torch.nn.Parameter( + torch.empty(2 * max_seq_len - 1).normal_(mean=0, std=0.02), + ) + self._num_buckets: int = num_buckets + self._bucketization_fn: Callable[[torch.Tensor], torch.Tensor] = ( + bucketization_fn + ) + + def forward( + self, + all_timestamps: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + all_timestamps: (B, N). + Returns: + (B, N, N). + """ + B = all_timestamps.size(0) + N = self._max_seq_len + t = F.pad(self._pos_w[: 2 * N - 1], [0, N]).repeat(N) + t = t[..., :-N].reshape(1, N, 3 * N - 2) + r = (2 * N - 1) // 2 + + # [B, N + 1] to simplify tensor manipulations. + ext_timestamps = torch.cat( + [all_timestamps, all_timestamps[:, N - 1 : N]], dim=1 + ) + # causal masking. Otherwise [:, :-1] - [:, 1:] works + bucketed_timestamps = torch.clamp( + self._bucketization_fn( + ext_timestamps[:, 1:].unsqueeze(2) - ext_timestamps[:, :-1].unsqueeze(1) + ), + min=0, + max=self._num_buckets, + ).detach() + rel_pos_bias = t[:, :, r:-r] + rel_ts_bias = torch.index_select( + self._ts_w, dim=0, index=bucketed_timestamps.view(-1) + ).view(B, N, N) + return rel_pos_bias + rel_ts_bias + + +HSTUCacheState = Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + + +def _hstu_attention_maybe_from_cache( + num_heads: int, + attention_dim: int, + linear_dim: int, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cached_q: Optional[torch.Tensor], + cached_k: Optional[torch.Tensor], + delta_x_offsets: Optional[Tuple[torch.Tensor, torch.Tensor]], + x_offsets: torch.Tensor, + all_timestamps: Optional[torch.Tensor], + invalid_attn_mask: torch.Tensor, + rel_attn_bias: RelativeAttentionBiasModule, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + B: int = x_offsets.size(0) - 1 + n: int = invalid_attn_mask.size(-1) + if delta_x_offsets is not None: + padded_q, padded_k = cached_q, cached_k + flattened_offsets = delta_x_offsets[1] + torch.arange( + start=0, + end=B * n, + step=n, + device=delta_x_offsets[1].device, + dtype=delta_x_offsets[1].dtype, + ) + assert isinstance(padded_q, torch.Tensor) + assert isinstance(padded_k, torch.Tensor) + padded_q = ( + padded_q.view(B * n, -1) + .index_copy_( + dim=0, + index=flattened_offsets, + source=q, + ) + .view(B, n, -1) + ) + padded_k = ( + padded_k.view(B * n, -1) + .index_copy_( + dim=0, + index=flattened_offsets, + source=k, + ) + .view(B, n, -1) + ) + else: + padded_q = torch.ops.fbgemm.jagged_to_padded_dense( + values=q, offsets=[x_offsets], max_lengths=[n], padding_value=0.0 + ) + padded_k = torch.ops.fbgemm.jagged_to_padded_dense( + values=k, offsets=[x_offsets], max_lengths=[n], padding_value=0.0 + ) + + qk_attn = torch.einsum( + "bnhd,bmhd->bhnm", + padded_q.view(B, n, num_heads, attention_dim), + padded_k.view(B, n, num_heads, attention_dim), + ) + if all_timestamps is not None: + qk_attn = qk_attn + rel_attn_bias(all_timestamps).unsqueeze(1) + qk_attn = F.silu(qk_attn) / n + qk_attn = qk_attn * invalid_attn_mask.unsqueeze(0).unsqueeze(0) + attn_output = torch.ops.fbgemm.dense_to_jagged( + torch.einsum( + "bhnm,bmhd->bnhd", + qk_attn, + torch.ops.fbgemm.jagged_to_padded_dense(v, [x_offsets], [n]).reshape( + B, n, num_heads, linear_dim + ), + ).reshape(B, n, num_heads * linear_dim), + [x_offsets], + )[0] + return attn_output, padded_q, padded_k + + +class SequentialTransductionUnitJagged(torch.nn.Module): + def __init__( + self, + embedding_dim: int, + linear_hidden_dim: int, + attention_dim: int, + dropout_ratio: float, + attn_dropout_ratio: float, + num_heads: int, + linear_activation: str, + relative_attention_bias_module: Optional[RelativeAttentionBiasModule] = None, + normalization: str = "rel_bias", + linear_config: str = "uvqk", + concat_ua: bool = False, + epsilon: float = 1e-6, + max_length: Optional[int] = None, + ) -> None: + super().__init__() + self._embedding_dim: int = embedding_dim + self._linear_dim: int = linear_hidden_dim + self._attention_dim: int = attention_dim + self._dropout_ratio: float = dropout_ratio + self._attn_dropout_ratio: float = attn_dropout_ratio + self._num_heads: int = num_heads + self._rel_attn_bias: Optional[RelativeAttentionBiasModule] = ( + relative_attention_bias_module + ) + self._normalization: str = normalization + self._linear_config: str = linear_config + if self._linear_config == "uvqk": + self._uvqk: torch.nn.Parameter = torch.nn.Parameter( + torch.empty( + ( + embedding_dim, + linear_hidden_dim * 2 * num_heads + + attention_dim * num_heads * 2, + ) + ).normal_(mean=0, std=0.02), + ) + else: + raise ValueError(f"Unknown linear_config {self._linear_config}") + self._linear_activation: str = linear_activation + self._concat_ua: bool = concat_ua + self._o = torch.nn.Linear( + in_features=linear_hidden_dim * num_heads * (3 if concat_ua else 1), + out_features=embedding_dim, + ) + torch.nn.init.xavier_uniform_(self._o.weight) + self._eps: float = epsilon + + def _norm_input(self, x: torch.Tensor) -> torch.Tensor: + return F.layer_norm(x, normalized_shape=[self._embedding_dim], eps=self._eps) + + def _norm_attn_output(self, x: torch.Tensor) -> torch.Tensor: + return F.layer_norm( + x, normalized_shape=[self._linear_dim * self._num_heads], eps=self._eps + ) + + def forward( # pyre-ignore [3] + self, + x: torch.Tensor, + x_offsets: torch.Tensor, + all_timestamps: Optional[torch.Tensor], + invalid_attn_mask: torch.Tensor, + delta_x_offsets: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + cache: Optional[HSTUCacheState] = None, + return_cache_states: bool = False, + ): + """ + Args: + x: (\sum_i N_i, D) x float. + x_offsets: (B + 1) x int32. + all_timestamps: optional (B, N) x int64. + invalid_attn_mask: (B, N, N) x float, each element in {0, 1}. + delta_x_offsets: optional 2-tuple ((B,) x int32, (B,) x int32). + For the 1st element in the tuple, each element is in [0, x_offsets[-1]). For the + 2nd element in the tuple, each element is in [0, N). + cache: Optional 4-tuple of (v, padded_q, padded_k, output) from prior runs, + where all except padded_q, padded_k are jagged. + Returns: + x' = f(x), (\sum_i N_i, D) x float. + """ + n: int = invalid_attn_mask.size(-1) + cached_q = None + cached_k = None + if delta_x_offsets is not None: + # In this case, for all the following code, x, u, v, q, k become restricted to + # [delta_x_offsets[0], :]. + assert cache is not None + x = x[delta_x_offsets[0], :] + cached_v, cached_q, cached_k, cached_outputs = cache + + normed_x = self._norm_input(x) + + if self._linear_config == "uvqk": + batched_mm_output = torch.mm(normed_x, self._uvqk) + if self._linear_activation == "silu": + batched_mm_output = F.silu(batched_mm_output) + elif self._linear_activation == "none": + batched_mm_output = batched_mm_output + u, v, q, k = torch.split( + batched_mm_output, + [ + self._linear_dim * self._num_heads, + self._linear_dim * self._num_heads, + self._attention_dim * self._num_heads, + self._attention_dim * self._num_heads, + ], + dim=1, + ) + else: + raise ValueError(f"Unknown self._linear_config {self._linear_config}") + + if delta_x_offsets is not None: + v = cached_v.index_copy_(dim=0, index=delta_x_offsets[0], source=v) + + B: int = x_offsets.size(0) - 1 + if self._normalization == "rel_bias" or self._normalization == "hstu_rel_bias": + assert self._rel_attn_bias is not None + attn_output, padded_q, padded_k = _hstu_attention_maybe_from_cache( + num_heads=self._num_heads, + attention_dim=self._attention_dim, + linear_dim=self._linear_dim, + q=q, + k=k, + v=v, + cached_q=cached_q, + cached_k=cached_k, + delta_x_offsets=delta_x_offsets, + x_offsets=x_offsets, + all_timestamps=all_timestamps, + invalid_attn_mask=invalid_attn_mask, + rel_attn_bias=self._rel_attn_bias, + ) + elif self._normalization == "softmax_rel_bias": + if delta_x_offsets is not None: + B = x_offsets.size(0) - 1 + padded_q, padded_k = cached_q, cached_k + flattened_offsets = delta_x_offsets[1] + torch.arange( + start=0, + end=B * n, + step=n, + device=delta_x_offsets[1].device, + dtype=delta_x_offsets[1].dtype, + ) + assert padded_q is not None + assert padded_k is not None + padded_q = ( + padded_q.view(B * n, -1) + .index_copy_( + dim=0, + index=flattened_offsets, + source=q, + ) + .view(B, n, -1) + ) + padded_k = ( + padded_k.view(B * n, -1) + .index_copy_( + dim=0, + index=flattened_offsets, + source=k, + ) + .view(B, n, -1) + ) + else: + padded_q = torch.ops.fbgemm.jagged_to_padded_dense( + values=q, offsets=[x_offsets], max_lengths=[n], padding_value=0.0 + ) + padded_k = torch.ops.fbgemm.jagged_to_padded_dense( + values=k, offsets=[x_offsets], max_lengths=[n], padding_value=0.0 + ) + + qk_attn = torch.einsum("bnd,bmd->bnm", padded_q, padded_k) + if self._rel_attn_bias is not None: + qk_attn = qk_attn + self._rel_attn_bias(all_timestamps) + qk_attn = F.softmax(qk_attn / math.sqrt(self._attention_dim), dim=-1) + qk_attn = qk_attn * invalid_attn_mask + attn_output = torch.ops.fbgemm.dense_to_jagged( + torch.bmm( + qk_attn, + torch.ops.fbgemm.jagged_to_padded_dense(v, [x_offsets], [n]), + ), + [x_offsets], + )[0] + else: + raise ValueError(f"Unknown normalization method {self._normalization}") + + attn_output = ( + attn_output + if delta_x_offsets is None + else attn_output[delta_x_offsets[0], :] + ) + if self._concat_ua: + a = self._norm_attn_output(attn_output) + o_input = torch.cat([u, a, u * a], dim=-1) + else: + o_input = u * self._norm_attn_output(attn_output) + + new_outputs = ( + self._o( + F.dropout( + o_input, + p=self._dropout_ratio, + training=self.training, + ) + ) + + x + ) + + if delta_x_offsets is not None: + new_outputs = cached_outputs.index_copy_( + dim=0, index=delta_x_offsets[0], source=new_outputs + ) + + if return_cache_states and delta_x_offsets is None: + v = v.contiguous() + + return new_outputs, (v, padded_q, padded_k, new_outputs) \ No newline at end of file diff --git a/tzrec/modules/sequence.py b/tzrec/modules/sequence.py index 3d4e44b..4b73245 100644 --- a/tzrec/modules/sequence.py +++ b/tzrec/modules/sequence.py @@ -9,14 +9,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Tuple, Optional +import fbgemm_gpu import numpy as np import torch from torch import nn from torch.nn import functional as F from tzrec.modules.mlp import MLP +from tzrec.modules.hstu import RelativeBucketedTimeAndPositionBasedBias, SequentialTransductionUnitJagged +from tzrec.modules.hstu import HSTUCacheState from tzrec.protos.seq_encoder_pb2 import SeqEncoderConfig from tzrec.utils import config_util from tzrec.utils.load_class import get_register_class_meta @@ -295,3 +298,212 @@ def create_seq_encoder( seq_config_dict["query_dim"] = query_dim seq_encoder = model_cls(**seq_config_dict) return seq_encoder + + +class HSTUEncoder(SequenceEncoder): + """HSTU sequence encoder. + + Args: + sequence_dim (int): sequence tensor channel dimension. + query_dim (int): query tensor channel dimension. + input(str): input feature group name. + attn_mlp (dict): target attention MLP module parameters. + """ + + def __init__( + self, + sequence_dim: int, + input: str, + max_seq_length: int, + pos_dropout_rate: float = 0.5, + linear_dropout_rate: float = 0.2, + attn_dropout_rate: float = 0.0, + normalization: str = 'rel_bias', + linear_activation: str = 'silu', + linear_config: str = 'uvqk', + num_heads: int = 4, + num_blocks: int = 4, + max_output_len: int = 2, + time_bucket_size: int = 128, + **kwargs: Optional[Dict[str, Any]] + ) -> None: + super().__init__(input) + self._sequence_dim = sequence_dim + self._max_seq_length = max_seq_length + self._query_name = f"{input}.query" + self._sequence_name = f"{input}.sequence" + self._sequence_length_name = f"{input}.sequence_length" + self.position_embed = nn.Embedding(self._max_seq_length + max_output_len + 1, self._sequence_dim, padding_idx=0) + self.dropout_rate = pos_dropout_rate + self.enable_relative_attention_bias = True + self.autocast_dtype = None + self._attention_layers: nn.ModuleList = nn.ModuleList( + modules=[ + SequentialTransductionUnitJagged( + embedding_dim=self._sequence_dim, + linear_hidden_dim=self._sequence_dim, + attention_dim=self._sequence_dim, + normalization=normalization, + linear_config=linear_config, + linear_activation=linear_activation, + num_heads=num_heads, + relative_attention_bias_module=( + RelativeBucketedTimeAndPositionBasedBias( + max_seq_len=max_seq_length + max_output_len, + num_buckets=time_bucket_size, + bucketization_fn=lambda x: ( + torch.log(torch.abs(x).clamp(min=1)) / 0.301 + ).long(), + ) + if self.enable_relative_attention_bias + else None + ), + dropout_ratio=linear_dropout_rate, + attn_dropout_ratio=attn_dropout_rate, + concat_ua=False, + ) + for _ in range(num_blocks) + ] + ) + self.register_buffer( + "_attn_mask", + torch.triu( + torch.ones( + ( + self._max_seq_length + max_output_len, + self._max_seq_length + max_output_len, + ), + dtype=torch.bool, + ), + diagonal=1, + ), + ) + self._autocast_dtype = None + + def output_dim(self) -> int: + """Output dimension of the module.""" + return self._sequence_dim + + def forward(self, sequence_embedded: Dict[str, torch.Tensor]) -> torch.Tensor: + """Forward the module.""" + sequence = sequence_embedded[self._sequence_name] # B, N, E + sequence_length = sequence_embedded[self._sequence_length_name] # N + # max_seq_length = sequence.size(1) + float_dtype = sequence.dtype + + # Add positional embeddings and apply dropout + positions = _arange(sequence.size(1), device=sequence.device).unsqueeze(0).expand(sequence.size(0), -1) + sequence = sequence * (self._sequence_dim**0.5) + self.position_embed(positions) + sequence = F.dropout(sequence, p=self.dropout_rate, training=self.training) + sequence_mask = _arange(sequence.size(1), device=sequence_length.device).unsqueeze(0) < sequence_length.unsqueeze(1) + sequence = sequence * sequence_mask.unsqueeze(-1).to(float_dtype) + + invalid_attn_mask = 1.0 - self._attn_mask.to(float_dtype) + sequence_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(sequence_length) + sequence = torch.ops.fbgemm.dense_to_jagged(sequence, [sequence_offsets])[0] + + all_timestamps = None + jagged_x, cache_states = self.jagged_forward( + x=sequence, + x_offsets=sequence_offsets, + all_timestamps=all_timestamps, + invalid_attn_mask=invalid_attn_mask, + delta_x_offsets=None, + cache=None, + return_cache_states=False, + ) + output_embeddings = torch.ops.fbgemm.jagged_to_padded_dense( + values=jagged_x, + offsets=[sequence_offsets], + max_lengths=[invalid_attn_mask.size(1)], + padding_value=0.0, + ) + # post processing + output_embeddings = output_embeddings[..., : self._sequence_dim] + output_embeddings = output_embeddings / torch.clamp( + torch.linalg.norm(output_embeddings, ord=None, dim=-1, keepdim=True), + min=1e-6, + ) + output_embeddings = self.get_current_embeddings(sequence_length, output_embeddings) + return output_embeddings + + def jagged_forward( + self, + x: torch.Tensor, + x_offsets: torch.Tensor, + all_timestamps: Optional[torch.Tensor], + invalid_attn_mask: torch.Tensor, + delta_x_offsets: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + cache: Optional[List[HSTUCacheState]] = None, + return_cache_states: bool = False, + ) -> Tuple[torch.Tensor, List[HSTUCacheState]]: + """ + Args: + x: (\sum_i N_i, D) x float + x_offsets: (B + 1) x int32 + all_timestamps: (B, 1 + N) x int64 + invalid_attn_mask: (B, N, N) x float, each element in {0, 1} + return_cache_states: bool. True if we should return cache states. + + Returns: + x' = f(x), (\sum_i N_i, D) x float + """ + cache_states: List[HSTUCacheState] = [] + + with torch.autocast( + "cuda", + enabled=self._autocast_dtype is not None, + dtype=self._autocast_dtype or torch.float16, + ): + for i, layer in enumerate(self._attention_layers): + x, cache_states_i = layer( + x=x, + x_offsets=x_offsets, + all_timestamps=all_timestamps, + invalid_attn_mask=invalid_attn_mask, + delta_x_offsets=delta_x_offsets, + cache=cache[i] if cache is not None else None, + return_cache_states=return_cache_states, + ) + print(f"\n--- Layer {i} ---") + print(f"Layer type: {type(layer).__name__}") + important_attrs = [ + 'embedding_dim', + 'linear_dim', + 'attention_dim', + 'dropout_ratio', + 'attn_dropout_ratio', + 'num_heads', + 'linear_activation', + 'normalization', + 'linear_config', + 'concat_ua' + ] + for attr_name in important_attrs: + attr_value = getattr(layer, f"_{attr_name}", None) + if attr_value is not None: + print(f"{attr_name}: {attr_value}") + print(f"\nAfter Layer {i}:") + print(f"sequence.size: {x.size()}\nsequence: {x}") + print(f"x_offsets: {x_offsets}") + if return_cache_states: + cache_states.append(cache_states_i) + + return x, cache_states + + def get_current_embeddings( + self, + lengths: torch.Tensor, + encoded_embeddings: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + lengths: (B,) x int + seq_embeddings: (B, N, D,) x float + + Returns: + (B, D,) x float, where [i, :] == encoded_embeddings[i, lengths[i] - 1, :] + """ + B, N, D = encoded_embeddings.size() + flattened_offsets = (lengths - 1) + _arange(B, device=lengths.device) * N + return encoded_embeddings.reshape(-1, D)[flattened_offsets, :].reshape(B, D) \ No newline at end of file diff --git a/tzrec/modules/sequence_test.py b/tzrec/modules/sequence_test.py index 57e20fb..fac839e 100644 --- a/tzrec/modules/sequence_test.py +++ b/tzrec/modules/sequence_test.py @@ -19,6 +19,7 @@ MultiWindowDINEncoder, PoolingEncoder, SimpleAttention, + HSTUEncoder, create_seq_encoder, ) from tzrec.protos import module_pb2, seq_encoder_pb2 @@ -75,6 +76,46 @@ def test_din_encoder_padding(self, graph_type) -> None: } result = din(embedded) self.assertEqual(result.size(), (4, 16)) + + +class HSTUEncoderTest(unittest.TestCase): + @parameterized.expand( + [[TestGraphType.NORMAL], [TestGraphType.FX_TRACE], [TestGraphType.JIT_SCRIPT]] + ) + def test_hstu_encoder(self, graph_type) -> None: + din = HSTUEncoder( + sequence_dim=16, + input="click_seq", + max_seq_length=10, + ) + self.assertEqual(din.output_dim(), 16) + din = create_test_module(din, graph_type) + embedded = { + "click_seq.query": torch.randn(4, 16), + "click_seq.sequence": torch.randn(4, 10, 16), + "click_seq.sequence_length": torch.tensor([2, 3, 4, 5]), + } + result = din(embedded) + self.assertEqual(result.size(), (4, 16)) + + @parameterized.expand( + [[TestGraphType.NORMAL], [TestGraphType.FX_TRACE], [TestGraphType.JIT_SCRIPT]] + ) + def test_hstu_encoder_padding(self, graph_type) -> None: + din = HSTUEncoder( + sequence_dim=16, + input="click_seq", + max_seq_length=10, + ) + self.assertEqual(din.output_dim(), 16) + din = create_test_module(din, graph_type) + embedded = { + "click_seq.query": torch.randn(4, 12), + "click_seq.sequence": torch.randn(4, 10, 16), + "click_seq.sequence_length": torch.tensor([2, 3, 4, 5]), + } + result = din(embedded) + self.assertEqual(result.size(), (4, 16)) class SimpleAttentionTest(unittest.TestCase): diff --git a/tzrec/protos/model.proto b/tzrec/protos/model.proto index b88acba..08f64e9 100644 --- a/tzrec/protos/model.proto +++ b/tzrec/protos/model.proto @@ -46,6 +46,7 @@ message ModelConfig { DSSM dssm = 301; DSSMV2 dssm_v2 = 302; + HSTU hstu = 303; TDM tdm = 400; diff --git a/tzrec/protos/seq_encoder.proto b/tzrec/protos/seq_encoder.proto index 5f3933b..9dc8627 100644 --- a/tzrec/protos/seq_encoder.proto +++ b/tzrec/protos/seq_encoder.proto @@ -39,11 +39,43 @@ message MultiWindowDINEncoder { repeated uint32 windows_len = 4; } +message HSTUEncoder { + // seq encoder name + optional string name = 1; + // sequence feature name + required string input = 2; + // sequence dimension + optional int32 sequence_dim = 3; + // maximum sequence length + optional int32 max_seq_length = 4; + // dropout rate for positional embeddings + optional float pos_dropout_rate = 5 [default = 0.5]; + // dropout rate for linear layers + optional float linear_dropout_rate = 6 [default = 0.2]; + // dropout rate for attention + optional float attn_dropout_rate = 7 [default = 0.0]; + // normalization type + optional string normalization = 8 [default = "rel_bias"]; + // activation function for linear layers + optional string linear_activation = 9 [default = "silu"]; + // linear configuration type + optional string linear_config = 10 [default = "uvqk"]; + // number of attention heads + optional int32 num_heads = 11 [default = 4]; + // number of transformer blocks + optional int32 num_blocks = 12 [default = 4]; + // maximum output sequence length + optional int32 max_output_len = 13 [default = 2]; + // size of time buckets for relative attention + optional int32 time_bucket_size = 14 [default = 128]; +} + message SeqEncoderConfig { oneof seq_module { DINEncoder din_encoder = 1; SimpleAttention simple_attention = 2; PoolingEncoder pooling_encoder = 3; MultiWindowDINEncoder multi_window_din_encoder = 4; + HSTUEncoder hstu_encoder = 5; } } diff --git a/tzrec/tests/configs/hstu_fg_mock.config b/tzrec/tests/configs/hstu_fg_mock.config new file mode 100644 index 0000000..0146598 --- /dev/null +++ b/tzrec/tests/configs/hstu_fg_mock.config @@ -0,0 +1,259 @@ +train_input_path: "" +eval_input_path: "" +model_dir: "experiments/dssm_fg_mock" +train_config { + sparse_optimizer { + adagrad_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + dense_optimizer { + adam_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + num_epochs: 8 +} +eval_config { +} +data_config { + batch_size: 8192 + dataset_type: ParquetDataset + fg_encoded: false + label_fields: "clk" + num_workers: 8 + force_base_data_group: true + negative_sampler { + input_path: "odps://{PROJECT}/tables/taobao_ad_feature_gl_bucketized_v1" + num_sample: 1024 + attr_fields: "item_id" + attr_fields: "item_id_1" + attr_fields: "item_id_2" + attr_fields: "item_raw_1" + attr_fields: "item_raw_2" + attr_fields: "item_raw_3" + attr_fields: "title" + item_id_field: "item_id" + attr_delimiter: "\x02" + } +} +feature_configs { + id_feature { + feature_name: "user_id" + expression: "user:user_id" + hash_bucket_size: 1000000 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "user_id_1" + expression: "user:user_id_1" + num_buckets: 10000 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "user_id_2" + expression: "user:user_id_2" + vocab_list: ["a", "b", "c"] + embedding_dim: 8 + } +} +feature_configs { + id_feature { + feature_name: "user_id_3" + expression: "user:user_id_3" + num_buckets: 100 + embedding_dim: 16 + embedding_name: "user_id_3_emb" + } +} +feature_configs { + id_feature { + feature_name: "user_id_4" + expression: "user:user_id_4" + num_buckets: 100 + embedding_dim: 16 + embedding_name: "user_id_3_emb" + } +} +feature_configs { + raw_feature { + feature_name: "user_raw_1" + expression: "user:user_raw_1" + boundaries: [0.1, 0.2, 0.3, 0.4] + embedding_dim: 16 + } +} +feature_configs { + raw_feature { + feature_name: "user_raw_2" + expression: "user:user_raw_2" + } +} +feature_configs { + raw_feature { + feature_name: "user_raw_3" + expression: "user:user_raw_3" + value_dim: 4 + } +} +feature_configs { + id_feature { + feature_name: "item_id" + expression: "item:item_id" + num_buckets: 1000000 + embedding_dim: 16 + embedding_name: "item_id" + } +} +feature_configs { + id_feature { + feature_name: "item_id_1" + expression: "item:item_id_1" + num_buckets: 10000 + embedding_dim: 16 + } +} +feature_configs { + id_feature { + feature_name: "item_id_2" + expression: "item:item_id_2" + num_buckets: 1000 + embedding_dim: 8 + } +} +feature_configs { + raw_feature { + feature_name: "item_raw_1" + expression: "item:item_raw_1" + boundaries: [0.1, 0.2, 0.3, 0.4] + embedding_dim: 16 + } +} +feature_configs { + raw_feature { + feature_name: "item_raw_2" + expression: "item:item_raw_2" + } +} +feature_configs { + raw_feature { + feature_name: "item_raw_3" + expression: "item:item_raw_3" + value_dim: 4 + } +} +feature_configs { + sequence_feature { + sequence_name: "click_50_seq" + sequence_length: 50 + sequence_delim: "|" + features { + id_feature { + feature_name: "item_id" + expression: "item:item_id" + num_buckets: 1000000 + embedding_dim: 16 + embedding_name: "item_id" + } + } + features { + id_feature { + feature_name: "item_id_1" + expression: "item:item_id_1" + num_buckets: 10000 + embedding_dim: 16 + embedding_name: "item_id_1" + } + } + features { + raw_feature { + feature_name: "item_raw_1" + expression: "item:item_raw_1" + boundaries: [0.1, 0.2, 0.3, 0.4] + embedding_dim: 16 + } + } + features { + raw_feature { + feature_name: "item_raw_2" + expression: "item:item_raw_2" + } + } + } +} +feature_configs { + tokenize_feature { + feature_name: "item_title" + expression: "item:title" + vocab_file: "./data/test/tokenizer.json" + text_normalizer {} + embedding_dim: 8 + } +} +model_config { + feature_groups { + group_name: "user" + feature_names: "user_id" + feature_names: "user_id_1" + feature_names: "user_raw_1" + sequence_groups { + group_name: "click_50_seq" + feature_names: "click_50_seq__item_id" + feature_names: "click_50_seq__item_id_1" + feature_names: "click_50_seq__item_raw_1" + feature_names: "click_50_seq__item_raw_2" + } + sequence_encoders { + hstu_encoder: { + sequence_dim: 16 + input: "click_50_seq" + max_seq_length: 50 + } + } + group_type: DEEP + } + feature_groups { + group_name: "item" + feature_names: "item_id" + feature_names: "item_id_1" + feature_names: "item_raw_1" + feature_names: "item_title" + group_type: DEEP + } + hstu { + user_tower { + input: 'user' + mlp { + hidden_units: [512, 256, 128] + } + } + item_tower { + input: 'item' + mlp { + hidden_units: [512, 256, 128] + } + } + output_dim: 64 + } + metrics { + recall_at_k { + top_k: 1 + } + } + metrics { + recall_at_k { + top_k: 5 + } + } + losses { + softmax_cross_entropy {} + } +} diff --git a/tzrec/tests/train_eval_export_test.py b/tzrec/tests/train_eval_export_test.py index 5e0ea15..0cd6dc5 100644 --- a/tzrec/tests/train_eval_export_test.py +++ b/tzrec/tests/train_eval_export_test.py @@ -299,6 +299,30 @@ def test_dssm_v2_with_fg_train_eval_export(self): self.assertTrue( os.path.exists(os.path.join(self.test_dir, "export/item/scripted_model.pt")) ) + + def test_hstu_with_fg_train_eval_export(self): + self.success = utils.test_train_eval( + "tzrec/tests/configs/hstu_fg_mock.config", + self.test_dir, + user_id="user_id", + item_id="item_id", + ) + if self.success: + print("First test_eval") + self.success = utils.test_eval( + os.path.join(self.test_dir, "pipeline.config"), self.test_dir + ) + if self.success: + self.success = utils.test_export( + os.path.join(self.test_dir, "pipeline.config"), self.test_dir + ) + self.assertTrue(self.success) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "export/user/scripted_model.pt")) + ) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "export/item/scripted_model.pt")) + ) def test_multi_tower_din_with_fg_train_eval_export_input_tile(self): self.success = utils.test_train_eval( From 769cc0a29a775f6ad51de8bf5043e7b1a4435e42 Mon Sep 17 00:00:00 2001 From: tianzhou Date: Tue, 24 Dec 2024 10:52:16 +0800 Subject: [PATCH 3/3] [bugfix] fix bugs in hstu --- tzrec/models/hstu.py | 28 ++---- tzrec/modules/embedding.py | 23 +++-- tzrec/modules/hstu.py | 128 +++++++++++++++++++++--- tzrec/modules/sequence.py | 107 ++++++++++---------- tzrec/modules/sequence_test.py | 4 +- tzrec/tests/configs/hstu_fg_mock.config | 2 +- 6 files changed, 190 insertions(+), 102 deletions(-) diff --git a/tzrec/models/hstu.py b/tzrec/models/hstu.py index 154b7a1..3979c28 100644 --- a/tzrec/models/hstu.py +++ b/tzrec/models/hstu.py @@ -14,16 +14,13 @@ import torch import torch.nn.functional as F -from torch import nn from torch._tensor import Tensor from tzrec.datasets.utils import Batch from tzrec.features.feature import BaseFeature from tzrec.models.match_model import MatchModel, MatchTower -from tzrec.modules.mlp import MLP from tzrec.protos import model_pb2, tower_pb2 from tzrec.protos.models import match_model_pb2 -from tzrec.utils.config_util import config_to_kwargs @torch.fx.wrap @@ -62,10 +59,7 @@ def __init__( tower_config, output_dim, similarity, feature_group, features, model_config ) self.init_input() - tower_feature_in = self.embedding_group.group_total_dim(self._group_name) - self.mlp = MLP(tower_feature_in, **config_to_kwargs(tower_config.mlp)) - if self._output_dim > 0: - self.output = nn.Linear(self.mlp.output_dim(), output_dim) + self.tower_config = tower_config def forward(self, batch: Batch) -> torch.Tensor: """Forward the tower. @@ -76,18 +70,13 @@ def forward(self, batch: Batch) -> torch.Tensor: Return: embedding (dict): tower output embedding. """ + # print(batch) grouped_features = self.build_input(batch) - output = self.mlp(grouped_features[self._group_name]) - # TODO: this will cause dimension unmatch in self.output, considering resolutions - # if self._tower_config.input == 'user': - # output = grouped_features[self._group_name] - # else: - # output = self.mlp(grouped_features[self._group_name]) - - if self._output_dim > 0: - output = self.output(output) - if self._similarity == match_model_pb2.Similarity.COSINE: - output = F.normalize(output, p=2.0, dim=1) + output = grouped_features[self._group_name] + + if self.tower_config.input == "item": + if self._similarity == match_model_pb2.Similarity.COSINE: + output = F.normalize(output, p=2.0, dim=1, eps=1e-6) return output @@ -108,7 +97,7 @@ def __init__( sample_weights: Optional[List[str]] = None, **kwargs: Any, ) -> None: - super().__init__(model_config, features, labels, sample_weights, **kwargs) + super().__init__(model_config, features, labels, sample_weights, **kwargs) name_to_feature_group = {x.group_name: x for x in model_config.feature_groups} user_group = name_to_feature_group[self._model_config.user_tower.input] @@ -158,7 +147,6 @@ def predict(self, batch: Batch) -> Dict[str, Tensor]: _update_dict_tensor( self._loss_collection, self.item_tower.group_variational_dropout_loss ) - ui_sim = ( self.sim(user_tower_emb, item_tower_emb) / self._model_config.temperature ) diff --git a/tzrec/modules/embedding.py b/tzrec/modules/embedding.py index 6f4f7cc..d498b22 100644 --- a/tzrec/modules/embedding.py +++ b/tzrec/modules/embedding.py @@ -399,15 +399,22 @@ def forward( if emb_impl.has_sparse_user: sparse_feat_kjt_user = batch.sparse_features[key + "_user"] - result_dicts.append( - emb_impl( - sparse_feat_kjt, - dense_feat_kt, - sparse_feat_kjt_user, - dense_feat_kt_user, - batch.batch_size, + + if ( + emb_impl.has_dense + or emb_impl.has_dense_user + or emb_impl.has_sparse + or emb_impl.has_sparse_user + ): + result_dicts.append( + emb_impl( + sparse_feat_kjt, + dense_feat_kt, + sparse_feat_kjt_user, + dense_feat_kt_user, + batch.batch_size, + ) ) - ) for key, seq_emb_impl in self.seq_emb_impls.items(): sparse_feat_kjt = None diff --git a/tzrec/modules/hstu.py b/tzrec/modules/hstu.py index 86b845a..179db7b 100644 --- a/tzrec/modules/hstu.py +++ b/tzrec/modules/hstu.py @@ -1,22 +1,47 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import math -import abc +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +import math from typing import Callable, Optional, Tuple +import torch +import torch.nn.functional as F + TIMESTAMPS_KEY = "timestamps" class RelativeAttentionBiasModule(torch.nn.Module): + """Relative Attention Bias Module for transformer-based architectures. + + This module computes relative positional biases for attention mechanisms, + allowing the model to consider relative positions between tokens in the sequence. + Implements learnable relative position embeddings that can be + added to attention scores. + + Inherits from: + torch.nn.Module: Base PyTorch Module class + + Note: + The relative attention bias is typically added to the attention scores + before the softmax operation in the attention mechanism. + """ @abc.abstractmethod def forward( self, all_timestamps: torch.Tensor, ) -> torch.Tensor: - """ + """Calculate bias with timestamps. + Args: all_timestamps: [B, N] x int64 Returns: @@ -26,6 +51,14 @@ def forward( class RelativePositionalBias(RelativeAttentionBiasModule): + """Implements relative positional bias for attention mechanisms. + + This class provides learnable position-based attention biases based on the relative + positions of elements in a sequence, up to a maximum sequence length. + + Args: + max_seq_len (int): Maximum sequence length supported by this bias module. + """ def __init__(self, max_seq_len: int) -> None: super().__init__() @@ -39,6 +72,20 @@ def forward( self, all_timestamps: torch.Tensor, ) -> torch.Tensor: + """Computes relative positional biases for attention. + + This method generates position-based attention biases based on + relative positions, ignoring the actual timestamps provided + (as this implementation only cares about + relative positions, not temporal information). + + Args: + all_timestamps: Tensor of shape [B, N] containing int64 timestamps + (unused in this implementation) + + Returns: + torch.Tensor: Attention bias tensor broadcastable to shape [B, N, N] + """ del all_timestamps n: int = self._max_seq_len t = F.pad(self._w[: 2 * n - 1], [0, n]).repeat(n) @@ -48,9 +95,7 @@ def forward( class RelativeBucketedTimeAndPositionBasedBias(RelativeAttentionBiasModule): - """ - Bucketizes timespans based on ts(next-item) - ts(current-item). - """ + """Bucketizes timespans based on ts(next-item) - ts(current-item).""" def __init__( self, @@ -76,9 +121,11 @@ def forward( self, all_timestamps: torch.Tensor, ) -> torch.Tensor: - """ + """Forward function. + Args: all_timestamps: (B, N). + Returns: (B, N, N). """ @@ -187,6 +234,52 @@ def _hstu_attention_maybe_from_cache( class SequentialTransductionUnitJagged(torch.nn.Module): + """A jagged sequential transduction unit for variable-length sequences. + + This module processes jagged (variable-length) sequences using a + combination of attention mechanisms and linear transformations. + It supports various normalization strategies and attention bias configurations. + + Args: + embedding_dim (int): Dimension of input embeddings + linear_hidden_dim (int): Dimension of hidden linear layers + attention_dim (int): Dimension of attention mechanism + dropout_ratio (float): Dropout probability for linear layers + attn_dropout_ratio (float): Dropout probability for attention + num_heads (int): Number of attention heads + linear_activation (str): + Activation function for linear layers ('silu' or 'none') + relative_attention_bias_module (Optional[RelativeAttentionBiasModule]): + Module for relative position biases + normalization (str, optional): + Normalization strategy. Defaults to "rel_bias". + Options: "rel_bias", "hstu_rel_bias", "softmax_rel_bias" + linear_config (str, optional): + Linear layer configuration. Defaults to "uvqk". + concat_ua (bool, optional): + Whether to concatenate u and a in output. Defaults to False. + epsilon (float, optional): + Small constant for numerical stability. Defaults to 1e-6. + max_length (Optional[int], optional): + Maximum sequence length. Defaults to None. + + Attributes: + _embedding_dim (int): Dimension of input embeddings + _linear_dim (int): Dimension of hidden linear layers + _attention_dim (int): Dimension of attention mechanism + _num_heads (int): Number of attention heads + _rel_attn_bias (Optional[RelativeAttentionBiasModule]): + Module for relative position biases + _normalization (str): Normalization strategy + _linear_config (str): Linear layer configuration + _concat_ua (bool): Whether to concatenate u and a in output + _eps (float): Small constant for numerical stability + + Note: + This implementation supports caching for efficient sequential processing and + handles jagged sequences through FBGEMM operations for dense-jagged conversions. + """ + def __init__( self, embedding_dim: int, @@ -254,17 +347,20 @@ def forward( # pyre-ignore [3] cache: Optional[HSTUCacheState] = None, return_cache_states: bool = False, ): - """ + r"""Forward function. + Args: x: (\sum_i N_i, D) x float. x_offsets: (B + 1) x int32. all_timestamps: optional (B, N) x int64. invalid_attn_mask: (B, N, N) x float, each element in {0, 1}. delta_x_offsets: optional 2-tuple ((B,) x int32, (B,) x int32). - For the 1st element in the tuple, each element is in [0, x_offsets[-1]). For the - 2nd element in the tuple, each element is in [0, N). + For the 1st element in the tuple, each element is in [0, x_offsets[-1]). + For the 2nd element in the tuple, each element is in [0, N). cache: Optional 4-tuple of (v, padded_q, padded_k, output) from prior runs, where all except padded_q, padded_k are jagged. + return_cache_states: Return cache status or not. + Returns: x' = f(x), (\sum_i N_i, D) x float. """ @@ -272,8 +368,8 @@ def forward( # pyre-ignore [3] cached_q = None cached_k = None if delta_x_offsets is not None: - # In this case, for all the following code, x, u, v, q, k become restricted to - # [delta_x_offsets[0], :]. + # In this case, for all the following code, x, u, v, q, k + # become restricted to [delta_x_offsets[0], :]. assert cache is not None x = x[delta_x_offsets[0], :] cached_v, cached_q, cached_k, cached_outputs = cache @@ -404,4 +500,4 @@ def forward( # pyre-ignore [3] if return_cache_states and delta_x_offsets is None: v = v.contiguous() - return new_outputs, (v, padded_q, padded_k, new_outputs) \ No newline at end of file + return new_outputs, (v, padded_q, padded_k, new_outputs) diff --git a/tzrec/modules/sequence.py b/tzrec/modules/sequence.py index 4b73245..e94cb8f 100644 --- a/tzrec/modules/sequence.py +++ b/tzrec/modules/sequence.py @@ -9,17 +9,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Tuple, Optional +from typing import Any, Dict, List, Optional, Tuple -import fbgemm_gpu import numpy as np import torch from torch import nn from torch.nn import functional as F +from tzrec.modules.hstu import ( + HSTUCacheState, + RelativeBucketedTimeAndPositionBasedBias, + SequentialTransductionUnitJagged, +) from tzrec.modules.mlp import MLP -from tzrec.modules.hstu import RelativeBucketedTimeAndPositionBasedBias, SequentialTransductionUnitJagged -from tzrec.modules.hstu import HSTUCacheState from tzrec.protos.seq_encoder_pb2 import SeqEncoderConfig from tzrec.utils import config_util from tzrec.utils.load_class import get_register_class_meta @@ -315,17 +317,17 @@ def __init__( sequence_dim: int, input: str, max_seq_length: int, - pos_dropout_rate: float = 0.5, + pos_dropout_rate: float = 0.2, linear_dropout_rate: float = 0.2, attn_dropout_rate: float = 0.0, - normalization: str = 'rel_bias', - linear_activation: str = 'silu', - linear_config: str = 'uvqk', - num_heads: int = 4, - num_blocks: int = 4, - max_output_len: int = 2, + normalization: str = "rel_bias", + linear_activation: str = "silu", + linear_config: str = "uvqk", + num_heads: int = 1, + num_blocks: int = 2, + max_output_len: int = 10, time_bucket_size: int = 128, - **kwargs: Optional[Dict[str, Any]] + **kwargs: Optional[Dict[str, Any]], ) -> None: super().__init__(input) self._sequence_dim = sequence_dim @@ -333,7 +335,9 @@ def __init__( self._query_name = f"{input}.query" self._sequence_name = f"{input}.sequence" self._sequence_length_name = f"{input}.sequence_length" - self.position_embed = nn.Embedding(self._max_seq_length + max_output_len + 1, self._sequence_dim, padding_idx=0) + self.position_embed = nn.Embedding( + self._max_seq_length + max_output_len + 1, self._sequence_dim, padding_idx=0 + ) self.dropout_rate = pos_dropout_rate self.enable_relative_attention_bias = True self.autocast_dtype = None @@ -379,7 +383,7 @@ def __init__( ), ) self._autocast_dtype = None - + def output_dim(self) -> int: """Output dimension of the module.""" return self._sequence_dim @@ -390,18 +394,26 @@ def forward(self, sequence_embedded: Dict[str, torch.Tensor]) -> torch.Tensor: sequence_length = sequence_embedded[self._sequence_length_name] # N # max_seq_length = sequence.size(1) float_dtype = sequence.dtype - + # Add positional embeddings and apply dropout - positions = _arange(sequence.size(1), device=sequence.device).unsqueeze(0).expand(sequence.size(0), -1) + positions = ( + _arange(sequence.size(1), device=sequence.device) + .unsqueeze(0) + .expand(sequence.size(0), -1) + ) sequence = sequence * (self._sequence_dim**0.5) + self.position_embed(positions) sequence = F.dropout(sequence, p=self.dropout_rate, training=self.training) - sequence_mask = _arange(sequence.size(1), device=sequence_length.device).unsqueeze(0) < sequence_length.unsqueeze(1) + sequence_mask = _arange( + sequence.size(1), device=sequence_length.device + ).unsqueeze(0) < sequence_length.unsqueeze(1) sequence = sequence * sequence_mask.unsqueeze(-1).to(float_dtype) - + invalid_attn_mask = 1.0 - self._attn_mask.to(float_dtype) - sequence_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(sequence_length) + sequence_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + sequence_length + ) sequence = torch.ops.fbgemm.dense_to_jagged(sequence, [sequence_offsets])[0] - + all_timestamps = None jagged_x, cache_states = self.jagged_forward( x=sequence, @@ -413,18 +425,20 @@ def forward(self, sequence_embedded: Dict[str, torch.Tensor]) -> torch.Tensor: return_cache_states=False, ) output_embeddings = torch.ops.fbgemm.jagged_to_padded_dense( - values=jagged_x, - offsets=[sequence_offsets], - max_lengths=[invalid_attn_mask.size(1)], - padding_value=0.0, - ) - # post processing + values=jagged_x, + offsets=[sequence_offsets], + max_lengths=[invalid_attn_mask.size(1)], + padding_value=0.0, + ) + # post processing: L2 Normalization output_embeddings = output_embeddings[..., : self._sequence_dim] output_embeddings = output_embeddings / torch.clamp( - torch.linalg.norm(output_embeddings, ord=None, dim=-1, keepdim=True), - min=1e-6, - ) - output_embeddings = self.get_current_embeddings(sequence_length, output_embeddings) + torch.linalg.norm(output_embeddings, ord=None, dim=-1, keepdim=True), + min=1e-6, + ) + output_embeddings = self.get_current_embeddings( + sequence_length, output_embeddings + ) return output_embeddings def jagged_forward( @@ -437,12 +451,15 @@ def jagged_forward( cache: Optional[List[HSTUCacheState]] = None, return_cache_states: bool = False, ) -> Tuple[torch.Tensor, List[HSTUCacheState]]: - """ + r"""Jagged forward. + Args: x: (\sum_i N_i, D) x float x_offsets: (B + 1) x int32 all_timestamps: (B, 1 + N) x int64 invalid_attn_mask: (B, N, N) x float, each element in {0, 1} + delta_x_offsets: offsets for x + cache: cache contents return_cache_states: bool. True if we should return cache states. Returns: @@ -465,27 +482,6 @@ def jagged_forward( cache=cache[i] if cache is not None else None, return_cache_states=return_cache_states, ) - print(f"\n--- Layer {i} ---") - print(f"Layer type: {type(layer).__name__}") - important_attrs = [ - 'embedding_dim', - 'linear_dim', - 'attention_dim', - 'dropout_ratio', - 'attn_dropout_ratio', - 'num_heads', - 'linear_activation', - 'normalization', - 'linear_config', - 'concat_ua' - ] - for attr_name in important_attrs: - attr_value = getattr(layer, f"_{attr_name}", None) - if attr_value is not None: - print(f"{attr_name}: {attr_value}") - print(f"\nAfter Layer {i}:") - print(f"sequence.size: {x.size()}\nsequence: {x}") - print(f"x_offsets: {x_offsets}") if return_cache_states: cache_states.append(cache_states_i) @@ -496,14 +492,15 @@ def get_current_embeddings( lengths: torch.Tensor, encoded_embeddings: torch.Tensor, ) -> torch.Tensor: - """ + """Get the embeddings of the last past_id as the current embeds. + Args: lengths: (B,) x int - seq_embeddings: (B, N, D,) x float + encoded_embeddings: (B, N, D,) x float Returns: (B, D,) x float, where [i, :] == encoded_embeddings[i, lengths[i] - 1, :] """ B, N, D = encoded_embeddings.size() flattened_offsets = (lengths - 1) + _arange(B, device=lengths.device) * N - return encoded_embeddings.reshape(-1, D)[flattened_offsets, :].reshape(B, D) \ No newline at end of file + return encoded_embeddings.reshape(-1, D)[flattened_offsets, :].reshape(B, D) diff --git a/tzrec/modules/sequence_test.py b/tzrec/modules/sequence_test.py index fac839e..e373c43 100644 --- a/tzrec/modules/sequence_test.py +++ b/tzrec/modules/sequence_test.py @@ -16,10 +16,10 @@ from tzrec.modules.sequence import ( DINEncoder, + HSTUEncoder, MultiWindowDINEncoder, PoolingEncoder, SimpleAttention, - HSTUEncoder, create_seq_encoder, ) from tzrec.protos import module_pb2, seq_encoder_pb2 @@ -76,7 +76,7 @@ def test_din_encoder_padding(self, graph_type) -> None: } result = din(embedded) self.assertEqual(result.size(), (4, 16)) - + class HSTUEncoderTest(unittest.TestCase): @parameterized.expand( diff --git a/tzrec/tests/configs/hstu_fg_mock.config b/tzrec/tests/configs/hstu_fg_mock.config index 0146598..7b13c5c 100644 --- a/tzrec/tests/configs/hstu_fg_mock.config +++ b/tzrec/tests/configs/hstu_fg_mock.config @@ -1,6 +1,6 @@ train_input_path: "" eval_input_path: "" -model_dir: "experiments/dssm_fg_mock" +model_dir: "experiments/hstu_fg_mock" train_config { sparse_optimizer { adagrad_optimizer {