From 2ce9fdf51c733d63b292907b02a0dd3a4ce101b8 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sun, 21 Jan 2024 21:43:07 +0000 Subject: [PATCH] Fix tests --- src/brevitas_examples/common/generative/quantizers.py | 2 -- tests/brevitas_ort/common.py | 4 +++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/brevitas_examples/common/generative/quantizers.py b/src/brevitas_examples/common/generative/quantizers.py index 8b71ab6b3..9913586c2 100644 --- a/src/brevitas_examples/common/generative/quantizers.py +++ b/src/brevitas_examples/common/generative/quantizers.py @@ -131,7 +131,6 @@ class Int8ActDynamicPerTensorFloat(ActDynamicProxyMixin, Int8ActPerTensorFloat): scaling_stats_input_view_shape_impl = OverTensorView scaling_stats_op = 'min_max' dynamic_scaling_broadcastable_shape = this.scaling_shape - stats_reduce_dim = 0 class Int8ActDynamicPerRowFloat(ActDynamicProxyMixin, Int8ActPerRowFloat): @@ -165,4 +164,3 @@ class ShiftedUint8ActDynamicPerTensorFloat(ActDynamicProxyMixin, ShiftedUint8Act scaling_stats_op = 'min_max' zero_point_stats_impl = NegativeMinOrZero dynamic_scaling_broadcastable_shape = this.scaling_shape - stats_reduce_dim = 0 diff --git a/tests/brevitas_ort/common.py b/tests/brevitas_ort/common.py index 84a44c4df..c08175edd 100644 --- a/tests/brevitas_ort/common.py +++ b/tests/brevitas_ort/common.py @@ -58,7 +58,9 @@ class A2QWeightQuantizerForTests(Int8AccumulatorAwareWeightQuant): 'a2q': (A2QWeightQuantizerForTests, Int8ActPerTensorFloat), 'symmetric_per_tensor_fixed_point': (Int8WeightPerTensorFixedPoint, Int8ActPerTensorFixedPoint), 'symmetric_per_channel_fixed_point': - (Int8WeightPerChannelFixedPoint, Int8ActPerTensorFixedPoint)} + (Int8WeightPerChannelFixedPoint, Int8ActPerTensorFixedPoint), + 'weight_symmetric_activation_dynamic_asymmetric_per_tensor_float': + (Int8WeightPerTensorFloat, ShiftedUint8ActDynamicPerTensorFloat)} LSTM_QUANTIZERS = { 'asymmetric_per_tensor_float': (ShiftedUint8WeightPerTensorFloat, ShiftedUint8ActPerTensorFloat),