From befa94b8591706256a63a8a89f6e3a979965c65f Mon Sep 17 00:00:00 2001 From: Qiang Zhang Date: Mon, 11 Dec 2023 16:27:56 -0800 Subject: [PATCH] Add comments to guide usage of batch_size_per_feature and block_bucketize_pos (#1575) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1575 Currently block_bucketize_sparse_features requires same dtype as tensor for both batch_size_per_feature and tensor in block_bucketize_pos. Add comments to guard against misuse Reviewed By: joshuadeng Differential Revision: D51954185 fbshipit-source-id: 643b797ee196782800dbd384fd970e63b366c475 --- torchrec/distributed/embedding_sharding.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchrec/distributed/embedding_sharding.py b/torchrec/distributed/embedding_sharding.py index 0109d368d..388526226 100644 --- a/torchrec/distributed/embedding_sharding.py +++ b/torchrec/distributed/embedding_sharding.py @@ -59,7 +59,9 @@ def _fx_wrap_tensor_to_device_dtype( @torch.fx.wrap def _fx_wrap_batch_size_per_feature(kjt: KeyedJaggedTensor) -> Optional[torch.Tensor]: return ( - torch.tensor(kjt.stride_per_key(), device=kjt.device()) + torch.tensor( + kjt.stride_per_key(), device=kjt.device(), dtype=kjt.lengths().dtype + ) if kjt.variable_stride_per_key() else None ) @@ -136,7 +138,7 @@ def bucketize_kjt_before_all2all( weights=kjt.weights_or_none(), batch_size_per_feature=_fx_wrap_batch_size_per_feature(kjt), max_B=_fx_wrap_max_B(kjt), - block_bucketize_pos=block_bucketize_row_pos, + block_bucketize_pos=block_bucketize_row_pos, # each tensor should have the same dtype as kjt.lengths() ) return (