Skip to content

Commit

Permalink
Fix (export/qonnx): shape propagation in custom ops
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 26, 2023
1 parent 740d9d4 commit 3df5396
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/brevitas/export/onnx/qonnx/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
from torch.autograd import Function
from torch.onnx.symbolic_helper import _get_tensor_sizes

from brevitas.core.bit_width import BitWidthConst
from brevitas.core.function_wrapper.clamp import TensorClamp
Expand All @@ -19,6 +20,7 @@ class BrevitasBinaryQuantFn(Function):
@staticmethod
def symbolic(g, x, scale, zero_point, bit_width, narrow_range, signed, rounding_mode):
ret = g.op(f'{DOMAIN_STRING}::BipolarQuant', x, scale)
ret.setType(x.type())
return ret

@staticmethod
Expand All @@ -40,6 +42,7 @@ def symbolic(g, x, scale, zero_point, bit_width, narrow_range, signed, rounding_
rounding_mode_s=rounding_mode,
signed_i=int(signed),
narrow_i=int(narrow_range))
ret.setType(x.type())
return ret

@staticmethod
Expand All @@ -66,6 +69,7 @@ def symbolic(g, x, scale, zero_point, input_bit_width, output_bit_width, roundin
input_bit_width,
output_bit_width,
rounding_mode_s=rounding_mode)
ret.setType(x.type())
return ret

@staticmethod
Expand Down

0 comments on commit 3df5396

Please sign in to comment.