Skip to content

Commit

Permalink
GPFQ support
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jul 25, 2023
1 parent 4d0852d commit a85f5f9
Show file tree
Hide file tree
Showing 4 changed files with 404 additions and 111 deletions.
8 changes: 7 additions & 1 deletion src/brevitas/graph/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

from abc import ABC
from copy import deepcopy
from functools import partial
import sys

Expand Down Expand Up @@ -279,11 +280,16 @@ def forward_hook_wbiol(self, module, inp, output, name):
# Compute float reference
self.disable_act_quantization(module, is_training=False)
self.disable_param_quantization(module, is_training=False)
quant_weight = dict()
if hasattr(module, 'weight_orig_data'):
quant_weight[module] = deepcopy(module.weight.data)
module.weight.data = module.weight_orig_data
out_float = module.forward(*inp) # Required to avoid infinite recursion
self.collect_float_mean(module, out_float, name)
self.enable_act_quantization(module, is_training=False)
self.enable_param_quantization(module, is_training=False)

for module, value in quant_weight.items():
module.weight.data = value
# Compute quant output
# We need to disable output_quant while out_quant is being computed
# or we are going to apply bias correction on post quant values instead of pre quant
Expand Down
Loading

0 comments on commit a85f5f9

Please sign in to comment.