Skip to content

Commit

Permalink
Fix (graph): fix for residual quantization with MHA (#681)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Jul 18, 2023
1 parent d2f3832 commit c1617b5
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions src/brevitas/graph/quantize_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ADD_FNS = [torch.add, operator.add, operator.iadd]

ADD_METHODS = ['add', 'add_']

CAT = brevitas.original_cat

SIGN_PRESERVING_MODULES = (
Expand Down Expand Up @@ -46,6 +47,8 @@
nn.PixelUnshuffle,
nn.Identity)

MAX_RESIDUAL_ITERS = 9999


def inp_placeholder_handler(model, input_quantizer):
"""
Expand Down Expand Up @@ -253,8 +256,7 @@ def recursive_input_handler(
else:
assert align_output is None, f"align_output {str(align_output)} not supported."
elif inp_node.op == 'call_function' and inp_node.target in [
torch.flatten, torch.reshape, torch.transpose, operator.getitem,
operator.__getitem__]:
torch.flatten, torch.reshape, torch.transpose]:
recursive_input_handler(
model,
inp_node,
Expand Down Expand Up @@ -307,6 +309,7 @@ def _get_quant_module(model, node, quant_identity_map, quant_act_map, unsigned_a

def residual_handler(
model, quant_identity_map, quant_act_map, unsigned_act_tuple, align_input_quant_fn):
iter = 0

def is_converged(model):

Expand Down Expand Up @@ -349,7 +352,11 @@ def is_converged(model):
return True

while not is_converged(model):
continue
iter += 1
if iter == MAX_RESIDUAL_ITERS:
raise RuntimeError(
"Residual handler could not find a solution to align scale factors "
"across ADDs and CATs")

return model

Expand Down

0 comments on commit c1617b5

Please sign in to comment.