From b3edb258ae582c570f42bf2e8f29d4c61ddd3f01 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Thu, 25 Apr 2024 09:59:10 -0400 Subject: [PATCH] Quantization Fixes (#35) * initial fix * fix for counter * using aminmax to fix discrepencies * minor improvements * remove prints --- .../quantization/observers/helpers.py | 21 +++++----- .../quantization/observers/memoryless.py | 7 ++-- .../quantization/observers/min_max.py | 38 ++++++++++--------- 3 files changed, 35 insertions(+), 31 deletions(-) diff --git a/src/compressed_tensors/quantization/observers/helpers.py b/src/compressed_tensors/quantization/observers/helpers.py index d0fca813..f548dba3 100644 --- a/src/compressed_tensors/quantization/observers/helpers.py +++ b/src/compressed_tensors/quantization/observers/helpers.py @@ -33,20 +33,21 @@ def calculate_qparams( :param quantization_args: settings to quantization :return: tuple of the calculated scale(s) and zero point(s) """ + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + bit_range = 2**quantization_args.num_bits - 1 bit_min = -(bit_range + 1) / 2 + bit_max = bit_min + bit_range if quantization_args.symmetric: - symmetric_range = 2 * max(min_vals.abs(), max_vals.abs()) - scales = symmetric_range / bit_range zero_points = torch.tensor(0).to(torch.int8) + max_val_pos = torch.max(-min_vals, max_vals) + scales = max_val_pos / (float(bit_range) / 2) + scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) else: - # non-symmetric - observed_range = max_vals - min_vals - scales = observed_range / bit_range - - # scales from a 0 range should be set to 1 - scales[observed_range == 0] = 1 - - zero_points = torch.round(((0.0 - min_vals) / scales + bit_min)).to(torch.int8) + scales = (max_vals - min_vals) / float(bit_range) + scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) + zero_points = bit_min - torch.round(min_vals / scales) + zero_points = torch.clamp(zero_points, bit_min, bit_max).to(torch.int8) return scales, zero_points diff --git a/src/compressed_tensors/quantization/observers/memoryless.py b/src/compressed_tensors/quantization/observers/memoryless.py index 04026807..f5400675 100644 --- a/src/compressed_tensors/quantization/observers/memoryless.py +++ b/src/compressed_tensors/quantization/observers/memoryless.py @@ -27,18 +27,19 @@ class MemorylessObserver(Observer): """ Implements a dynamic quantization observer that sets the scale and - zero point based on the latest observed value + zero point based on the latest observed value without tracking state """ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: """ + Returns the min and max values of observed + :param observed: observed tensor to calculate quantization parameters for :return: tuple of scale and zero point derived from the observed tensor """ # TODO: Add support for full range of quantization Args, only supports 8bit # per tensor - min_val = observed.min() - max_val = observed.max() + min_val, max_val = torch.aminmax(observed) # ensure zero is in the range min_val = torch.min(min_val, torch.zeros_like(min_val)) diff --git a/src/compressed_tensors/quantization/observers/min_max.py b/src/compressed_tensors/quantization/observers/min_max.py index 3496bb77..de8735ed 100644 --- a/src/compressed_tensors/quantization/observers/min_max.py +++ b/src/compressed_tensors/quantization/observers/min_max.py @@ -21,43 +21,45 @@ from torch import FloatTensor, IntTensor, Tensor -__all__ = ["MinMaxObserver"] +__all__ = ["MovingAverageMinMaxObserver"] @Observer.register("minmax") -class MinMaxObserver(Observer): +class MovingAverageMinMaxObserver(Observer): """ Implements a dynamic quantization observer that sets the scale and - zero point based on the overall min and max value + zero point based on a moving average of the overall min and max observed values """ - def __init__(self, quantization_args: QuantizationArgs): + def __init__( + self, quantization_args: QuantizationArgs, averaging_constant: float = 0.01 + ): super().__init__(quantization_args=quantization_args) self.min_val = float("inf") self.max_val = -float("inf") - self.counter = 0 + self.averaging_constant = averaging_constant def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: """ + Updates the observed min and max using a moving average smoothed by the + averaging_constant + :param observed: observed tensor to calculate quantization parameters for :return: tuple of scale and zero point derived from the observed tensor """ - min_val = torch.tensor([observed.min()]) - max_val = torch.tensor([observed.max()]) + min_val, max_val = torch.aminmax(observed) - # update global min and max - if self.counter > 0: - self.min_val = torch.min(min_val, self.min_val) - self.max_val = torch.max(max_val, self.max_val) - else: + if self.min_val == float("inf") and self.max_val == float("-inf"): self.min_val = min_val self.max_val = max_val + else: + self.min_val = self.min_val + self.averaging_constant * ( + min_val - self.min_val + ) + self.max_val = self.max_val + self.averaging_constant * ( + max_val - self.max_val + ) - # ensure that the zeros are in the range - min_val = torch.min(self.min_val, torch.zeros_like(self.min_val)) - max_val = torch.max(self.max_val, torch.zeros_like(self.max_val)) - - self.counter += 1 - return calculate_qparams(min_val, max_val, self.quantization_args) + return calculate_qparams(self.min_val, self.max_val, self.quantization_args)