diff --git a/torchrec/distributed/embedding_sharding.py b/torchrec/distributed/embedding_sharding.py index 0109d368d..388526226 100644 --- a/torchrec/distributed/embedding_sharding.py +++ b/torchrec/distributed/embedding_sharding.py @@ -59,7 +59,9 @@ def _fx_wrap_tensor_to_device_dtype( @torch.fx.wrap def _fx_wrap_batch_size_per_feature(kjt: KeyedJaggedTensor) -> Optional[torch.Tensor]: return ( - torch.tensor(kjt.stride_per_key(), device=kjt.device()) + torch.tensor( + kjt.stride_per_key(), device=kjt.device(), dtype=kjt.lengths().dtype + ) if kjt.variable_stride_per_key() else None ) @@ -136,7 +138,7 @@ def bucketize_kjt_before_all2all( weights=kjt.weights_or_none(), batch_size_per_feature=_fx_wrap_batch_size_per_feature(kjt), max_B=_fx_wrap_max_B(kjt), - block_bucketize_pos=block_bucketize_row_pos, + block_bucketize_pos=block_bucketize_row_pos, # each tensor should have the same dtype as kjt.lengths() ) return (