Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP8 OCP to FP8 FNUZ on hardware with only FP8 FNUZ support #3684

Open
wants to merge 27 commits into
base: develop
Choose a base branch
from

Conversation

CharlieL7
Copy link
Collaborator

@CharlieL7 CharlieL7 commented Dec 5, 2024

  • NANOO is short for NAN On Overflow, the data type comes from this paper: https://arxiv.org/pdf/2206.02915
  • Implements the method written about in Convert OCP FP8 model to FNUZ model inside MIGraphX #2717
  • This pass must run before simplify_qdq so that the adjusted scales and zero points are propagated to after the quantized operator.
  • The test in test/fp8_ocp_to_nanoo.cpp checks the pass works with simplify_qdq and does the expected operations
  • The test in test/ref/fp8_ocp_to_nanoo.cpp checks the pass produces the same result before and after
  • I will make a separate PR that removes the gpu context changes to get the gfx number
  • Fixed the cpp_generator that was using __builtin_nan incorrectly

@CharlieL7 CharlieL7 added the FP8 issues related to FP8 implemenation label Dec 5, 2024
@CharlieL7 CharlieL7 self-assigned this Dec 5, 2024
@CharlieL7 CharlieL7 marked this pull request as ready for review December 10, 2024 20:01
@CharlieL7 CharlieL7 requested a review from causten as a code owner December 10, 2024 20:01
@CharlieL7 CharlieL7 changed the title FP8 OCP to FP8 NANOO on hardware with only FP8 NANOO support FP8 OCP to FP8 FNUZ on hardware with only FP8 FNUZ support Dec 10, 2024
Copy link

codecov bot commented Dec 11, 2024

Codecov Report

Attention: Patch coverage is 97.46835% with 2 lines in your changes missing coverage. Please review.

Project coverage is 92.23%. Comparing base (f56b1b4) to head (3c36b9b).
Report is 15 commits behind head on develop.

Files with missing lines Patch % Lines
src/fp8_ocp_to_fnuz.cpp 98.50% 1 Missing ⚠️
src/simplify_qdq.cpp 75.00% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #3684      +/-   ##
===========================================
+ Coverage    92.21%   92.23%   +0.02%     
===========================================
  Files          514      517       +3     
  Lines        21750    21819      +69     
===========================================
+ Hits         20056    20124      +68     
- Misses        1694     1695       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@CharlieL7
Copy link
Collaborator Author

Fixed the bug in the the pointwise compilation. __builtin_nan requires a string input that affects the most significant bits.

@CharlieL7 CharlieL7 requested a review from pfultz2 December 13, 2024 20:09
@CharlieL7 CharlieL7 requested a review from ahsan-ca December 13, 2024 20:09
@migraphx-bot
Copy link
Collaborator

Test Batch Rate new
3c36b9
Rate old
79a256
Diff Compare
torchvision-resnet50 64 3,258.74 3,254.68 0.12%
torchvision-resnet50_fp16 64 6,984.47 6,988.98 -0.06%
torchvision-densenet121 32 2,434.85 2,435.20 -0.01%
torchvision-densenet121_fp16 32 4,078.51 4,089.35 -0.27%
torchvision-inceptionv3 32 1,628.47 1,629.15 -0.04%
torchvision-inceptionv3_fp16 32 2,747.68 2,750.39 -0.10%
cadene-inceptionv4 16 765.40 765.66 -0.03%
cadene-resnext64x4 16 812.27 812.33 -0.01%
slim-mobilenet 64 7,465.27 7,465.36 -0.00%
slim-nasnetalarge 64 209.02 209.02 0.00%
slim-resnet50v2 64 3,439.49 3,439.27 0.01%
bert-mrpc-onnx 8 1,150.69 1,145.81 0.43%
bert-mrpc-tf 1 503.08 466.39 7.87% 🔆
pytorch-examples-wlang-gru 1 410.04 421.94 -2.82%
pytorch-examples-wlang-lstm 1 387.84 381.01 1.79%
torchvision-resnet50_1 1 805.75 763.70 5.51% 🔆
cadene-dpn92_1 1 401.37 434.24 -7.57% 🔴
cadene-resnext101_1 1 382.46 383.59 -0.29%
onnx-taau-downsample 1 346.38 346.01 0.11%
dlrm-criteoterabyte 1 33.35 33.33 0.06%
dlrm-criteoterabyte_fp16 1 52.76 52.73 0.05%
agentmodel 1 8,215.78 8,229.23 -0.16%
unet_fp16 2 58.87 58.93 -0.09%
resnet50v1_fp16 1 975.05 1,025.60 -4.93% 🔴
resnet50v1_int8 1 1,027.54 1,052.08 -2.33%
bert_base_cased_fp16 64 1,170.31 1,169.64 0.06%
bert_large_uncased_fp16 32 363.18 363.34 -0.04%
bert_large_fp16 1 198.76 198.80 -0.02%
distilgpt2_fp16 16 2,199.98 2,201.30 -0.06%
yolov5s 1 533.93 529.55 0.83%
tinyllama 1 43.41 43.36 0.10%
vicuna-fastchat 1 174.80 170.51 2.51%
whisper-tiny-encoder 1 417.32 417.88 -0.13%
whisper-tiny-decoder 1 433.93 425.43 2.00%

This build is not recommended to merge 🔴

@migraphx-bot
Copy link
Collaborator


     ✅ bert-mrpc-onnx: PASSED: MIGraphX meets tolerance

     ✅ bert-mrpc-tf: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

     ✅ torchvision-resnet50_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-dpn92_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-resnext101_1: PASSED: MIGraphX meets tolerance

     ✅ dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

     ✅ agentmodel: PASSED: MIGraphX meets tolerance

     ✅ unet: PASSED: MIGraphX meets tolerance

     ✅ resnet50v1: PASSED: MIGraphX meets tolerance

     ✅ bert_base_cased_fp16: PASSED: MIGraphX meets tolerance

🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ bert_large: PASSED: MIGraphX meets tolerance

     ✅ yolov5s: PASSED: MIGraphX meets tolerance

     ✅ tinyllama: PASSED: MIGraphX meets tolerance

     ✅ vicuna-fastchat: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-encoder: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-decoder: PASSED: MIGraphX meets tolerance

     ✅ distilgpt2_fp16: PASSED: MIGraphX meets tolerance

* intrinsically. Conversion uses the same bit representation and adjusts scaling factors at the
* dequantization. Using the same bit representation from fp8e4m3fn to fp8e4m3fnuz halves the
* floating point representation. This pass should run before simplify_qdq so that the scales and
* zero points calculated by simplify_qdq have the correct adjusted scaling factors
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Appreciate this comment.

@@ -220,8 +220,8 @@ cpp_generator::function cpp_generator::generate_module(const module& m,
if(x < 0)
string_literal = "-__builtin_huge_val()";
}
else if(std::isnan(static_cast<double>(x)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think static_cast is needed for windows.

run_propagate_constant(m1);
run_propagate_constant(m3);
run_cse(m1);
run_cse(m3);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just combine all these passes into one function call?

Copy link
Collaborator

@TedThemistokleous TedThemistokleous left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Add some additional tests for some of the coverage warnings and that's about it.

Copy link
Collaborator

@TedThemistokleous TedThemistokleous left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Disregard, you're already like 97% covered here. Looks good

@TedThemistokleous TedThemistokleous added the roadmap Tasks to finish for a release label Dec 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
FP8 issues related to FP8 implemenation roadmap Tasks to finish for a release
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants