diff --git a/tests/brevitas_ort/common.py b/tests/brevitas_ort/common.py index 55124fd30..f5c73a2a4 100644 --- a/tests/brevitas_ort/common.py +++ b/tests/brevitas_ort/common.py @@ -36,6 +36,14 @@ FLOAT_TOLERANCE = 1e-6 KERNEL_SIZE = 1 # keep float error during fake-quantization under control BIT_WIDTHS = range(2, 9) +ACCUMULATOR_BIT_WIDTH_FOR_TESTS = 16 + + +# For testing purpose, we create a custom quantizer with a reduced bitwidth for the accumulator +class A2QWeightQuantizerForTests(Int8AccumulatorAwareWeightQuant): + accumulator_bit_width = ACCUMULATOR_BIT_WIDTH_FOR_TESTS + + WBIOL_QUANTIZERS = { 'asymmetric_per_tensor_float': (ShiftedUint8WeightPerTensorFloat, ShiftedUint8ActPerTensorFloat), @@ -43,7 +51,7 @@ 'asymmetric_per_channel_float': (ShiftedUint8WeightPerChannelFloat, ShiftedUint8ActPerTensorFloat), 'symmetric_per_channel_float': (Int8WeightPerChannelFloat, Int8ActPerTensorFloat), - 'a2q': (Int8AccumulatorAwareWeightQuant, Int8ActPerTensorFloat), + 'a2q': (A2QWeightQuantizerForTests, Int8ActPerTensorFloat), 'symmetric_per_tensor_fixed_point': (Int8WeightPerTensorFixedPoint, Int8ActPerTensorFixedPoint), 'symmetric_per_channel_fixed_point': (Int8WeightPerChannelFixedPoint, Int8ActPerTensorFixedPoint)}