Skip to content

Commit

Permalink
improve comment of VariableBatchPooledEmbeddingsAllToAll doc (pytorch…
Browse files Browse the repository at this point in the history
…#1423)

Summary:
Pull Request resolved: pytorch#1423

att

Reviewed By: henrylhtsang

Differential Revision: D49934622

fbshipit-source-id: e7360121ae38ed8dfdb57544669e27a51eede159
  • Loading branch information
Ning Wang authored and facebook-github-bot committed Oct 5, 2023
1 parent db79614 commit 369526d
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions torchrec/distributed/dist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,14 +683,20 @@ class VariableBatchPooledEmbeddingsAllToAll(nn.Module):
Example::
kjt_split = [1, 2]
emb_dim_per_rank_per_feature = [[2], [3, 3]]
a2a = VariableBatchPooledEmbeddingsAllToAll(
pg, emb_dim_per_rank_per_feature, device
)
t0 = torch.rand(6) # 2 * (2 + 1)
t1 = torch.rand(24) # 3 * (1 + 3) + 3 * (2 + 2)
r0_batch_size_per_rank_per_feature = [[2, 1]]
# r0_batch_size r1_batch_size
# f_0: 2 1
-----------------------------------------
# f_1: 1 2
# f_2: 3 2
r0_batch_size_per_rank_per_feature = [[2], [1]]
r1_batch_size_per_rank_per_feature = [[1, 3], [2, 2]]
r0_batch_size_per_feature_pre_a2a = [2, 1, 3]
r1_batch_size_per_feature_pre_a2a = [1, 2, 2]
Expand All @@ -703,7 +709,7 @@ class VariableBatchPooledEmbeddingsAllToAll(nn.Module):
).wait()
# input splits:
# r0: [2*2, 1*1]
# r0: [2*2, 1*2]
# r1: [1*3 + 3*3, 2*3 + 2*3]
# output splits:
Expand Down

0 comments on commit 369526d

Please sign in to comment.