From bb324662ac5837eb95edb81385e9a7c5d0ffdb2f Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 24 Sep 2024 10:07:50 +0100 Subject: [PATCH 1/5] calibration with reference values --- tests/brevitas/graph/test_calibration.py | 42 ++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index 86ef58b77..1da67ff3a 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -4,6 +4,7 @@ import math from hypothesis import given +import pytest_cases from pytest_cases import fixture import torch import torch.nn as nn @@ -13,6 +14,7 @@ from brevitas.graph.calibrate import load_quant_model_mode import brevitas.nn as qnn from brevitas.quant import Int8ActPerTensorFixedPoint +from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat from brevitas.quant.scaled_int import Int8ActPerTensorFloat # Use custom implementation of kthvalue as work around to (b)float16 kernel limitations from brevitas.utils.torch_utils import kthvalue @@ -21,6 +23,10 @@ IN_CH = 8 OUT_CH = 16 BATCH = 1 +REFERENCE_SCALES = { + 'int_quant': (0.00935234408825635910, 0.00859776325523853302), + 'fp_quant': (0.00249395845457911491, 0.00190271728206425905)} +REFERNECE_INP = torch.tensor([[-1.8645, -0.4071, 1.1971]]) def compute_quantile(x, q): @@ -65,6 +71,42 @@ def forward(self, x): assert torch.allclose(expected_scale, scale) +QUANTS = {'int_quant': Int8ActPerTensorFloat, 'fp_quant': Fp8e4m3ActPerTensorFloat} + + +@pytest_cases.parametrize("act_quant", QUANTS.items(), ids=QUANTS.keys()) +def test_scale_factors_ptq_calibration_reference(act_quant): + + reference, act_quant = act_quant + + class TestModel(nn.Module): + + def __init__(self): + super(TestModel, self).__init__() + self.act = qnn.QuantReLU(act_quant=act_quant) + self.linear = qnn.QuantLinear(3, 8) + self.act_1 = qnn.QuantIdentity(act_quant=act_quant) + + def forward(self, x): + o = self.act(x) + o = self.linear(o) + return self.act_1(o) + + # Reference input + inp = REFERNECE_INP + model = TestModel() + model.eval() + with torch.no_grad(): + with calibration_mode(model): + model(inp) + + computed_scale = model.act.act_quant.scale(), model.act_1.act_quant.scale() + reference_values = REFERENCE_SCALES[reference] + assert all([ + torch.allclose(comp, torch.tensor(ref)) for comp, + ref in zip(computed_scale, reference_values)]) + + def test_calibration_training_state(): class TestModel(nn.Module): From 346c1552600fe0ab96c112e1979cf83b605e0a1f Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 24 Sep 2024 11:06:11 +0100 Subject: [PATCH 2/5] Cleaner test eval --- tests/brevitas/graph/test_calibration.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index 1da67ff3a..f48043558 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -102,9 +102,8 @@ def forward(self, x): computed_scale = model.act.act_quant.scale(), model.act_1.act_quant.scale() reference_values = REFERENCE_SCALES[reference] - assert all([ - torch.allclose(comp, torch.tensor(ref)) for comp, - ref in zip(computed_scale, reference_values)]) + assert torch.allclose(computed_scale[0], torch.tensor(reference_values[0])) + assert torch.allclose(computed_scale[1], torch.tensor(reference_values[1])) def test_calibration_training_state(): From 213dc29e1a8cdeaa4b2df2d7aafff6b713d0fbc5 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 24 Sep 2024 12:10:08 +0100 Subject: [PATCH 3/5] Setting seed --- tests/brevitas/graph/test_calibration.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index f48043558..1709f1c1b 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -19,7 +19,9 @@ # Use custom implementation of kthvalue as work around to (b)float16 kernel limitations from brevitas.utils.torch_utils import kthvalue from tests.brevitas.hyp_helper import float_tensor_random_size_st +from tests.conftest import SEED +torch.manual_seed(SEED) IN_CH = 8 OUT_CH = 16 BATCH = 1 From a0732640f3f98899d46c988d63f74f5ea8072110 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 24 Sep 2024 12:35:00 +0100 Subject: [PATCH 4/5] Reference impl --- tests/brevitas/graph/test_calibration.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index 1709f1c1b..95d6cc5d9 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -26,8 +26,8 @@ OUT_CH = 16 BATCH = 1 REFERENCE_SCALES = { - 'int_quant': (0.00935234408825635910, 0.00859776325523853302), - 'fp_quant': (0.00249395845457911491, 0.00190271728206425905)} + 'int_quant': (0.00935234408825635910, 0.01362917013466358185), + 'fp_quant': (0.00249395845457911491, 0.00363444536924362183)} REFERNECE_INP = torch.tensor([[-1.8645, -0.4071, 1.1971]]) @@ -86,12 +86,14 @@ class TestModel(nn.Module): def __init__(self): super(TestModel, self).__init__() self.act = qnn.QuantReLU(act_quant=act_quant) - self.linear = qnn.QuantLinear(3, 8) + self.linear_weights = torch.tensor([[1.0023, 0.0205, + 1.4604], [-0.2918, -1.8218, -0.7010], + [1.4573, -0.9074, -0.2708]]) self.act_1 = qnn.QuantIdentity(act_quant=act_quant) def forward(self, x): o = self.act(x) - o = self.linear(o) + o = torch.matmul(o, self.linear_weights) return self.act_1(o) # Reference input From 25b95fcac9fe546a5083fbd3b4396b6131fafe82 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 24 Sep 2024 13:01:40 +0100 Subject: [PATCH 5/5] Formatting --- tests/brevitas/graph/test_calibration.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index 95d6cc5d9..b22994275 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -28,7 +28,9 @@ REFERENCE_SCALES = { 'int_quant': (0.00935234408825635910, 0.01362917013466358185), 'fp_quant': (0.00249395845457911491, 0.00363444536924362183)} -REFERNECE_INP = torch.tensor([[-1.8645, -0.4071, 1.1971]]) +REFERENCE_INP = torch.tensor([[-1.8645, -0.4071, 1.1971]]) +REFERENCE_WEIGHTS = torch.tensor([[1.0023, 0.0205, 1.4604], [-0.2918, -1.8218, -0.7010], + [1.4573, -0.9074, -0.2708]]) def compute_quantile(x, q): @@ -86,9 +88,7 @@ class TestModel(nn.Module): def __init__(self): super(TestModel, self).__init__() self.act = qnn.QuantReLU(act_quant=act_quant) - self.linear_weights = torch.tensor([[1.0023, 0.0205, - 1.4604], [-0.2918, -1.8218, -0.7010], - [1.4573, -0.9074, -0.2708]]) + self.linear_weights = REFERENCE_WEIGHTS self.act_1 = qnn.QuantIdentity(act_quant=act_quant) def forward(self, x): @@ -97,7 +97,7 @@ def forward(self, x): return self.act_1(o) # Reference input - inp = REFERNECE_INP + inp = REFERENCE_INP model = TestModel() model.eval() with torch.no_grad():