Skip to content

Commit

Permalink
fixed missing parts in test code causing test failures
Browse files Browse the repository at this point in the history
  • Loading branch information
costigt-dev committed Feb 1, 2024
1 parent 5ad14d4 commit 382c50b
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 6 deletions.
3 changes: 2 additions & 1 deletion tests/brevitas/export/quant_module_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@
QuantConv3d,
QuantConvTranspose1d,
QuantConvTranspose2d,
QuantConvTranspose3d,]
QuantConvTranspose3d,
]
BIT_WIDTHS = [4, 8, 10] # below 8, equal 8, above 8
BIAS_BIT_WIDTHS = [8, 16, 32]

Expand Down
4 changes: 3 additions & 1 deletion tests/brevitas/export/test_torch_qcdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ def test_torch_qcdq_wbiol_export(
in_size = (1, IN_CH)
elif quant_module_impl == QuantConv1d or quant_module_impl == QuantConvTranspose1d:
in_size = (1, IN_CH, FEATURES)
else:
elif quant_module_impl == QuantConv2d or quant_module_impl == QuantConvTranspose2d:
in_size = (1, IN_CH, FEATURES, FEATURES)
else:
in_size = (1, IN_CH, FEATURES, FEATURES, FEATURES)

inp = torch.randn(in_size)
quant_module(inp) # Collect scale factors
Expand Down
7 changes: 5 additions & 2 deletions tests/brevitas/nn/nn_quantizers_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@
QuantConv3d,
QuantConvTranspose1d,
QuantConvTranspose2d,
QuantConvTranspose3d,]
QuantConvTranspose3d,
]

ACC_BIT_WIDTHS = [8, 9, 10, 12, 16, 24, 32]

Expand Down Expand Up @@ -161,8 +162,10 @@ def forward(self, x):
in_size = (1, IN_CH)
elif impl in ('QuantConv1d', 'QuantConvTranspose1d'):
in_size = (1, IN_CH, FEATURES)
else:
elif impl in ('QuantConv2d', 'QuantConvTranspose2d'):
in_size = (1, IN_CH, FEATURES, FEATURES)
else:
in_size = (1, IN_CH, FEATURES, FEATURES, FEATURES)

if input_quantized:
quant_inp = QuantTensor(
Expand Down
3 changes: 2 additions & 1 deletion tests/brevitas/nn/test_wbiol.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
QuantConv3d,
QuantConvTranspose1d,
QuantConvTranspose2d,
QuantConvTranspose3d,]
QuantConvTranspose3d,
]


@pytest_cases.fixture()
Expand Down
3 changes: 2 additions & 1 deletion tests/brevitas_ort/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ class A2QWeightQuantizerForTests(Int8AccumulatorAwareWeightQuant):
QuantConv3d,
QuantConvTranspose1d,
QuantConvTranspose2d,
QuantConvTranspose3d,]
QuantConvTranspose3d,
]


def compute_ort(export_name, np_input):
Expand Down

0 comments on commit 382c50b

Please sign in to comment.