Skip to content

Commit

Permalink
Permute export fix (pytorch#1985)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1985

Make KJT permute work in edge cases for torch.export: https://fb.workplace.com/groups/6829516587176185/posts/7206415192819654/?comment_id=7207715762689597&reply_comment_id=7210658785728628.

A cleaner solution than before as well. Necessary for enabling PT2 eager model processing on AIMP

Reviewed By: IvanKobzarev

Differential Revision: D57162331

fbshipit-source-id: 01bf47328d1e31a7544d0aefab40f66ce7435c65
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed May 13, 2024
1 parent b32dc7a commit b97efd5
Showing 1 changed file with 7 additions and 14 deletions.
21 changes: 7 additions & 14 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1952,8 +1952,13 @@ def permute(
self.stride_per_key_per_rank()[index]
)
permuted_length_per_key.append(length_per_key[index])
if not is_non_strict_exporting():
permuted_length_per_key_sum += length_per_key[index]

permuted_length_per_key_sum = sum(permuted_length_per_key)
if not torch.jit.is_scripting() and is_non_strict_exporting():
torch._check_is_size(permuted_length_per_key_sum)
torch._check(permuted_length_per_key_sum != -1)
torch._check(permuted_length_per_key_sum != 0)

if self.variable_stride_per_key():
length_per_key_tensor = _pin_and_move(
torch.tensor(self.length_per_key()), self.device()
Expand All @@ -1974,18 +1979,6 @@ def permute(
self.weights_or_none(),
)
else:
if not torch.jit.is_scripting() and is_non_strict_exporting():
permuted_length_per_key_sum = torch.sum(
torch._refs.tensor(
permuted_length_per_key,
dtype=torch.int32,
device=torch.device("cpu"),
pin_memory=False,
requires_grad=False,
)
).item()

torch._check(permuted_length_per_key_sum > 0)

(
permuted_lengths,
Expand Down

0 comments on commit b97efd5

Please sign in to comment.