diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index ba4a474e2..8ec4be3aa 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -350,10 +350,9 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe output_tensor = self.inner_forward_impl( _unpack_quant_tensor(quant_input), _unpack_quant_tensor(quant_weight), None) - if not self.is_output_quant_enabled: + if not self.is_output_quant_enabled and self.return_quant_tensor: if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor): - if (quant_input.zero_point != 0.0 - ).any() or (quant_weight.zero_point != 0.0).any() and self.return_quant_tensor: + if (quant_input.zero_point != 0.0).any() or (quant_weight.zero_point != 0.0).any(): raise RuntimeError( "Computing zero point of output accumulator not supported yet.") elif output_zero_point is None: diff --git a/tests/brevitas/nn/test_nn_quantizers.py b/tests/brevitas/nn/test_nn_quantizers.py index b0db249af..bbee8daca 100644 --- a/tests/brevitas/nn/test_nn_quantizers.py +++ b/tests/brevitas/nn/test_nn_quantizers.py @@ -185,7 +185,13 @@ def test_quant_mha(model_input, current_cases): with pytest.raises(RuntimeError, match='Input scale required'): output, _ = model(inp, inp, inp) return - + elif kwargs['weight_quant'] is not None and kwargs['io_quant'] is None: + if kwargs['weight_quant'] == 'quant_asym' and kwargs['return_quant_tensor']: + with pytest.raises( + RuntimeError, + match='Computing zero point of output accumulator not supported yet.'): + output, _ = model(inp, inp, inp) + return output, _ = model(inp, inp, inp) if kwargs['return_quant_tensor']: