Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

match gemm_softmax_gemm when there is no scale #2748

Merged
merged 13 commits into from
Feb 16, 2024
Merged

match gemm_softmax_gemm when there is no scale #2748

merged 13 commits into from
Feb 16, 2024

Conversation

shivadbhavsar
Copy link
Contributor

Make the scale mul op optional when looking for attention. Required to get mlir attention for for SDXL workflow.

Copy link

codecov bot commented Feb 10, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (8e17050) 91.48% compared to head (19cc624) 91.48%.
Report is 3 commits behind head on develop.

Additional details and impacted files
@@           Coverage Diff            @@
##           develop    #2748   +/-   ##
========================================
  Coverage    91.48%   91.48%           
========================================
  Files          468      468           
  Lines        17539    17539           
========================================
  Hits         16045    16045           
  Misses        1494     1494           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

else
return;
});
if(r.instructions.find("scale") != r.instructions.end())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use contains(r.instructions, "scale") instead.

@migraphx-bot
Copy link
Collaborator

migraphx-bot commented Feb 10, 2024

Test Batch Rate new
917139
Rate old
f41336
Diff Compare
torchvision-resnet50 64 2,837.63 2,838.50 -0.03%
torchvision-resnet50_fp16 64 6,511.66 6,512.19 -0.01%
torchvision-densenet121 32 2,093.02 2,091.56 0.07%
torchvision-densenet121_fp16 32 3,693.96 3,690.69 0.09%
torchvision-inceptionv3 32 1,599.72 1,600.00 -0.02%
torchvision-inceptionv3_fp16 32 2,575.27 2,572.00 0.13%
cadene-inceptionv4 16 724.03 723.88 0.02%
cadene-resnext64x4 16 690.38 693.14 -0.40%
slim-mobilenet 64 6,889.67 6,899.88 -0.15%
slim-nasnetalarge 64 177.09 177.17 -0.04%
slim-resnet50v2 64 2,667.62 2,665.56 0.08%
bert-mrpc-onnx 8 826.84 827.18 -0.04%
bert-mrpc-tf 1 379.73 383.31 -0.93%
pytorch-examples-wlang-gru 1 238.16 240.12 -0.82%
pytorch-examples-wlang-lstm 1 242.93 241.56 0.57%
torchvision-resnet50_1 1 611.83 602.40 1.57%
cadene-dpn92_1 1 393.03 391.77 0.32%
cadene-resnext101_1 1 333.28 333.33 -0.02%
onnx-taau-downsample 1 305.93 306.07 -0.05%
dlrm-criteoterabyte 1 21.57 21.57 -0.02%
dlrm-criteoterabyte_fp16 1 40.70 40.70 0.02%
agentmodel 1 4,888.61 4,768.49 2.52%
unet_fp16 2 55.95 55.92 0.06%
resnet50v1_fp16 1 877.45 877.69 -0.03%
resnet50v1_int8 1 800.84 802.61 -0.22%
bert_base_cased_fp16 64 936.40 936.71 -0.03%
bert_large_uncased_fp16 32 292.67 292.73 -0.02%
bert_large_fp16 1 184.02 184.57 -0.30%
distilgpt2_fp16 16 1,639.46 1,644.53 -0.31%
yolov5s 1 495.20 488.02 1.47%
tinyllama 1 32.67 32.68 -0.03%
vicuna-fastchat 1 158.04 155.23 1.81%
whisper-tiny-encoder 1 335.02 336.49 -0.44%
whisper-tiny-decoder 1 375.00 374.83 0.04%

This build is OK for merge ✅

@migraphx-bot
Copy link
Collaborator


     ✅ bert-mrpc-onnx: PASSED: MIGraphX meets tolerance

     ✅ bert-mrpc-tf: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

     ✅ torchvision-resnet50_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-dpn92_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-resnext101_1: PASSED: MIGraphX meets tolerance

     ✅ dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

     ✅ agentmodel: PASSED: MIGraphX meets tolerance

     ✅ unet: PASSED: MIGraphX meets tolerance

     ✅ resnet50v1: PASSED: MIGraphX meets tolerance

     ✅ bert_base_cased_fp16: PASSED: MIGraphX meets tolerance

🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ bert_large: PASSED: MIGraphX meets tolerance

     ✅ yolov5s: PASSED: MIGraphX meets tolerance

     ✅ tinyllama: PASSED: MIGraphX meets tolerance

     ✅ vicuna-fastchat: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-encoder: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-decoder: PASSED: MIGraphX meets tolerance

     ✅ distilgpt2_fp16: PASSED: MIGraphX meets tolerance

auto scale_lit = r.instructions["scale"];
scale_lit->eval().visit([&](const auto s) {
// CK only supports single-valued scale
if(std::all_of(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to flip the logic so that it returns if scale values are different.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking something like this

Suggested change
if(std::all_of(
if(not std::all_of(....) return;
scale = s.front();

Copy link
Contributor Author

@shivadbhavsar shivadbhavsar Feb 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But thats inside a visit, that return will not exit the apply method for this matcher
Oh you dont mean to modify the current functionallity, just rewrite it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah I saw that too. But looks like you rewrote it and looks fine now.

@kahmed10
Copy link
Collaborator

Can you add a test case?

@shivadbhavsar
Copy link
Contributor Author

Can you add a test case?

This pass only really runs with special flags. How would the test cases work for this?
There are no existing cases for this pass


migraphx::program p1 = create_program();
migraphx::program p2;
if(migraphx::gpu::mlir_attention_enabled())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dont rely on this global setting. Instead add a flag to the prefuse_ops class like enable_attention that can be set to true during tests. And then you pass that along to the find_gemm_softmax_gemm so the matcher will match it.

auto dot2 = mm->add_instruction(migraphx::make_op("dot"), sm, z);
mm->add_return({dot2});
return p;
};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dont create lambdas. Instead just construct the program directly:

migraphx::program p1;
{
    auto* mm = p1.get_main_module();
    ...
}

struct find_gemm_softmax_gemm
{
bool enable_attention;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set this to false, so its not uninitialized when default constructing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@causten causten merged commit c8b6c6d into develop Feb 16, 2024
18 of 19 checks passed
@causten causten deleted the atten_matcher branch February 16, 2024 21:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants