Skip to content

Commit

Permalink
Minimize d2h syncs in calculating length_per_key from `stride_per_k…
Browse files Browse the repository at this point in the history
…ey` (pytorch#1485)

Summary:
Pull Request resolved: pytorch#1485

for large numbers of features, we will call .item() for each one causing a large number of d2h syncs. this diff combines list of tensors into a single tensor and calls a single .tolist()

Reviewed By: bigning

Differential Revision: D51046476

fbshipit-source-id: 26fd38767d1d48dade24057cd2136b15ea29c16c
  • Loading branch information
joshuadeng authored and facebook-github-bot committed Nov 13, 2023
1 parent 9db35bb commit e2cc13a
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,9 +631,9 @@ def _maybe_compute_stride_kjt_scripted(
def _length_per_key_from_stride_per_key(
lengths: torch.Tensor, stride_per_key: List[int]
) -> List[int]:
return [
int(torch.sum(chunk).item()) for chunk in torch.split(lengths, stride_per_key)
]
return torch.cat(
[torch.sum(chunk).view(1) for chunk in torch.split(lengths, stride_per_key)]
).tolist()


def _maybe_compute_length_per_key(
Expand Down

0 comments on commit e2cc13a

Please sign in to comment.