diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index bbea538dd..8953b3338 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -276,10 +276,6 @@ def symbolic_execution(self, x: Tensor): class DynamicQDQCastActQuantProxyHandlerMixin(DynamicQMixin, DQCastMixin, ABC): handled_layer = DynamicActQuantProxyFromInjector - def validate(self, module): - super().validate(module) - assert module.is_signed == False, "Only unsigned quantization is supported" - def prepare_for_export(self, module): if module.is_quant_enabled: self.validate(module) diff --git a/src/brevitas/export/onnx/standard/qcdq/handler.py b/src/brevitas/export/onnx/standard/qcdq/handler.py index 79a2d95d9..7fe471489 100644 --- a/src/brevitas/export/onnx/standard/qcdq/handler.py +++ b/src/brevitas/export/onnx/standard/qcdq/handler.py @@ -94,7 +94,11 @@ def int32_dtype(cls): def validate(self, module): super().validate(module) - # ONNX DynamicQuantizeLinear supports only 8b output with round to nearest even. + + assert module.is_signed == False, "Only unsigned quantization supported" + assert module.quant_injector.scaling_stats_op == 'min_max', "Only min_max scaling op supported" + # ONNX QuantizeLinear supports only 8b output with round to nearest even. + # Below 8b quantization is supported through clipping. assert module.rounding_mode.upper() == 'ROUND', 'Only round to nearest even supported' # Below 8b quantization is not supported. self.validate_8b_bit_width(module.bit_width(), le_then=False) diff --git a/src/brevitas_examples/common/generative/quantizers.py b/src/brevitas_examples/common/generative/quantizers.py index e794f6ffb..8b71ab6b3 100644 --- a/src/brevitas_examples/common/generative/quantizers.py +++ b/src/brevitas_examples/common/generative/quantizers.py @@ -127,11 +127,10 @@ class Int8ActDynamicPerTensorFloat(ActDynamicProxyMixin, Int8ActPerTensorFloat): """ Symmetric quantizer with per tensor dynamic scale. """ - proxy_class = DynamicActQuantProxyFromInjector scaling_impl = RuntimeDynamicStatsScaling scaling_stats_input_view_shape_impl = OverTensorView - scaling_stats_op = 'max' - dynamic_scaling_broadcastable_shape = (-1,) + scaling_stats_op = 'min_max' + dynamic_scaling_broadcastable_shape = this.scaling_shape stats_reduce_dim = 0 @@ -141,7 +140,7 @@ class Int8ActDynamicPerRowFloat(ActDynamicProxyMixin, Int8ActPerRowFloat): """ scaling_impl = RuntimeDynamicStatsScaling scaling_stats_input_view_shape_impl = OverBatchOverOutputChannelView - scaling_stats_op = 'max' + scaling_stats_op = 'min_max' class Int8ActDynamicPerGroupFloat(ActDynamicProxyMixin, Int8ActPerRowFloat): @@ -150,7 +149,7 @@ class Int8ActDynamicPerGroupFloat(ActDynamicProxyMixin, Int8ActPerRowFloat): """ scaling_impl = RuntimeDynamicGroupStatsScaling keepdim = True - scaling_stats_op = 'max' + scaling_stats_op = 'min_max' @value def stats_reduce_dim(group_dim): @@ -159,11 +158,11 @@ def stats_reduce_dim(group_dim): class ShiftedUint8ActDynamicPerTensorFloat(ActDynamicProxyMixin, ShiftedUint8ActPerTensorFloat): """ - Asymmetric quantizer with per tensor dynamic scale. + Symmetric quantizer with per tensor dynamic scale. """ scaling_impl = RuntimeDynamicStatsScaling scaling_stats_input_view_shape_impl = OverTensorView - scaling_stats_op = 'max' + scaling_stats_op = 'min_max' zero_point_stats_impl = NegativeMinOrZero - dynamic_scaling_broadcastable_shape = (-1,) + dynamic_scaling_broadcastable_shape = this.scaling_shape stats_reduce_dim = 0 diff --git a/tests/brevitas_ort/common.py b/tests/brevitas_ort/common.py index c05fd59b9..84a44c4df 100644 --- a/tests/brevitas_ort/common.py +++ b/tests/brevitas_ort/common.py @@ -30,6 +30,7 @@ from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat +from brevitas_examples.common.generative.quantizers import ShiftedUint8ActDynamicPerTensorFloat SEED = 123456 OUT_CH = 16 @@ -114,7 +115,8 @@ def recursive_allclose(ort_output, brevitas_output, tolerance): def is_brevitas_ort_close( model, np_input, export_name, export_type, tolerance=None, first_output_only=False): input_t = torch.from_numpy(np_input) - brevitas_output = model(input_t) + with torch.no_grad(): + brevitas_output = model(input_t) if tolerance is not None and export_type == 'qcdq': tolerance = tolerance * brevitas_output.scale # Float Output, tolerance is +/- output scale diff --git a/tests/brevitas_ort/test_quant_module.py b/tests/brevitas_ort/test_quant_module.py index a4f7e7f5c..c9d142fab 100644 --- a/tests/brevitas_ort/test_quant_module.py +++ b/tests/brevitas_ort/test_quant_module.py @@ -27,6 +27,8 @@ def test_ort_wbiol(model, export_type, current_cases): impl = case_id.split('-')[ -2] # Inverse list of definition, 'export_type' is -1, 'impl' is -2, etc. quantizer = case_id.split('-')[-6] + o_bit_width = case_id.split('-')[-5] + i_bit_width = case_id.split('-')[-3] if impl in ('QuantConvTranspose1d', 'QuantConvTranspose2d') and export_type == 'qop': pytest.skip('Export of ConvTranspose is not supported for QOperation') @@ -34,6 +36,9 @@ def test_ort_wbiol(model, export_type, current_cases): pytest.skip('Per-channel zero-point is not well supported in ORT.') if 'QuantLinear' in impl and 'asymmetric' in quantizer: pytest.skip('ORT execution is unreliable and fails randomly on a subset of cases.') + if 'dynamic' in quantizer and ((o_bit_width != "o8" and i_bit_width != "i8") or + export_type != "qcdq"): + pytest.skip('Dynamic Act Quant supported only for 8bit and QCDQ export') if impl in ('QuantLinear'): in_size = (1, IN_CH)