Skip to content

Commit

Permalink
Add comments to guide usage of batch_size_per_feature and block_bucke…
Browse files Browse the repository at this point in the history
…tize_pos (pytorch#1575)

Summary:
Pull Request resolved: pytorch#1575

Currently block_bucketize_sparse_features requires same dtype as tensor for both batch_size_per_feature and tensor in block_bucketize_pos. Add comments to guard against misuse

Reviewed By: joshuadeng

Differential Revision: D51954185

fbshipit-source-id: 643b797ee196782800dbd384fd970e63b366c475
  • Loading branch information
gnahzg authored and facebook-github-bot committed Dec 12, 2023
1 parent 37f8243 commit befa94b
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def _fx_wrap_tensor_to_device_dtype(
@torch.fx.wrap
def _fx_wrap_batch_size_per_feature(kjt: KeyedJaggedTensor) -> Optional[torch.Tensor]:
return (
torch.tensor(kjt.stride_per_key(), device=kjt.device())
torch.tensor(
kjt.stride_per_key(), device=kjt.device(), dtype=kjt.lengths().dtype
)
if kjt.variable_stride_per_key()
else None
)
Expand Down Expand Up @@ -136,7 +138,7 @@ def bucketize_kjt_before_all2all(
weights=kjt.weights_or_none(),
batch_size_per_feature=_fx_wrap_batch_size_per_feature(kjt),
max_B=_fx_wrap_max_B(kjt),
block_bucketize_pos=block_bucketize_row_pos,
block_bucketize_pos=block_bucketize_row_pos, # each tensor should have the same dtype as kjt.lengths()
)

return (
Expand Down

0 comments on commit befa94b

Please sign in to comment.