From 687230b54575836157a95473edacf76e0598fad1 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 2 Feb 2024 14:09:31 +0000 Subject: [PATCH] Fix ORT tests --- tests/brevitas_ort/common.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/brevitas_ort/common.py b/tests/brevitas_ort/common.py index e01596dc9..fa324f0be 100644 --- a/tests/brevitas_ort/common.py +++ b/tests/brevitas_ort/common.py @@ -27,6 +27,7 @@ from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat +from brevitas.quant_tensor import QuantTensor from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat SEED = 123456 @@ -116,10 +117,15 @@ def is_brevitas_ort_close( input_t = torch.from_numpy(np_input) with torch.no_grad(): brevitas_output = model(input_t) - computed_out = brevitas_output.value + if isinstance(brevitas_output, QuantTensor): + computed_out = brevitas_output.value + scale = brevitas_output.scale + else: + computed_out = brevitas_output + scale = 1. if tolerance is not None and export_type == 'qcdq': - tolerance = tolerance * brevitas_output.scale # Float Output, tolerance is +/- output scale + tolerance = tolerance * scale # Float Output, tolerance is +/- output scale if export_type == 'qonnx': exported_model = export_qonnx(model, input_t, export_path=export_name)