Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accuracy issue in quantized resnet50 after match_qlinear_reused #2949

Closed
shivadbhavsar opened this issue Apr 3, 2024 · 2 comments
Closed
Assignees
Labels
bug Something isn't working TorchMIGraphX

Comments

@shivadbhavsar
Copy link
Contributor

The accuracy test in torch_migraphx for asymmetricly quantized resnet50 is failed after adding #2613. Here is a simpler repro program for the issue:

import migraphx
import numpy as np

np.random.seed(10)
asymmetric = True

scale_scalar = 0.004 if asymmetric else 0.008
zp_scalar = -128 if asymmetric else 0

p = migraphx.program()
mm = p.get_main_module()
mgx_shape = migraphx.shape(lens=[1, 16, 7, 7], type='float_type')
in1 = mm.add_parameter('x0', mgx_shape)

# clip inp to be between -1 and 1 => relu between 0, 1 => scale 0.004 will never overflow
min_ins = mm.add_literal(-1 * np.ones(shape=[1, 16, 7, 7], dtype=np.float32))
max_ins = mm.add_literal(np.ones(shape=[1, 16, 7, 7], dtype=np.float32))
in1 = mm.add_instruction(migraphx.op('clip'), [in1, min_ins, max_ins])

w = mm.add_literal(np.random.randint(-128, 128, size=[16, 16, 1, 1]).astype(np.int8))
scale = mm.add_literal(np.array([scale_scalar]).astype(np.float32))
scale_mb = relu = mm.add_instruction(migraphx.op('multibroadcast', out_lens=[1, 16, 7, 7]), [scale])
zp = mm.add_literal(np.array([zp_scalar]).astype(np.int8))
zp_mb = relu = mm.add_instruction(migraphx.op('multibroadcast', out_lens=[1, 16, 7, 7]), [zp])

relu = mm.add_instruction(migraphx.op('relu'), [in1])
qlin = mm.add_instruction(migraphx.op('quantizelinear'), [relu, scale_mb, zp_mb])
qconv = mm.add_instruction(migraphx.op('quant_convolution', 
                                       padding=[0, 0, 0, 0],
                                       stride=[1, 1],
                                       dilation=[1, 1],
                                       padding_mode=0,
                                       group=1), [qlin, w])


# Assume weight scale is 0.02 (this is similar to values seen in real model)
scale_dq = mm.add_literal(np.array([scale_scalar * 0.02]).astype(np.float32))
scale_dq_mb = mm.add_instruction(migraphx.op('multibroadcast', out_lens=[1, 16, 7, 7]), [scale_dq])
zp_dq = mm.add_literal(np.array([zp_scalar]).astype(np.int32))
zp_dq_mb = mm.add_instruction(migraphx.op('multibroadcast', out_lens=[1, 16, 7, 7]), [zp_dq])
dqlin = mm.add_instruction(migraphx.op('dequantizelinear'), [qconv, scale_dq_mb, zp_dq_mb])

add = mm.add_instruction(migraphx.op('add'), [relu, dqlin])
mm.add_return([add])

Verify using:
migraphx-driver verify qlinear_reused_fail.py

Without 2613, this passes in both cases: asymmetric = True, asymmetric = False

With 2613 this fails when asymmetric = True.

FAILED: qlinear_reused_fail.py
RMS Error: 0.0423974
Max diff: 1.026
Mismatch at 1: -2.9816 != -4.0076

The error here is not just from the extra dequantize added by this pass, seems like there is more going on here.

@pfultz2

@CharlieL7
Copy link
Collaborator

CharlieL7 commented Jul 12, 2024

Issue came up because there is a disagreement between what is expected in MIGX and ONNX and the MLIR implementation of dequantizelinear. See ROCm/rocMLIR#1567. Additionaly the qlinear_resued matcher in MIGX inherently reduces the accuracy of the model. Plan is to remove the referenced matcher in the develop branch and instead merge pointwise operations with multiple outputs into MLIR.

@CharlieL7
Copy link
Collaborator

This issue is resolved by ROCm/rocMLIR#1567 and commenting out the qlinear_reused matcher in the 6.2 release #3264. For a longer-term fix we should be removing the matcher altogether and fusing better: #3269.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working TorchMIGraphX
Projects
None yet
Development

No branches or pull requests

2 participants