Skip to content

Commit

Permalink
Feat (nn): avoid computing output scale/zp when not needed
Browse files Browse the repository at this point in the history
  • Loading branch information
volcacius committed Jul 4, 2023
1 parent 5e2d00a commit fd4fb20
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions src/brevitas/nn/quant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,12 +310,17 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
quant_input = self.input_quant(inp)
quant_weight = self.quant_weight(quant_input)

if quant_input.bit_width is not None and quant_weight.bit_width is not None:
output_bit_width = self.max_acc_bit_width(quant_input.bit_width, quant_weight.bit_width)
if quant_input.scale is not None and quant_weight.scale is not None:
output_scale = self.quant_output_scale_impl(inp, quant_input.scale, quant_weight.scale)
if quant_input.signed is not None:
output_signed = inp.signed or quant_weight.signed
if (self.return_quant_tensor or
(self.is_bias_quant_enabled and
(self.bias_quant.requires_input_scale or self.bias_quant.requires_input_bit_width))):
if quant_input.bit_width is not None and quant_weight.bit_width is not None:
output_bit_width = self.max_acc_bit_width(
quant_input.bit_width, quant_weight.bit_width)
if quant_input.scale is not None and quant_weight.scale is not None:
output_scale = self.quant_output_scale_impl(
inp, quant_input.scale, quant_weight.scale)
if quant_input.signed is not None:
output_signed = inp.signed or quant_weight.signed

if self.bias is not None:
quant_bias = self.bias_quant(self.bias, output_scale, output_bit_width)
Expand Down

0 comments on commit fd4fb20

Please sign in to comment.