diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index 86ef58b77..b22994275 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -4,6 +4,7 @@ import math from hypothesis import given +import pytest_cases from pytest_cases import fixture import torch import torch.nn as nn @@ -13,14 +14,23 @@ from brevitas.graph.calibrate import load_quant_model_mode import brevitas.nn as qnn from brevitas.quant import Int8ActPerTensorFixedPoint +from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat from brevitas.quant.scaled_int import Int8ActPerTensorFloat # Use custom implementation of kthvalue as work around to (b)float16 kernel limitations from brevitas.utils.torch_utils import kthvalue from tests.brevitas.hyp_helper import float_tensor_random_size_st +from tests.conftest import SEED +torch.manual_seed(SEED) IN_CH = 8 OUT_CH = 16 BATCH = 1 +REFERENCE_SCALES = { + 'int_quant': (0.00935234408825635910, 0.01362917013466358185), + 'fp_quant': (0.00249395845457911491, 0.00363444536924362183)} +REFERENCE_INP = torch.tensor([[-1.8645, -0.4071, 1.1971]]) +REFERENCE_WEIGHTS = torch.tensor([[1.0023, 0.0205, 1.4604], [-0.2918, -1.8218, -0.7010], + [1.4573, -0.9074, -0.2708]]) def compute_quantile(x, q): @@ -65,6 +75,41 @@ def forward(self, x): assert torch.allclose(expected_scale, scale) +QUANTS = {'int_quant': Int8ActPerTensorFloat, 'fp_quant': Fp8e4m3ActPerTensorFloat} + + +@pytest_cases.parametrize("act_quant", QUANTS.items(), ids=QUANTS.keys()) +def test_scale_factors_ptq_calibration_reference(act_quant): + + reference, act_quant = act_quant + + class TestModel(nn.Module): + + def __init__(self): + super(TestModel, self).__init__() + self.act = qnn.QuantReLU(act_quant=act_quant) + self.linear_weights = REFERENCE_WEIGHTS + self.act_1 = qnn.QuantIdentity(act_quant=act_quant) + + def forward(self, x): + o = self.act(x) + o = torch.matmul(o, self.linear_weights) + return self.act_1(o) + + # Reference input + inp = REFERENCE_INP + model = TestModel() + model.eval() + with torch.no_grad(): + with calibration_mode(model): + model(inp) + + computed_scale = model.act.act_quant.scale(), model.act_1.act_quant.scale() + reference_values = REFERENCE_SCALES[reference] + assert torch.allclose(computed_scale[0], torch.tensor(reference_values[0])) + assert torch.allclose(computed_scale[1], torch.tensor(reference_values[1])) + + def test_calibration_training_state(): class TestModel(nn.Module):