From 369526d017194434a129ec64457aba19c59b2fae Mon Sep 17 00:00:00 2001 From: Ning Wang Date: Wed, 4 Oct 2023 20:39:02 -0700 Subject: [PATCH] improve comment of VariableBatchPooledEmbeddingsAllToAll doc (#1423) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1423 att Reviewed By: henrylhtsang Differential Revision: D49934622 fbshipit-source-id: e7360121ae38ed8dfdb57544669e27a51eede159 --- torchrec/distributed/dist_data.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torchrec/distributed/dist_data.py b/torchrec/distributed/dist_data.py index 5170882b8..81ecf3a2e 100644 --- a/torchrec/distributed/dist_data.py +++ b/torchrec/distributed/dist_data.py @@ -683,6 +683,7 @@ class VariableBatchPooledEmbeddingsAllToAll(nn.Module): Example:: + kjt_split = [1, 2] emb_dim_per_rank_per_feature = [[2], [3, 3]] a2a = VariableBatchPooledEmbeddingsAllToAll( pg, emb_dim_per_rank_per_feature, device @@ -690,7 +691,12 @@ class VariableBatchPooledEmbeddingsAllToAll(nn.Module): t0 = torch.rand(6) # 2 * (2 + 1) t1 = torch.rand(24) # 3 * (1 + 3) + 3 * (2 + 2) - r0_batch_size_per_rank_per_feature = [[2, 1]] + # r0_batch_size r1_batch_size + # f_0: 2 1 + ----------------------------------------- + # f_1: 1 2 + # f_2: 3 2 + r0_batch_size_per_rank_per_feature = [[2], [1]] r1_batch_size_per_rank_per_feature = [[1, 3], [2, 2]] r0_batch_size_per_feature_pre_a2a = [2, 1, 3] r1_batch_size_per_feature_pre_a2a = [1, 2, 2] @@ -703,7 +709,7 @@ class VariableBatchPooledEmbeddingsAllToAll(nn.Module): ).wait() # input splits: - # r0: [2*2, 1*1] + # r0: [2*2, 1*2] # r1: [1*3 + 3*3, 2*3 + 2*3] # output splits: