diff --git a/tests/brevitas_ort/common.py b/tests/brevitas_ort/common.py index e01596dc9..fa324f0be 100644 --- a/tests/brevitas_ort/common.py +++ b/tests/brevitas_ort/common.py @@ -27,6 +27,7 @@ from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat +from brevitas.quant_tensor import QuantTensor from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat SEED = 123456 @@ -116,10 +117,15 @@ def is_brevitas_ort_close( input_t = torch.from_numpy(np_input) with torch.no_grad(): brevitas_output = model(input_t) - computed_out = brevitas_output.value + if isinstance(brevitas_output, QuantTensor): + computed_out = brevitas_output.value + scale = brevitas_output.scale + else: + computed_out = brevitas_output + scale = 1. if tolerance is not None and export_type == 'qcdq': - tolerance = tolerance * brevitas_output.scale # Float Output, tolerance is +/- output scale + tolerance = tolerance * scale # Float Output, tolerance is +/- output scale if export_type == 'qonnx': exported_model = export_qonnx(model, input_t, export_path=export_name)