Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support two outputs in split_reduce #3097

Merged
merged 49 commits into from
Jul 31, 2024
Merged

Support two outputs in split_reduce #3097

merged 49 commits into from
Jul 31, 2024

Conversation

pfultz2
Copy link
Collaborator

@pfultz2 pfultz2 commented May 16, 2024

No description provided.

@pfultz2 pfultz2 requested a review from causten as a code owner May 16, 2024 18:26
@TedThemistokleous TedThemistokleous added enhancement New feature or request roadmap Tasks to finish for a release labels May 16, 2024
test/split_reduce.cpp Show resolved Hide resolved
Comment on lines +299 to +304
return {rsum2, rsum1};
});
auto rsum2 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), rsum);
auto rsum1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), rsum);
auto mul =
add_pointwise(p2, mm, "main:pointwise1", {rsum1, rsum2}, single_pointwise("mul"));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rsum2 is at tuple elem index 0
and rsum1 is at tuple elem index 1.

Which can be error prone.

test/split_reduce.cpp Outdated Show resolved Hide resolved
test/split_reduce.cpp Outdated Show resolved Hide resolved
@umangyadav
Copy link
Member

This PR fails to compile if i run
MIGRAPHX_DISABLE_LAYERNORM_FUSION=1 MIGRAPHX_ENABLE_SPLIT_REDUCE=1 ./bin/test_verify "test_layernorm_large"

@umangyadav
Copy link
Member

This currently will not fuse
pointwise preceding "reduce" ops is pointwise is alive after the reduce_ops.

This is what happens with #3212

@6 = convolution[padding={1, 1, 1, 1},stride={1, 1},dilation={1, 1},group=1,padding_mode=0](sample,@0) -> float_type, {2, 320, 64, 64}, {1310720, 4096, 64, 1}
@7 = broadcast[axis=1,out_lens={2, 320, 64, 64}](@1) -> float_type, {2, 320, 64, 64}, {0, 1, 0, 0}
@8 = contiguous(@7) -> float_type, {2, 320, 64, 64}, {1310720, 4096, 64, 1}
@9 = reshape[dims={2, 32, 10, 64, 64}](@6) -> float_type, {2, 32, 10, 64, 64}, {1310720, 40960, 4096, 64, 1}
@10 = reshape[dims={2, 32, 10, 64, 64}](@8) -> float_type, {2, 32, 10, 64, 64}, {1310720, 40960, 4096, 64, 1}
@11 = multibroadcast[out_lens={2, 320, 64, 64},out_dyn_dims={}](@3) -> float_type, {2, 320, 64, 64}, {0, 1, 0, 0}
@12 = contiguous(@11) -> float_type, {2, 320, 64, 64}, {1310720, 4096, 64, 1}
@13 = multibroadcast[out_lens={2, 320, 64, 64},out_dyn_dims={}](@4) -> float_type, {2, 320, 64, 64}, {0, 1, 0, 0}
@14 = contiguous(@13) -> float_type, {2, 320, 64, 64}, {1310720, 4096, 64, 1}
@15 = reshape[dims={2, 32, 10, 64, 64}](@12) -> float_type, {2, 32, 10, 64, 64}, {1310720, 40960, 4096, 64, 1}
@16 = reshape[dims={2, 32, 10, 64, 64}](@14) -> float_type, {2, 32, 10, 64, 64}, {1310720, 40960, 4096, 64, 1}
@17 = pointwise(@9,@10), [main:pointwise0] -> float_type, {2, 32, 10, 64, 64}, {1310720, 40960, 4096, 64, 1}
@18 = split_fused_reduce[axes={2, 3, 4},assign=assign_add](@17), [main:pointwise0:main:pointwise2:main:reduce_sum1:main:pointwise4:main:pointwise6:main:pointwise1:main:reduce_sum0_reshape_reshape:main:pointwise10_split] -> [float_type, {2, 32, 1, 1, 1}, {32, 1, 1, 1, 1}, float_type, {2, 32, 1, 1, 1}, {32, 1, 1, 1, 1}]
@19 = get_tuple_elem[index=0](@18) -> float_type, {2, 32, 1, 1, 1}, {32, 1, 1, 1, 1}
@20 = get_tuple_elem[index=1](@18) -> float_type, {2, 32, 1, 1, 1}, {32, 1, 1, 1, 1}
@21 = multibroadcast[out_lens={2, 32, 10, 64, 64},out_dyn_dims={}](@19) -> float_type, {2, 32, 10, 64, 64}, {32, 1, 0, 0, 0}
@22 = multibroadcast[out_lens={2, 32, 10, 64, 64},out_dyn_dims={}](@20) -> float_type, {2, 32, 10, 64, 64}, {32, 1, 0, 0, 0}
@23 = pointwise(@21,@22,@17,@21,@15,@16), [main:pointwise4] -> float_type, {2, 32, 10, 64, 64}, {1310720, 40960, 4096, 64, 1}
@24 = reshape[dims={2, 320, 64, 64}](@23) -> float_type, {2, 320, 64, 64}, {1310720, 4096, 64, 1}
@25 = convolution[padding={1, 1, 1, 1},stride={1, 1},dilation={1, 1},group=1,padding_mode=0](@24,@2) -> float_type, {2, 320, 64, 64}, {1310720, 4096, 64, 1}

@17 is not fused with @18 because @17 isused later at @23.

@pfultz2
Copy link
Collaborator Author

pfultz2 commented Jul 24, 2024

@17 is not fused with @18 because @17 isused later at https://github.com/23.

Yea we need to support multi-output fusion in the future. We can probably just do the fusion initially with mlir.

@pfultz2
Copy link
Collaborator Author

pfultz2 commented Jul 24, 2024

This PR fails to compile if i run MIGRAPHX_DISABLE_LAYERNORM_FUSION=1 MIGRAPHX_ENABLE_SPLIT_REDUCE=1 ./bin/test_verify "test_layernorm_large"

This is fixed.

@causten causten merged commit 403ee86 into develop Jul 31, 2024
34 of 40 checks passed
@causten causten deleted the split-reduce2 branch July 31, 2024 13:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request roadmap Tasks to finish for a release
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants