From d906304e721379d759b0e67d2a9aa6b4c7fa0219 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 18 Jul 2023 09:25:06 +0100 Subject: [PATCH] Fix (fx): fix fx quantize for conv->bn --- src/brevitas/graph/quantize_impl.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/brevitas/graph/quantize_impl.py b/src/brevitas/graph/quantize_impl.py index 4b2eb2f79..d2e1c5c29 100644 --- a/src/brevitas/graph/quantize_impl.py +++ b/src/brevitas/graph/quantize_impl.py @@ -49,6 +49,7 @@ MAX_RESIDUAL_ITERS = 9999 +BATCH_NORM = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d) def inp_placeholder_handler(model, input_quantizer): """ @@ -187,6 +188,18 @@ def output_quant_handler( user_module = get_module(model, user.target) if hasattr(user_module, 'act_quant'): output_quant = False + elif isinstance(user_module, BATCH_NORM): + # If the user is BatchNorm, check BN's users and potentially requentize at + # the output of BN + output_quant = False + output_quant_handler( + model, + user, + rewriters, + is_sign_preserving, + quant_identity_map, + quant_act_map, + unsigned_act_tuple) if output_quant: if quant_module_name is None and quant_module is None: if is_sign_preserving and are_inputs_unsigned(