From d762c991a0a77f9b37a90513ac97a6887188c8c9 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 29 May 2024 13:54:52 +0100 Subject: [PATCH] Last fix --- tests/brevitas/export/test_onnx_fp8.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/brevitas/export/test_onnx_fp8.py b/tests/brevitas/export/test_onnx_fp8.py index b460bff24..140485256 100644 --- a/tests/brevitas/export/test_onnx_fp8.py +++ b/tests/brevitas/export/test_onnx_fp8.py @@ -10,8 +10,10 @@ import brevitas.nn as qnn from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat +from tests.marker import jit_disabled_for_mock +@jit_disabled_for_mock() def test_simple_fp8_export(): if torch_version < version.parse('2.1.0'): pytest.skip(f"OCP FP8 types not supported by {torch_version}") @@ -21,6 +23,7 @@ def test_simple_fp8_export(): assert True +@jit_disabled_for_mock() def test_fp8_export_activation(): if torch_version < version.parse('2.1.0'): pytest.skip(f"OCP FP8 types not supported by {torch_version}") @@ -30,6 +33,7 @@ def test_fp8_export_activation(): assert True +@jit_disabled_for_mock() def test_fp8_export_export_activation(): if torch_version < version.parse('2.1.0'): pytest.skip(f"OCP FP8 types not supported by {torch_version}") @@ -38,9 +42,3 @@ def test_fp8_export_export_activation(): 3, 16, weight_quant=Fp8e4m3OCPWeightPerTensorFloat, input_quant=Fp8e4m3OCPActPerTensorFloat) export_onnx_qcdq(model, torch.randn(1, 3), 'weight_act_fp8.onnx', export_weight_q_node=True) assert True - - -if __name__ == "__main__": - #test_fp8_export_activation() - test_fp8_export_export_activation() - print("Done")