Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Show mlir program when tracing benchmarking #2741

Merged
merged 18 commits into from
Jun 18, 2024
Merged

Conversation

pfultz2
Copy link
Collaborator

@pfultz2 pfultz2 commented Feb 8, 2024

This will show the mlir program when using MIGRAPHX_TRACE_BENCHMARKING=3.

Copy link

codecov bot commented Feb 8, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 92.01%. Comparing base (2c4dd4a) to head (c5a4040).
Report is 144 commits behind head on develop.

Additional details and impacted files
@@           Coverage Diff            @@
##           develop    #2741   +/-   ##
========================================
  Coverage    92.01%   92.01%           
========================================
  Files          490      490           
  Lines        19434    19434           
========================================
  Hits         17883    17883           
  Misses        1551     1551           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@migraphx-bot
Copy link
Collaborator

migraphx-bot commented Feb 9, 2024

Test Batch Rate new
c5a404
Rate old
2c4dd4
Diff Compare
torchvision-resnet50 64 1,745.37 1,744.86 0.03%
torchvision-resnet50_fp16 64 4,047.51 4,050.81 -0.08%
torchvision-densenet121 32 1,465.24 1,462.21 0.21%
torchvision-densenet121_fp16 32 2,525.93 2,527.03 -0.04%
torchvision-inceptionv3 32 877.52 877.83 -0.03%
torchvision-inceptionv3_fp16 32 1,485.55 1,485.81 -0.02%
cadene-inceptionv4 16 407.19 407.35 -0.04%
cadene-resnext64x4 16 419.37 419.43 -0.01%
slim-mobilenet 64 4,092.80 4,093.22 -0.01%
slim-nasnetalarge 64 101.12 101.12 0.00%
slim-resnet50v2 64 1,685.62 1,686.03 -0.02%
bert-mrpc-onnx 8 615.90 615.30 0.10%
bert-mrpc-tf 1 279.60 276.34 1.18%
pytorch-examples-wlang-gru 1 321.79 322.21 -0.13%
pytorch-examples-wlang-lstm 1 293.27 294.16 -0.30%
torchvision-resnet50_1 1 468.48 476.51 -1.69%
cadene-dpn92_1 1 249.13 248.85 0.11%
cadene-resnext101_1 1 200.41 205.81 -2.62%
onnx-taau-downsample 1 204.75 204.68 0.03%
dlrm-criteoterabyte 1 22.90 22.89 0.06%
dlrm-criteoterabyte_fp16 1 42.69 42.67 0.03%
agentmodel 1 6,179.25 6,007.95 2.85%
unet_fp16 2 34.20 34.25 -0.13%
resnet50v1_fp16 1 594.56 599.72 -0.86%
resnet50v1_int8 1 580.17 586.21 -1.03%
bert_base_cased_fp16 64 646.20 646.09 0.02%
bert_large_uncased_fp16 32 199.03 198.99 0.02%
bert_large_fp16 1 117.53 117.68 -0.13%
distilgpt2_fp16 16 1,209.40 1,210.37 -0.08%
yolov5s 1 301.44 301.54 -0.03%
tinyllama 1 23.32 23.33 -0.05%
vicuna-fastchat 1 133.21 133.06 0.11%
whisper-tiny-encoder 1 244.20 244.46 -0.10%
whisper-tiny-decoder 1 256.82 256.52 0.12%

This build is OK for merge ✅

@migraphx-bot
Copy link
Collaborator

migraphx-bot commented Feb 9, 2024


     ✅ bert-mrpc-onnx: PASSED: MIGraphX meets tolerance

     ✅ bert-mrpc-tf: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

     ✅ torchvision-resnet50_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-dpn92_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-resnext101_1: PASSED: MIGraphX meets tolerance

     ✅ dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

     ✅ agentmodel: PASSED: MIGraphX meets tolerance

     ✅ unet: PASSED: MIGraphX meets tolerance

     ✅ resnet50v1: PASSED: MIGraphX meets tolerance

     ✅ bert_base_cased_fp16: PASSED: MIGraphX meets tolerance

🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ bert_large: PASSED: MIGraphX meets tolerance

     ✅ yolov5s: PASSED: MIGraphX meets tolerance

     ✅ tinyllama: PASSED: MIGraphX meets tolerance

     ✅ vicuna-fastchat: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-encoder: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-decoder: PASSED: MIGraphX meets tolerance

     ✅ distilgpt2_fp16: PASSED: MIGraphX meets tolerance

if(not inputs.empty())
{
mm = m;
mr = &mm;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it require const_ref ? Shouldn't just copied module mm work ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only want to copy the module if there is no input shapes because we wont be adjusting the parameters.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Module with empty input shape is unlikely case it would have been const-folded.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to assume inputs are not empty. it would be simpler

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is two overloads to dump_mlir. One just takes the module(which we dont want to copy) and the other overload takes the input shapes, which might be different than the shapes in the module so we need to modify the module which we will use a copy for this case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and the other overload takes the input shapes, which might be different than the shapes in the module so we need to modify the module which we will use a copy for this case.

Yes but those input shapes parameter would be input arguments to the precompile_op instruction. If they are empty that means MLIR module also doesn't take any inputs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the inputs are empty it means skip doing param adjustments.

@@ -179,6 +185,8 @@ struct compile_plan
std::cout << "No binary" << std::endl;
return std::numeric_limits<double>::max();
}
if(trace_level > 2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add this option to documents

@pfultz2 pfultz2 requested a review from causten as a code owner May 14, 2024 17:53
@pfultz2 pfultz2 requested a review from a team as a code owner May 14, 2024 18:16
@pfultz2 pfultz2 requested a review from umangyadav May 14, 2024 18:29
@kahmed10 kahmed10 requested review from kahmed10 and bpickrel May 20, 2024 19:17
const compiled_result& benchmark() const
{
const auto trace_level = value_of(MIGRAPHX_TRACE_BENCHMARKING{});
if(trace_level > 0 and not results.empty())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest you consolidate some of these checks with if ... or ...

@causten causten merged commit 5a0cf97 into develop Jun 18, 2024
45 of 46 checks passed
@causten causten deleted the trace-mlir-benchmark branch June 18, 2024 21:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants