diff --git a/torchrec/distributed/quant_embedding.py b/torchrec/distributed/quant_embedding.py index 1c02c54b1..311336594 100644 --- a/torchrec/distributed/quant_embedding.py +++ b/torchrec/distributed/quant_embedding.py @@ -189,6 +189,8 @@ def _construct_jagged_tensors_cw( for i in range(len(features)): embedding = embeddings[i] feature = features[i] + # pyre-fixme[6]: For 1st argument expected `List[Tensor]` but got + # `Tuple[Tensor, ...]`. lengths_lists.append(torch.unbind(feature.lengths().view(-1, stride), dim=0)) embeddings_lists.append( list(torch.split(embedding, feature.length_per_key(), dim=0))