Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allgather with DID loop split #3284
Allgather with DID loop split #3284
Changes from 16 commits
04e06a8
416f1d0
2c984c8
44b5091
521c783
9ff10cf
8bd9486
5a16349
2768970
e38962e
400684e
3086237
fe0cec6
5d60dd5
7ccfccd
39f2809
7cf2384
c27a585
c218c70
9c2a218
5229512
5157fe1
6ad52dd
File filter
Filter by extension
Conversations
Jump to
There are no files selected for viewing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, however if the sharded dimension is split, then
scatted_axis
is not valid here, right?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't think of an immediate problem and #3504 apparently works fine. Could be incidental and I'm happy to hear what you think is problematic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct me if I'm wrong, but I think this is an example where we see the problem:
In this case, the scattered axis is 2 but
getShardedAxis
returns 1.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In your case,
getShardedLogicalAxis will return 0, the tensor axis being sharded. This is correct because the output at::Tensor for
tv1
will be of shape [i1/d] and indeed axis 0 is the sharded dimension. Then, scattered_axis=0 will be used to compute which input tensor axis will be sharded (which will be 1). Finally, that input scattered axis (1) will be used to split the input tensor of shape [1, i1].Caveat: With 7cf2384, DID'ing an inner split is disallowed by code. So the above case will actually throw an exception. But what I said should be correct after we lift that limitation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure to understand why this check is needed. Isn't it true that by assumption what is returned by getInputsTo is an element of tv->getLogicalDomain()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also I am not sure what is meant by "dominate" in the error message
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re dominate: https://en.wikipedia.org/wiki/Dominator_(graph_theory) and I extended the concept to a set of nodes dominating another set.
Re the check: I heard from @naoyam that logical won't always dominate allocation with "the new indexing system".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand how shardTensor can be correct here if it never replays the split backwards... But I might be missing something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review! I think there are two problems with the PR as is:
shardTensor
may slice wrong numbers. For example, if an inner split is DID'ed, the slicing needs to be strided per the outer split.Fuser/csrc/multidevice/utils.cpp
Line 77 in 67127c9
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re the suggested change: I manually checked the shape is as expected. I added some extra unit tests for shardTensor alone, so we don't have to verify it here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made a couple of changes to address the problems I said in #3284 (comment).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact, there's
Fuser/tests/cpp/test_multidevice_overlap.cpp
Line 681 in 64bc560
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use
validate
here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed allgather's lowering was not changed...I'm a bit surprised it didn't need any modifications for inputs with DID loop split! I might have missed a few earlier PRs though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since
validate
allows for (small) differences, if two tensors are supposed to be exactly the same, just using the simpler validation method, i.e.,at::equal
, would be more preferable.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whether we call lowerToAllGather depends on I/O meshes and whether I/O is sharded:
Fuser/csrc/multidevice/lower_communication.cpp
Line 285 in 67127c9
That being said, I think this PR as is is a bit too permissive and may lower a
set
toAllgather
without properly checking its allocation domain. For example,Fuser/csrc/multidevice/utils.cpp
Line 77 in 67127c9
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to address this in #3284 (comment).