Skip to content

Commit

Permalink
Add test for DynamicActQuant
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jan 21, 2024
1 parent 080465c commit ef83ffb
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 14 deletions.
4 changes: 0 additions & 4 deletions src/brevitas/export/common/handler/qcdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,6 @@ def symbolic_execution(self, x: Tensor):
class DynamicQDQCastActQuantProxyHandlerMixin(DynamicQMixin, DQCastMixin, ABC):
handled_layer = DynamicActQuantProxyFromInjector

def validate(self, module):
super().validate(module)
assert module.is_signed == False, "Only unsigned quantization is supported"

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.validate(module)
Expand Down
6 changes: 5 additions & 1 deletion src/brevitas/export/onnx/standard/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ def int32_dtype(cls):

def validate(self, module):
super().validate(module)
# ONNX DynamicQuantizeLinear supports only 8b output with round to nearest even.

assert module.is_signed == False, "Only unsigned quantization supported"
assert module.quant_injector.scaling_stats_op == 'min_max', "Only min_max scaling op supported"
# ONNX QuantizeLinear supports only 8b output with round to nearest even.
# Below 8b quantization is supported through clipping.
assert module.rounding_mode.upper() == 'ROUND', 'Only round to nearest even supported'
# Below 8b quantization is not supported.
self.validate_8b_bit_width(module.bit_width(), le_then=False)
Expand Down
15 changes: 7 additions & 8 deletions src/brevitas_examples/common/generative/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,10 @@ class Int8ActDynamicPerTensorFloat(ActDynamicProxyMixin, Int8ActPerTensorFloat):
"""
Symmetric quantizer with per tensor dynamic scale.
"""
proxy_class = DynamicActQuantProxyFromInjector
scaling_impl = RuntimeDynamicStatsScaling
scaling_stats_input_view_shape_impl = OverTensorView
scaling_stats_op = 'max'
dynamic_scaling_broadcastable_shape = (-1,)
scaling_stats_op = 'min_max'
dynamic_scaling_broadcastable_shape = this.scaling_shape
stats_reduce_dim = 0


Expand All @@ -141,7 +140,7 @@ class Int8ActDynamicPerRowFloat(ActDynamicProxyMixin, Int8ActPerRowFloat):
"""
scaling_impl = RuntimeDynamicStatsScaling
scaling_stats_input_view_shape_impl = OverBatchOverOutputChannelView
scaling_stats_op = 'max'
scaling_stats_op = 'min_max'


class Int8ActDynamicPerGroupFloat(ActDynamicProxyMixin, Int8ActPerRowFloat):
Expand All @@ -150,7 +149,7 @@ class Int8ActDynamicPerGroupFloat(ActDynamicProxyMixin, Int8ActPerRowFloat):
"""
scaling_impl = RuntimeDynamicGroupStatsScaling
keepdim = True
scaling_stats_op = 'max'
scaling_stats_op = 'min_max'

@value
def stats_reduce_dim(group_dim):
Expand All @@ -159,11 +158,11 @@ def stats_reduce_dim(group_dim):

class ShiftedUint8ActDynamicPerTensorFloat(ActDynamicProxyMixin, ShiftedUint8ActPerTensorFloat):
"""
Asymmetric quantizer with per tensor dynamic scale.
Symmetric quantizer with per tensor dynamic scale.
"""
scaling_impl = RuntimeDynamicStatsScaling
scaling_stats_input_view_shape_impl = OverTensorView
scaling_stats_op = 'max'
scaling_stats_op = 'min_max'
zero_point_stats_impl = NegativeMinOrZero
dynamic_scaling_broadcastable_shape = (-1,)
dynamic_scaling_broadcastable_shape = this.scaling_shape
stats_reduce_dim = 0
4 changes: 3 additions & 1 deletion tests/brevitas_ort/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat
from brevitas_examples.common.generative.quantizers import ShiftedUint8ActDynamicPerTensorFloat

SEED = 123456
OUT_CH = 16
Expand Down Expand Up @@ -114,7 +115,8 @@ def recursive_allclose(ort_output, brevitas_output, tolerance):
def is_brevitas_ort_close(
model, np_input, export_name, export_type, tolerance=None, first_output_only=False):
input_t = torch.from_numpy(np_input)
brevitas_output = model(input_t)
with torch.no_grad():
brevitas_output = model(input_t)

if tolerance is not None and export_type == 'qcdq':
tolerance = tolerance * brevitas_output.scale # Float Output, tolerance is +/- output scale
Expand Down
5 changes: 5 additions & 0 deletions tests/brevitas_ort/test_quant_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,18 @@ def test_ort_wbiol(model, export_type, current_cases):
impl = case_id.split('-')[
-2] # Inverse list of definition, 'export_type' is -1, 'impl' is -2, etc.
quantizer = case_id.split('-')[-6]
o_bit_width = case_id.split('-')[-5]
i_bit_width = case_id.split('-')[-3]

if impl in ('QuantConvTranspose1d', 'QuantConvTranspose2d') and export_type == 'qop':
pytest.skip('Export of ConvTranspose is not supported for QOperation')
if 'per_channel' in quantizer and 'asymmetric' in quantizer:
pytest.skip('Per-channel zero-point is not well supported in ORT.')
if 'QuantLinear' in impl and 'asymmetric' in quantizer:
pytest.skip('ORT execution is unreliable and fails randomly on a subset of cases.')
if 'dynamic' in quantizer and ((o_bit_width != "o8" and i_bit_width != "i8") or
export_type != "qcdq"):
pytest.skip('Dynamic Act Quant supported only for 8bit and QCDQ export')

if impl in ('QuantLinear'):
in_size = (1, IN_CH)
Expand Down

0 comments on commit ef83ffb

Please sign in to comment.