Skip to content

Commit

Permalink
table-batched inference QEC (pytorch#1574)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1574

* add table batching to QEC for inference case.

Reviewed By: dstaay-fb, 842974287

Differential Revision: D46238791

fbshipit-source-id: ecc46844e6763426b183b1977561ecf073e8728e
  • Loading branch information
YazhiGao authored and facebook-github-bot committed Jan 28, 2024
1 parent d3f8be1 commit 4376dff
Show file tree
Hide file tree
Showing 5 changed files with 278 additions and 158 deletions.
57 changes: 2 additions & 55 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
EmbeddingCollection,
EmbeddingCollectionInterface,
)
from torchrec.modules.utils import construct_jagged_tensors
from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor
Expand Down Expand Up @@ -223,60 +224,6 @@ def create_sharding_infos_by_sharding(
return sharding_type_to_sharding_infos


def _construct_jagged_tensors(
embeddings: torch.Tensor,
features: KeyedJaggedTensor,
embedding_names: List[str],
need_indices: bool = False,
features_to_permute_indices: Optional[Dict[str, List[int]]] = None,
original_features: Optional[KeyedJaggedTensor] = None,
reverse_indices: Optional[torch.Tensor] = None,
) -> Dict[str, JaggedTensor]:
with record_function("## _construct_jagged_tensors ##"):
if original_features is not None:
features = original_features
if reverse_indices is not None:
embeddings = torch.index_select(
embeddings, 0, reverse_indices.to(torch.int32)
)

ret: Dict[str, JaggedTensor] = {}
stride = features.stride()
length_per_key = features.length_per_key()
values = features.values()

lengths = features.lengths().view(-1, stride)
lengths_tuple = torch.unbind(lengths.view(-1, stride), dim=0)
embeddings_list = torch.split(embeddings, length_per_key, dim=0)
values_list = torch.split(values, length_per_key) if need_indices else None

key_indices = defaultdict(list)
for i, key in enumerate(embedding_names):
key_indices[key].append(i)
for key, indices in key_indices.items():
# combines outputs in correct order for CW sharding
indices = (
_permute_indices(indices, features_to_permute_indices[key])
if features_to_permute_indices and key in features_to_permute_indices
else indices
)
ret[key] = JaggedTensor(
lengths=lengths_tuple[indices[0]],
values=embeddings_list[indices[0]]
if len(indices) == 1
else torch.cat([embeddings_list[i] for i in indices], dim=1),
weights=values_list[indices[0]] if values_list else None,
)
return ret


def _permute_indices(indices: List[int], permute: List[int]) -> List[int]:
permuted_indices = [0] * len(indices)
for i, permuted_index in enumerate(permute):
permuted_indices[i] = indices[permuted_index]
return permuted_indices


@dataclass
class EmbeddingCollectionContext(Multistreamable):
sharding_contexts: List[SequenceShardingContext] = field(default_factory=list)
Expand Down Expand Up @@ -330,7 +277,7 @@ def _wait_impl(self) -> Dict[str, JaggedTensor]:
else self._ctx.reverse_indices[i]
)
jt_dict.update(
_construct_jagged_tensors(
construct_jagged_tensors(
embeddings=w.wait(),
features=f,
embedding_names=e,
Expand Down
8 changes: 3 additions & 5 deletions torchrec/distributed/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@
from torch import nn
from torch.distributed._shard.sharded_tensor import Shard
from torchrec.distributed.comm import get_local_rank
from torchrec.distributed.embedding import (
_construct_jagged_tensors,
EmbeddingCollectionContext,
)
from torchrec.distributed.embedding import EmbeddingCollectionContext
from torchrec.distributed.embedding_sharding import (
EmbeddingSharding,
EmbeddingShardingContext,
Expand Down Expand Up @@ -56,6 +53,7 @@
apply_mc_method_to_jt_dict,
ManagedCollisionCollection,
)
from torchrec.modules.utils import construct_jagged_tensors
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor


Expand Down Expand Up @@ -83,7 +81,7 @@ def _wait_impl(self) -> KeyedJaggedTensor:
self._embedding_names_per_sharding,
):
jt_dict.update(
_construct_jagged_tensors(
construct_jagged_tensors(
embeddings=w.wait(),
features=f,
embedding_names=e,
Expand Down
60 changes: 59 additions & 1 deletion torchrec/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
# LICENSE file in the root directory of this source tree.

import copy
from typing import Callable, Iterable, Tuple, Union
from collections import defaultdict
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union

import torch
from torch.profiler import record_function
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor


def extract_module_or_tensor_callable(
Expand Down Expand Up @@ -112,3 +115,58 @@ def convert_list_of_modules_to_modulelist(
return torch.nn.ModuleList(
convert_list_of_modules_to_modulelist(m, sizes[1:]) for m in modules
)


def construct_jagged_tensors(
embeddings: torch.Tensor,
features: KeyedJaggedTensor,
embedding_names: List[str],
need_indices: bool = False,
features_to_permute_indices: Optional[Dict[str, List[int]]] = None,
original_features: Optional[KeyedJaggedTensor] = None,
reverse_indices: Optional[torch.Tensor] = None,
) -> Dict[str, JaggedTensor]:
with record_function("## construct_jagged_tensors ##"):
if original_features is not None:
features = original_features
if reverse_indices is not None:
embeddings = torch.index_select(
embeddings, 0, reverse_indices.to(torch.int32)
)

ret: Dict[str, JaggedTensor] = {}
stride = features.stride()
length_per_key = features.length_per_key()
values = features.values()

lengths = features.lengths().view(-1, stride)
lengths_tuple = torch.unbind(lengths.view(-1, stride), dim=0)
embeddings_list = torch.split(embeddings, length_per_key, dim=0)
values_list = torch.split(values, length_per_key) if need_indices else None

key_indices = defaultdict(list)
for i, key in enumerate(embedding_names):
key_indices[key].append(i)
for key, indices in key_indices.items():
# combines outputs in correct order for CW sharding
indices = (
_permute_indices(indices, features_to_permute_indices[key])
if features_to_permute_indices and key in features_to_permute_indices
else indices
)
ret[key] = JaggedTensor(
lengths=lengths_tuple[indices[0]],
values=embeddings_list[indices[0]]
if len(indices) == 1
else torch.cat([embeddings_list[i] for i in indices], dim=1),
# pyre-ignore
weights=values_list[indices[0]] if need_indices else None,
)
return ret


def _permute_indices(indices: List[int], permute: List[int]) -> List[int]:
permuted_indices = [0] * len(indices)
for i, permuted_index in enumerate(permute):
permuted_indices[i] = indices[permuted_index]
return permuted_indices
Loading

0 comments on commit 4376dff

Please sign in to comment.