Skip to content

Commit

Permalink
Fix (ptq/bias_correction): remove unnecessary forward pass (#980)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Jul 14, 2024
1 parent f7d634d commit 6f752a3
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 9 deletions.
18 changes: 12 additions & 6 deletions src/brevitas/graph/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,23 @@ def __init__(self, model, enabled=True, skip_if_no_bias=False):
self.bias_correction = _BiasCorrection(skip_if_no_bias=skip_if_no_bias)
self.enabled = enabled
self.hooks = []
self.output_quant_modules = []

def __enter__(self):
if self.enabled:
for module in self.model.modules():
# Disable output quant so that the bias correction can be merged in the bias
if hasattr(module, 'output_quant') and module.output_quant.is_quant_enabled:
self.bias_correction.disable_act_quantization(
module.output_quant, is_training=False)
self.output_quant_modules.append(module)
self.bias_correction.register_hook_to_wbiol(self.model, self.hooks)

def __exit__(self, type, value, traceback):
self.bias_correction.apply_correction(self.model)
for module in self.output_quant_modules:
# Re-enable output quantization
self.bias_correction.enable_act_quantization(module.output_quant, is_training=False)
for hook in self.hooks:
hook.remove()

Expand Down Expand Up @@ -339,14 +349,10 @@ def forward_hook_wbiol(self, module, inp, output, name):
self.collect_float_mean(module, out_float, name)
self.enable_act_quantization(module, is_training=False)
self.enable_param_quantization(module, is_training=False)

# Compute quant output
# We need to disable output_quant while out_quant is being computed
# or we are going to apply bias correction on post quant values instead of pre quant
# Keep output quant disabled until further notice
self.disable_act_quantization(module.output_quant, is_training=False)
out_quant = module.forward(*inp) # Required to avoid infinite recursion
out_quant = output
self.compute_correct_bias(module, out_quant, name)
self.enable_act_quantization(module.output_quant, is_training=False)
self.iterations[name] += 1
return out_float

Expand Down
16 changes: 13 additions & 3 deletions tests/brevitas/graph/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from brevitas.graph.calibrate import load_quant_model_mode
import brevitas.nn as qnn
from brevitas.quant import Int8ActPerTensorFixedPoint
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
# Use custom implementation of kthvalue as work around to (b)float16 kernel limitations
from brevitas.utils.torch_utils import kthvalue
from tests.brevitas.hyp_helper import float_tensor_random_size_st
Expand Down Expand Up @@ -108,8 +109,9 @@ class MyQuantModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.module_list = nn.ModuleList([
qnn.QuantLinear(IN_CH, OUT_CH, bias=False),
qnn.QuantLinear(OUT_CH, OUT_CH, bias=False)])
qnn.QuantLinear(IN_CH, OUT_CH, bias=False, output_quant=Int8ActPerTensorFloat),
qnn.QuantLinear(OUT_CH, OUT_CH, bias=False,
output_quant=Int8ActPerTensorFloat)])

def forward(self, inp):
out_0 = self.module_list[0](inp)
Expand Down Expand Up @@ -138,18 +140,26 @@ def test_bias_correction_results(self, models):
error = torch.zeros(num_layers, OUT_CH)

# Reference Implementation of bias correction
quant_model.module_list[0].output_quant.disable_quant = True
quant_model.module_list[1].output_quant.disable_quant = True
for b, inp in enumerate(inp_list):
fp_outs[b, :, :] = fp_model(inp)

quant_outs[b, 0, :] = quant_model.module_list[0](inp)

quant_outs[b, 1, :] = quant_model.module_list[1](
fp_outs[b, 0, :]) # The second layer takes as input the "corrected" output
error += fp_outs[b] - quant_outs[b]
quant_model.module_list[0].output_quant.disable_quant = False
quant_model.module_list[1].output_quant.disable_quant = False

with bias_correction_mode(quant_model):
assert not quant_model.module_list[0].output_quant.is_quant_enabled
assert not quant_model.module_list[1].output_quant.is_quant_enabled
for inp in inp_list:
quant_model(inp)

assert quant_model.module_list[0].output_quant.is_quant_enabled
assert quant_model.module_list[1].output_quant.is_quant_enabled
assert quant_model.module_list[0].bias is not None
assert quant_model.module_list[1].bias is not None
assert torch.allclose(quant_model.module_list[0].bias, error[0] / len(inp_list))
Expand Down

0 comments on commit 6f752a3

Please sign in to comment.