Skip to content

Commit

Permalink
Fix (tests): qonnx moved in brevitas_ort
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 14, 2023
1 parent baccb2c commit b3b9941
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 132 deletions.
1 change: 1 addition & 0 deletions requirements/requirements-ort-integration.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
onnx
onnxoptimizer
onnxruntime
qonnx
115 changes: 0 additions & 115 deletions tests/brevitas_finn/brevitas/test_wbiol.py

This file was deleted.

43 changes: 27 additions & 16 deletions tests/brevitas_ort/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import numpy as np
import onnxruntime as ort
import pytest
from qonnx.core.modelwrapper import ModelWrapper
import qonnx.core.onnx_exec as oxe
from qonnx.transformation.infer_shapes import InferShapes
import torch

from brevitas.export import export_onnx_qcdq
Expand All @@ -17,10 +20,10 @@
from brevitas.nn import QuantLinear
from brevitas.nn import QuantLSTM
from brevitas.nn import TruncAvgPool2d
from brevitas.quant.fixed_point import Int8AccumulatorAwareWeightQuant
from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint
from brevitas.quant.fixed_point import Int8WeightPerChannelFixedPoint
from brevitas.quant.fixed_point import Int8WeightPerTensorFixedPoint
from brevitas.quant.scaled_int import Int8AccumulatorAwareWeightQuant
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
from brevitas.quant.scaled_int import Int8WeightPerChannelFloat
from brevitas.quant.scaled_int import Int8WeightPerTensorFloat
Expand Down Expand Up @@ -105,31 +108,39 @@ def is_brevitas_ort_close(
input_t = torch.from_numpy(np_input)
brevitas_output = model(input_t)

if export_type == 'qop':
export_onnx_qop(model, input_t, export_path=export_name)
brevitas_output = brevitas_output.int(float_datatype=False)
elif export_type == 'qcdq':
export_onnx_qcdq(model, input_t, export_path=export_name)
elif export_type == 'qcdq_opset14':
export_onnx_qcdq(model, input_t, opset_version=14, export_path=export_name)
elif export_type == 'qonnx_opset14':
export_qonnx(model, input_t, opset_version=14, export_path=export_name)
else:
raise RuntimeError(f"Export type {export_type} not recognized.")

if tolerance is not None and export_type == 'qcdq':
tolerance = tolerance * brevitas_output.scale # Float Output, tolerance is +/- output scale

ort_output = compute_ort(export_name, np_input)
if export_type == 'qonnx':
exported_model = export_qonnx(model, input_t, export_path=export_name)
exported_model = ModelWrapper(exported_model)
exported_model = exported_model.transform(InferShapes())
idict = {exported_model.graph.input[0].name: np_input}
odict = oxe.execute_onnx(exported_model, idict, True)
ort_output = odict[exported_model.graph.output[0].name]
else:
if export_type == 'qop':
export_onnx_qop(model, input_t, export_path=export_name)
brevitas_output = brevitas_output.int(float_datatype=False)
elif export_type == 'qcdq':
export_onnx_qcdq(model, input_t, export_path=export_name)
elif export_type == 'qcdq_opset14':
export_onnx_qcdq(model, input_t, opset_version=14, export_path=export_name)
elif export_type == 'qonnx_opset14':
export_qonnx(model, input_t, opset_version=14, export_path=export_name)
else:
raise RuntimeError(f"Export type {export_type} not recognized.")

ort_output = compute_ort(export_name, np_input)

if first_output_only:
if isinstance(ort_output, tuple):
if isinstance(ort_output, (tuple, list)):
ort_output = ort_output[0]
if isinstance(brevitas_output, tuple):
brevitas_output = brevitas_output[0]

# make sure we are not comparing 0s
if ort_output == 0 and (brevitas_output == 0).all():
if (ort_output == 0).all() and (brevitas_output == 0).all():
pytest.skip("Skip testing against all 0s.")

return recursive_allclose(ort_output, brevitas_output, tolerance)
Expand Down
3 changes: 2 additions & 1 deletion tests/brevitas_ort/test_quant_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from functools import reduce
from operator import mul
import os

import pytest
from pytest_cases import get_case_id
Expand All @@ -18,7 +19,7 @@


@parametrize_with_cases('model', cases=QuantWBIOLCases)
@pytest.mark.parametrize('export_type', ['qcdq', 'qop'])
@pytest.mark.parametrize('export_type', ['qcdq', 'qonnx', 'qop'])
@requires_pt_ge('1.8.1')
def test_ort_wbiol(model, export_type, current_cases):
cases_generator_func = current_cases['model'][1]
Expand Down

0 comments on commit b3b9941

Please sign in to comment.