Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add --fp8 option to quantize models in FP8 inside migraphx-driver #2535

Merged
merged 138 commits into from
Dec 12, 2023

Conversation

umangyadav
Copy link
Member

@umangyadav umangyadav commented Dec 7, 2023

Depends on #2506

Follows same scheme as Int8 quantization except it uses different scales compared to Int8.

@shivadbhavsar
Copy link
Contributor

We should also expose quantize_fp8 to the APIs in the same way we have quantize_int8 and quantize_fp16

Copy link
Collaborator

@CharlieL7 CharlieL7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@umangyadav umangyadav marked this pull request as ready for review December 8, 2023 14:58
@@ -41,11 +41,19 @@ struct program;
MIGRAPHX_EXPORT void quantize_fp16(program& prog,
const std::vector<std::string>& ins_names = {"all"});

MIGRAPHX_EXPORT void quantize_8bits(program& prog,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think this needs to be declared in the header, its only used internally for quantize_int8 and quantize_fp8.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

std::make_shared<std::vector<std::pair<float, float>>>();
std::shared_ptr<std::vector<float>> max_abs_vals = std::make_shared<std::vector<float>>();

auto calc_quant_params = [int8_quant_params, max_abs_vals, &t](std::size_t ins_index,
std::vector<argument> args) {
float quantized_range = (precision == shape::type_t::int8_type) ? 127.0 : 240.0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put this in another function? Ideally we should use a visit+numeric_limits to get the quantized range, but we can leave it as is for now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOt making change for now.

auto calc_quant_params = [int8_quant_params, max_abs_vals, &t](std::size_t ins_index,
std::vector<argument> args) {
float quantized_range = (precision == shape::type_t::int8_type) ? 127.0 : 240.0;
auto calc_quant_params = [quant_8bit_params, max_abs_vals, quantized_range, &t](
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can use [&] capture for the lambda.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

{
auto zero_point = m.add_literal(static_cast<int8_t>(param.second));
auto zero_point = m.add_literal(
migraphx::literal{migraphx::shape{precision}, {static_cast<int8_t>(param.second)}});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cast is not needed anymore.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

@codecov-commenter
Copy link

codecov-commenter commented Dec 8, 2023

Codecov Report

Attention: 18 lines in your changes are missing coverage. Please review.

Comparison is base (9d2003a) 91.50% compared to head (9692c57) 91.41%.
Report is 1 commits behind head on develop.

❗ Current head 9692c57 differs from pull request most recent head dc2263c. Consider uploading reports for the commit dc2263c to get more accurate results

Files Patch % Lines
src/quantization.cpp 51.72% 14 Missing ⚠️
src/quantize_8bits.cpp 80.00% 2 Missing ⚠️
src/include/migraphx/op/quant_dot.hpp 80.00% 1 Missing ⚠️
src/simplify_reshapes.cpp 50.00% 1 Missing ⚠️

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #2535      +/-   ##
===========================================
- Coverage    91.50%   91.41%   -0.09%     
===========================================
  Files          453      452       -1     
  Lines        17183    17153      -30     
===========================================
- Hits         15723    15681      -42     
- Misses        1460     1472      +12     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@umangyadav umangyadav requested a review from pfultz2 December 12, 2023 14:25
@umangyadav umangyadav changed the base branch from quant_gemm_fp8 to develop December 12, 2023 14:33
{
continue;
}
else if(not starts_with(ins->name(), "@"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The else is redundant here.

@causten causten merged commit db3c07f into develop Dec 12, 2023
28 of 36 checks passed
@causten causten deleted the add_fp8_quantizer branch December 12, 2023 18:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
FP8 issues related to FP8 implemenation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants