Skip to content

Commit

Permalink
Fix ORT tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Feb 2, 2024
1 parent cec372c commit 687230b
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions tests/brevitas_ort/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat
from brevitas.quant_tensor import QuantTensor
from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat

SEED = 123456
Expand Down Expand Up @@ -116,10 +117,15 @@ def is_brevitas_ort_close(
input_t = torch.from_numpy(np_input)
with torch.no_grad():
brevitas_output = model(input_t)
computed_out = brevitas_output.value
if isinstance(brevitas_output, QuantTensor):
computed_out = brevitas_output.value
scale = brevitas_output.scale
else:
computed_out = brevitas_output
scale = 1.

if tolerance is not None and export_type == 'qcdq':
tolerance = tolerance * brevitas_output.scale # Float Output, tolerance is +/- output scale
tolerance = tolerance * scale # Float Output, tolerance is +/- output scale

if export_type == 'qonnx':
exported_model = export_qonnx(model, input_t, export_path=export_name)
Expand Down

0 comments on commit 687230b

Please sign in to comment.