Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance of quantizelinear for int4 #1706

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from

Conversation

dhernandez0
Copy link
Contributor

@dhernandez0 dhernandez0 commented Dec 17, 2024

In this PR we improve the performance of quantizelinear for int4, these are the changes:

  • Get the right element type for the getVectorDim() call, when there are input fusions the type can change and we should get the original type to decide how many k/d to copy per thread.
  • Move genericops before lds barrier if possible
  • Pack scale and bias together in the same tensor (quantizelinear)

There's a PR in migraphx to also change the layout of the scale+bias tensor: ROCm/AMDMIGraphX#3718

This is the migraphx program of the layout change (int32 packing scale and bias together):

p = migraphx.program()
m = p.get_main_module()
x_1 = m.add_parameter("x1", migraphx.shape(type="half_type", lens=[384, 32, 32, 2]))
x_2 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="uint8_type", lens=[12288, 2048]), 2))
p_x4 = m.add_parameter("x4", migraphx.shape(type="half_type", lens=[1, 1, 4096]))
x_1_transposed = m.add_instruction(migraphx.op("transpose", permutation=[0,2,1,3]), [x_1])
x_1_new = m.add_instruction(migraphx.op("reshape", dims=[12288,32,2]), [x_1_transposed])
x_4 = m.add_instruction(migraphx.op("unpack_int4", axis=1), [x_2]) # migraphx.shape(type="uint8_type", lens=[12288, 4096])
x_6 = m.add_instruction(migraphx.op("unsqueeze", axes=[2]), [x_1_new]) # migraphx.shape(type="half_type", lens=[12288, 32, 1, 1, 1, 1, 1])
x_7 = m.add_instruction(migraphx.op("multibroadcast", out_lens=[12288,32,128,2]), [x_6]) # migraphx.shape(type="half_type", lens=[12288, 32, 1, 1, 1, 1, 128], strides=[32, 1, 1, 1, 1, 1, 0])
x_8 = m.add_instruction(migraphx.op("reshape", dims=[12288,4096,2]), [x_7]) # migraphx.shape(type="half_type", lens=[12288, 4096])
scale = m.add_instruction(migraphx.op("slice", axes=[2], starts=[0], ends=[1]), [x_8])
bias = m.add_instruction(migraphx.op("slice", axes=[2], starts=[1], ends=[2]), [x_8])
scale_squeeze = m.add_instruction(migraphx.op("squeeze", axes=2), [scale])
bias_squeeze = m.add_instruction(migraphx.op("squeeze", axes=2), [bias])
x_12 = m.add_instruction(migraphx.op("dequantizelinear"), [x_4, scale_squeeze, bias_squeeze]) # migraphx.shape(type="half_type", lens=[12288, 4096])
x_13 = m.add_instruction(migraphx.op("unsqueeze", axes=[0]), [x_12]) # migraphx.shape(type="half_type", lens=[1, 12288, 4096])
x_14 = m.add_instruction(migraphx.op("transpose", permutation=[0,2,1]), [x_13]) # migraphx.shape(type="half_type", lens=[1, 4096, 12288], strides=[50331648, 1, 4096])
x_15 = m.add_instruction(migraphx.op("dot"), [p_x4, x_14]) # migraphx.shape(type="half_type", lens=[1, 1, 12288])
m.add_return([x_15])
develop branch this PR (+new layout) speed up
gfx1101 0.1248ms 0.0897ms 1.39x
gfx1150 (strix) 0.5302ms 0.4071ms 1.30x

@pfultz2 pointed out we can use slice operations instead of changing quantizelinear to use one param. This simplifies this PR a lot.

TODO:

Comment on lines 100 to 102
} else if (op.getOperatorName() == "unpack_scale") {
assert(inElemType == b.getI32Type());
assert(outElemType == b.getF16Type());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good in general. But it would be nice to have a generic unpack operator which would take input_type and elemType and return array of elemType[] instead.
For example in this case, unpack(input) could return vector<outElemType>() and then scale would be first element and bias would be second element.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's a limitation in our codebase regarding linalg::GenericOp number of outputs (for example, look at findPostFusionTransforms in transformMapUtils.cpp). It's assumed the output is always one. We can do a generic unpack(input, element) so that it outputs the output[element].

mlir/lib/Conversion/MIGraphXToTosa/MIGraphXToTosa.cpp Outdated Show resolved Hide resolved
Comment on lines +2572 to +2575
} else {
LLVM_DEBUG(
llvm::dbgs()
<< "Found a linalg.generic that takes as input the gemm A or B\n");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its checking outputs not inputs. genericOut

Copy link

codecov bot commented Dec 17, 2024

Codecov Report

Attention: Patch coverage is 23.80952% with 48 lines in your changes missing coverage. Please review.

Project coverage is 78.45%. Comparing base (a4e8230) to head (e663907).

Files with missing lines Patch % Lines
...lir/lib/Dialect/Rock/utility/transformMapUtils.cpp 6.12% 45 Missing and 1 partial ⚠️
...ialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp 83.33% 0 Missing and 2 partials ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #1706      +/-   ##
===========================================
- Coverage    78.52%   78.45%   -0.08%     
===========================================
  Files          100      100              
  Lines        28346    28405      +59     
  Branches      4130     4146      +16     
===========================================
+ Hits         22260    22285      +25     
- Misses        4426     4458      +32     
- Partials      1660     1662       +2     
Flag Coverage Δ
mfma 78.45% <23.80%> (-0.08%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Value output = op.getOutput();
Location loc = op->getLoc();

Type origBiasType;
if (bias)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add tests in migraphx-to-tosa.mlir

assert(inElemType == b.getI32Type());
assert(outElemType == b.getF16Type());

Value offset = b.create<arith::ConstantIntOp>(loc, 16, inElemType);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add tests in rocmlir-custom-tosa-to-linalg.mlir

@@ -2549,7 +2549,16 @@ struct GridwiseGemmAccelRewritePattern

// Obtain data types of inputs.
auto elementTypeA = op.getA().getType().getElementType();
auto maybeElementTypeALoad = getGemmInputElementType(op.getA());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to have a unit-test that uses the getGemmInputElementType to test out that logic along with gridwisegemmtoblockwise.

@dhernandez0 dhernandez0 force-pushed the 1665-quantizelinear-slower-than-f16-on-mi300x branch from 1748a96 to f87b60a Compare December 18, 2024 15:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants