Skip to content

Commit

Permalink
Last fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed May 29, 2024
1 parent 4b78543 commit d762c99
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions tests/brevitas/export/test_onnx_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
import brevitas.nn as qnn
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat
from tests.marker import jit_disabled_for_mock


@jit_disabled_for_mock()
def test_simple_fp8_export():
if torch_version < version.parse('2.1.0'):
pytest.skip(f"OCP FP8 types not supported by {torch_version}")
Expand All @@ -21,6 +23,7 @@ def test_simple_fp8_export():
assert True


@jit_disabled_for_mock()
def test_fp8_export_activation():
if torch_version < version.parse('2.1.0'):
pytest.skip(f"OCP FP8 types not supported by {torch_version}")
Expand All @@ -30,6 +33,7 @@ def test_fp8_export_activation():
assert True


@jit_disabled_for_mock()
def test_fp8_export_export_activation():
if torch_version < version.parse('2.1.0'):
pytest.skip(f"OCP FP8 types not supported by {torch_version}")
Expand All @@ -38,9 +42,3 @@ def test_fp8_export_export_activation():
3, 16, weight_quant=Fp8e4m3OCPWeightPerTensorFloat, input_quant=Fp8e4m3OCPActPerTensorFloat)
export_onnx_qcdq(model, torch.randn(1, 3), 'weight_act_fp8.onnx', export_weight_q_node=True)
assert True


if __name__ == "__main__":
#test_fp8_export_activation()
test_fp8_export_export_activation()
print("Done")

0 comments on commit d762c99

Please sign in to comment.