Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pass to rewrite pow2 div #2844

Merged
merged 6 commits into from
Mar 7, 2024
Merged

Add pass to rewrite pow2 div #2844

merged 6 commits into from
Mar 7, 2024

Conversation

gyulaz-htec
Copy link
Collaborator

This pass is needed for Llama2 fp16 where RMSNorm calc can go out of bounds outputting inf values.
The pass rewrites x^2/n to (x/sqrt(n))^2.

@gyulaz-htec gyulaz-htec requested a review from causten as a code owner February 29, 2024 12:46
@gyulaz-htec gyulaz-htec requested review from pfultz2 and removed request for causten February 29, 2024 12:46
Copy link

codecov bot commented Feb 29, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 91.76%. Comparing base (effedcd) to head (eb80d63).

Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #2844      +/-   ##
===========================================
+ Coverage    91.75%   91.76%   +0.01%     
===========================================
  Files          473      475       +2     
  Lines        17958    17982      +24     
===========================================
+ Hits         16478    16502      +24     
  Misses        1480     1480              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@migraphx-bot
Copy link
Collaborator

migraphx-bot commented Feb 29, 2024

Test Batch Rate new
eb80d6
Rate old
effedc
Diff Compare
torchvision-resnet50 64 2,854.20 2,853.96 0.01%
torchvision-resnet50_fp16 64 6,520.70 6,521.98 -0.02%
torchvision-densenet121 32 2,089.09 2,104.22 -0.72%
torchvision-densenet121_fp16 32 3,698.01 3,697.17 0.02%
torchvision-inceptionv3 32 1,604.72 1,605.56 -0.05%
torchvision-inceptionv3_fp16 32 2,576.82 2,574.71 0.08%
cadene-inceptionv4 16 724.46 725.17 -0.10%
cadene-resnext64x4 16 682.84 682.95 -0.02%
slim-mobilenet 64 5,934.83 5,954.04 -0.32%
slim-nasnetalarge 64 152.85 152.96 -0.07%
slim-resnet50v2 64 2,667.40 2,669.14 -0.07%
bert-mrpc-onnx 8 826.02 826.64 -0.07%
bert-mrpc-tf 1 382.53 383.07 -0.14%
pytorch-examples-wlang-gru 1 236.06 239.32 -1.36%
pytorch-examples-wlang-lstm 1 244.95 243.94 0.41%
torchvision-resnet50_1 1 607.04 612.09 -0.83%
cadene-dpn92_1 1 392.10 393.43 -0.34%
cadene-resnext101_1 1 332.05 331.82 0.07%
onnx-taau-downsample 1 305.47 305.46 0.00%
dlrm-criteoterabyte 1 21.58 21.55 0.11%
dlrm-criteoterabyte_fp16 1 40.71 40.70 0.01%
agentmodel 1 4,825.21 4,243.71 13.70% 🔆
unet_fp16 2 56.16 56.15 0.01%
resnet50v1_fp16 1 900.16 904.94 -0.53%
resnet50v1_int8 1 797.44 796.67 0.10%
bert_base_cased_fp16 64 936.57 936.72 -0.02%
bert_large_uncased_fp16 32 292.71 292.73 -0.01%
bert_large_fp16 1 183.94 184.04 -0.06%
distilgpt2_fp16 16 1,640.47 1,640.92 -0.03%
yolov5s 1 483.62 492.30 -1.76%
tinyllama 1 32.64 32.62 0.05%
vicuna-fastchat 1 156.70 154.57 1.38%
whisper-tiny-encoder 1 335.92 335.30 0.19%
whisper-tiny-decoder 1 373.58 373.60 -0.01%

Check results before merge 🔆

@migraphx-bot
Copy link
Collaborator

migraphx-bot commented Feb 29, 2024


     ✅ bert-mrpc-onnx: PASSED: MIGraphX meets tolerance

     ✅ bert-mrpc-tf: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

     ✅ torchvision-resnet50_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-dpn92_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-resnext101_1: PASSED: MIGraphX meets tolerance

     ✅ dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

     ✅ agentmodel: PASSED: MIGraphX meets tolerance

     ✅ unet: PASSED: MIGraphX meets tolerance

     ✅ resnet50v1: PASSED: MIGraphX meets tolerance

     ✅ bert_base_cased_fp16: PASSED: MIGraphX meets tolerance

🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ bert_large: PASSED: MIGraphX meets tolerance

     ✅ yolov5s: PASSED: MIGraphX meets tolerance

     ✅ tinyllama: PASSED: MIGraphX meets tolerance

     ✅ vicuna-fastchat: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-encoder: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-decoder: PASSED: MIGraphX meets tolerance

     ✅ distilgpt2_fp16: PASSED: MIGraphX meets tolerance

@pfultz2
Copy link
Collaborator

pfultz2 commented Feb 29, 2024

This should go into a seperate pass. Maybe it could be called rewrite_low_precision?

src/simplify_algebra.cpp Outdated Show resolved Hide resolved
src/simplify_algebra.cpp Outdated Show resolved Hide resolved
@gyulaz-htec
Copy link
Collaborator Author

gyulaz-htec commented Mar 1, 2024

@pfultz2 I've updated the PR

  • Moved the rewrite in the a pass (rewrite_low_precision).
  • Updated the matcher to match mul(x,x).
  • Enabled only fp16 for the rewrite, should I include fp8 as well?
  • Added test to cover mul(x,x) and some more to tests to check the new pass don't work with other types or with mul(x,y).

@gyulaz-htec gyulaz-htec force-pushed the pow2_div_rewrite branch 2 times, most recently from c5977a2 to 8973085 Compare March 1, 2024 15:40
src/quantization.cpp Outdated Show resolved Hide resolved
This pass is needed for Llama2 fp16 where RMSNorm calc can go out of bounds and output inf values.
The pass rewrites x^2/n to (x/sqrt(n))^2.
Rewrite operators in low precision types to avoid going out of precision bounds.
@gyulaz-htec
Copy link
Collaborator Author

@pfultz2 I've made the proposed chanegs

auto ins = r.result;
auto n = r.instructions["n"];
auto x = r.instructions["x"];

Copy link
Contributor

@lakhinderwalia lakhinderwalia Mar 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If instead of:
x^2/n --> (x/sqrt(n))^2,
If the following were applied, would there be any loss of accuracy?
x^2/n --> (x/n) * x
Thanks.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's actually works, the accuracy is the same with that. This solution saves us one instruction, also removes the sqrt. Thanks for the idea.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

@gyulaz-htec
Copy link
Collaborator Author

@pfultz2 @umangyadav @lakhinderwalia I've updated the PR with what @lakhinderwalia proposed to use x^2/n --> (x/n) * x. The tests are passing with that change.

Copy link
Contributor

@lakhinderwalia lakhinderwalia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

@causten causten merged commit bbe1c56 into develop Mar 7, 2024
19 checks passed
@causten causten deleted the pow2_div_rewrite branch March 7, 2024 05:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants