Skip to content

Commit

Permalink
Quantization Fixes (#35)
Browse files Browse the repository at this point in the history
* initial fix

* fix for counter

* using aminmax to fix discrepencies

* minor improvements

* remove prints
  • Loading branch information
Satrat authored Apr 25, 2024
1 parent d4787e2 commit b3edb25
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 31 deletions.
21 changes: 11 additions & 10 deletions src/compressed_tensors/quantization/observers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,21 @@ def calculate_qparams(
:param quantization_args: settings to quantization
:return: tuple of the calculated scale(s) and zero point(s)
"""
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))

bit_range = 2**quantization_args.num_bits - 1
bit_min = -(bit_range + 1) / 2
bit_max = bit_min + bit_range
if quantization_args.symmetric:
symmetric_range = 2 * max(min_vals.abs(), max_vals.abs())
scales = symmetric_range / bit_range
zero_points = torch.tensor(0).to(torch.int8)
max_val_pos = torch.max(-min_vals, max_vals)
scales = max_val_pos / (float(bit_range) / 2)
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
else:
# non-symmetric
observed_range = max_vals - min_vals
scales = observed_range / bit_range

# scales from a 0 range should be set to 1
scales[observed_range == 0] = 1

zero_points = torch.round(((0.0 - min_vals) / scales + bit_min)).to(torch.int8)
scales = (max_vals - min_vals) / float(bit_range)
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
zero_points = bit_min - torch.round(min_vals / scales)
zero_points = torch.clamp(zero_points, bit_min, bit_max).to(torch.int8)

return scales, zero_points
7 changes: 4 additions & 3 deletions src/compressed_tensors/quantization/observers/memoryless.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,19 @@
class MemorylessObserver(Observer):
"""
Implements a dynamic quantization observer that sets the scale and
zero point based on the latest observed value
zero point based on the latest observed value without tracking state
"""

def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
"""
Returns the min and max values of observed
:param observed: observed tensor to calculate quantization parameters for
:return: tuple of scale and zero point derived from the observed tensor
"""
# TODO: Add support for full range of quantization Args, only supports 8bit
# per tensor
min_val = observed.min()
max_val = observed.max()
min_val, max_val = torch.aminmax(observed)

# ensure zero is in the range
min_val = torch.min(min_val, torch.zeros_like(min_val))
Expand Down
38 changes: 20 additions & 18 deletions src/compressed_tensors/quantization/observers/min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,43 +21,45 @@
from torch import FloatTensor, IntTensor, Tensor


__all__ = ["MinMaxObserver"]
__all__ = ["MovingAverageMinMaxObserver"]


@Observer.register("minmax")
class MinMaxObserver(Observer):
class MovingAverageMinMaxObserver(Observer):
"""
Implements a dynamic quantization observer that sets the scale and
zero point based on the overall min and max value
zero point based on a moving average of the overall min and max observed values
"""

def __init__(self, quantization_args: QuantizationArgs):
def __init__(
self, quantization_args: QuantizationArgs, averaging_constant: float = 0.01
):
super().__init__(quantization_args=quantization_args)

self.min_val = float("inf")
self.max_val = -float("inf")
self.counter = 0
self.averaging_constant = averaging_constant

def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
"""
Updates the observed min and max using a moving average smoothed by the
averaging_constant
:param observed: observed tensor to calculate quantization parameters for
:return: tuple of scale and zero point derived from the observed tensor
"""

min_val = torch.tensor([observed.min()])
max_val = torch.tensor([observed.max()])
min_val, max_val = torch.aminmax(observed)

# update global min and max
if self.counter > 0:
self.min_val = torch.min(min_val, self.min_val)
self.max_val = torch.max(max_val, self.max_val)
else:
if self.min_val == float("inf") and self.max_val == float("-inf"):
self.min_val = min_val
self.max_val = max_val
else:
self.min_val = self.min_val + self.averaging_constant * (
min_val - self.min_val
)
self.max_val = self.max_val + self.averaging_constant * (
max_val - self.max_val
)

# ensure that the zeros are in the range
min_val = torch.min(self.min_val, torch.zeros_like(self.min_val))
max_val = torch.max(self.max_val, torch.zeros_like(self.max_val))

self.counter += 1
return calculate_qparams(min_val, max_val, self.quantization_args)
return calculate_qparams(self.min_val, self.max_val, self.quantization_args)

0 comments on commit b3edb25

Please sign in to comment.