Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unshard tensor sizes before binding. #3444

Merged
merged 18 commits into from
Nov 30, 2024
Merged

Unshard tensor sizes before binding. #3444

merged 18 commits into from
Nov 30, 2024

Conversation

wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Nov 19, 2024

Fixes #3282

With this PR, we'll still try to bind tensors to logical domains. However, tensor sizes are "unsharded" before binding.

@wujingyue wujingyue force-pushed the wjy/forward branch 2 times, most recently from 2111355 to 176ed8e Compare November 19, 2024 06:47
csrc/expr_evaluator.cpp Outdated Show resolved Hide resolved
@wujingyue wujingyue marked this pull request as ready for review November 25, 2024 23:26
@wujingyue
Copy link
Collaborator Author

!test

TensorView* b = makeSymbolicTensor(3);
b->split(1, 4);
b->axis(1)->parallelize(ParallelType::DIDx);
EXPECT_TRUE(isSharded(b)) << "DIDx on loop domain";
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to the change in isSharded, this is changed to look at the allocation domain. Also, I split this test into three and expanded the error messages.

csrc/fusion_segmenter.cpp Outdated Show resolved Hide resolved
csrc/transform_replay.cpp Outdated Show resolved Hide resolved
@naoyam
Copy link
Collaborator

naoyam commented Nov 26, 2024

Can you also run !test --diff just in case?

wujingyue added a commit that referenced this pull request Nov 26, 2024
…transforms (#3458)

This is a spin-off from #3444. 

The current code assumes that logical-to-allocation has to be a
permutation. This assumption won't hold any more with #2563. So this PR
tries to extend eraseInputDistinctRootDomains to support more general
transforms.

This can happen to single-GPU, although not as common. The tests added
in this PR are for single-GPU because #3444 hasn't landed. #3444 will
add some multi-GPU tests.
Base automatically changed from wjy/root to main November 26, 2024 03:49
@wujingyue
Copy link
Collaborator Author

!test --diff

@wujingyue
Copy link
Collaborator Author

Can you also run !test --diff just in case?

All passing

@wujingyue wujingyue requested a review from naoyam November 26, 2024 07:24
@wujingyue
Copy link
Collaborator Author

!test

@wujingyue
Copy link
Collaborator Author

!test

@wujingyue
Copy link
Collaborator Author

!test

Copy link
Collaborator

@samnordmann samnordmann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! only left some minor comments

I agree with the approach taken and how it's implemented, however, let me nit-pick on the wording used: I would say that the implemented approach binds the at::Tensor with the allocation domain, and not the logical domain as written in the comments. Indeed, the shape of the tensor will match the shape of the allocation domain, even though we still need to revert some division to "solve the equation" and obtain the size to which the symbolic extent must be bound. I guess it's ok if we don't use the same wording, but I just wanted to express how I interpret what is implemented here.

tests/cpp/test_sharding.cpp Show resolved Hide resolved
tests/cpp/test_sharding.cpp Show resolved Hide resolved
tests/cpp/test_multidevice_sharding.cpp Show resolved Hide resolved
for (auto* tv : {in, out}) {
tv->split(0, num_devices, /*inner_split=*/false);
tv->axis(0)->parallelize(ParallelType::DIDx);
tv->setAllocationDomain(tv->getLoopDomain(), true);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there some cases where we want Allocation and Loop domains to be different?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes and it's a limitation that they have to be the same in certain cases at this moment. #3479

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the link I read

The loop domain, ideally, shouldn't be set or used because a fusion/segment input comes from outside and is not generated by a loop.

It makes sense. But I am still curious if we know some cases where we want loop and allocation to be different.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't think of a case for DID, but there are certainly cases for the host-loop parallel type as we discussed before -- each iteration reads/writes a slice of fully-allocated input/output tensor.

tests/cpp/test_multidevice_sharding.cpp Show resolved Hide resolved
csrc/multidevice/utils.h Outdated Show resolved Hide resolved
// For example, when `tv` is
// logical: iM, iN
// allocation: iDIDx{D}, iN/D, iM
// and `sizes` is [2, 3], the returned shape will be [2, 3D]. This is because,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a mistake here..
if we bind {2,3} to {N/D, M}
then M=3 and N=2D, and so according to the comment, it should return the shape corresponding to the logical domain, i.e., [3, 2D]. Am I missing something?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moreover, do we support transposition?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment is correct and is consistent with the code.

ExpressionEvaluator::bindTensorDomain basically does the following

unsharded_sizes = unshardedSizes(t.sizes());
for (i : range(t.dim())) {
  bind(logical_domain[i], unsharded_sizes[i]);
}

That's also why I prefer to say we bind the unsharded sizes to the logical domain instead of allocation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment is correct and is consistent with the code.

I see. Imo it is error prone to silently discard transposition. We should assert that only splits have been applied, OR, we should support transposition, which shouldn't be too hard...

ExpressionEvaluator::bindTensorDomain basically does the following

unsharded_sizes = unshardedSizes(t.sizes());
for (i : range(t.dim())) {
  bind(logical_domain[i], unsharded_sizes[i]);
}

That's also why I prefer to say we bind the unsharded sizes to the logical domain instead of allocation.

I would say in this case that we bind to neither the logical nor the allocation, but to some hybrid domain where starting from the logical we only applied the splits. This is a bit counter-intuitive to me.

In your snippet above, everything is contained in the unsharded_sizes which basically embeds a mapping from allocation (or more precisely the hybrid domain I described earlier) to logical.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is error prone to silently discard transposition.

I believe code as is supports transposition. (I assume by transposition you mean TensorView::reorder). To assure you that, I added a test in the latest commit.

csrc/multidevice/utils.h Outdated Show resolved Hide resolved
csrc/multidevice/utils.h Show resolved Hide resolved
csrc/multidevice/utils.h Show resolved Hide resolved
@wujingyue
Copy link
Collaborator Author

!test


int64_t inner_size;
int64_t outer_size;
if (split->innerSplit()) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jjsjann123 this check was missing. I believe there's a similar problem in BackwardTraverseFromLogicalToAlloc which I didn't fix in this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to have a test exposing that first before fixing anything.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's why I didn't fix backward in this PR. The problem in forward was indeed exposed by LoopSplitWithReorder.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah no worries. We'll patch it when it shows up.

@wujingyue
Copy link
Collaborator Author

!test

@wujingyue
Copy link
Collaborator Author

!test

@wujingyue wujingyue merged commit c154e90 into main Nov 30, 2024
39 of 47 checks passed
@wujingyue wujingyue deleted the wjy/forward branch November 30, 2024 05:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Bind sharded input/output tensors with DID-parallelized allocation domains.
4 participants