From 382c50b35724110159572eaa1b05761f720dae0b Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Thu, 1 Feb 2024 16:06:32 +0000 Subject: [PATCH] fixed missing parts in test code causing test failures --- tests/brevitas/export/quant_module_fixture.py | 3 ++- tests/brevitas/export/test_torch_qcdq.py | 4 +++- tests/brevitas/nn/nn_quantizers_fixture.py | 7 +++++-- tests/brevitas/nn/test_wbiol.py | 3 ++- tests/brevitas_ort/common.py | 3 ++- 5 files changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/brevitas/export/quant_module_fixture.py b/tests/brevitas/export/quant_module_fixture.py index 31524729f..9e61c555b 100644 --- a/tests/brevitas/export/quant_module_fixture.py +++ b/tests/brevitas/export/quant_module_fixture.py @@ -61,7 +61,8 @@ QuantConv3d, QuantConvTranspose1d, QuantConvTranspose2d, - QuantConvTranspose3d,] + QuantConvTranspose3d, + ] BIT_WIDTHS = [4, 8, 10] # below 8, equal 8, above 8 BIAS_BIT_WIDTHS = [8, 16, 32] diff --git a/tests/brevitas/export/test_torch_qcdq.py b/tests/brevitas/export/test_torch_qcdq.py index 4f737f8d7..6019bf417 100644 --- a/tests/brevitas/export/test_torch_qcdq.py +++ b/tests/brevitas/export/test_torch_qcdq.py @@ -36,8 +36,10 @@ def test_torch_qcdq_wbiol_export( in_size = (1, IN_CH) elif quant_module_impl == QuantConv1d or quant_module_impl == QuantConvTranspose1d: in_size = (1, IN_CH, FEATURES) - else: + elif quant_module_impl == QuantConv2d or quant_module_impl == QuantConvTranspose2d: in_size = (1, IN_CH, FEATURES, FEATURES) + else: + in_size = (1, IN_CH, FEATURES, FEATURES, FEATURES) inp = torch.randn(in_size) quant_module(inp) # Collect scale factors diff --git a/tests/brevitas/nn/nn_quantizers_fixture.py b/tests/brevitas/nn/nn_quantizers_fixture.py index 4d4983ba1..95c640de8 100644 --- a/tests/brevitas/nn/nn_quantizers_fixture.py +++ b/tests/brevitas/nn/nn_quantizers_fixture.py @@ -96,7 +96,8 @@ QuantConv3d, QuantConvTranspose1d, QuantConvTranspose2d, - QuantConvTranspose3d,] + QuantConvTranspose3d, + ] ACC_BIT_WIDTHS = [8, 9, 10, 12, 16, 24, 32] @@ -161,8 +162,10 @@ def forward(self, x): in_size = (1, IN_CH) elif impl in ('QuantConv1d', 'QuantConvTranspose1d'): in_size = (1, IN_CH, FEATURES) - else: + elif impl in ('QuantConv2d', 'QuantConvTranspose2d'): in_size = (1, IN_CH, FEATURES, FEATURES) + else: + in_size = (1, IN_CH, FEATURES, FEATURES, FEATURES) if input_quantized: quant_inp = QuantTensor( diff --git a/tests/brevitas/nn/test_wbiol.py b/tests/brevitas/nn/test_wbiol.py index 58b9a86ca..9df9faa4d 100644 --- a/tests/brevitas/nn/test_wbiol.py +++ b/tests/brevitas/nn/test_wbiol.py @@ -33,7 +33,8 @@ QuantConv3d, QuantConvTranspose1d, QuantConvTranspose2d, - QuantConvTranspose3d,] + QuantConvTranspose3d, + ] @pytest_cases.fixture() diff --git a/tests/brevitas_ort/common.py b/tests/brevitas_ort/common.py index 4c148e96b..ef46cfcbf 100644 --- a/tests/brevitas_ort/common.py +++ b/tests/brevitas_ort/common.py @@ -77,7 +77,8 @@ class A2QWeightQuantizerForTests(Int8AccumulatorAwareWeightQuant): QuantConv3d, QuantConvTranspose1d, QuantConvTranspose2d, - QuantConvTranspose3d,] + QuantConvTranspose3d, + ] def compute_ort(export_name, np_input):