Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Observers] group size + channel wise + per token #32

Merged
merged 25 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 79 additions & 8 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
# limitations under the License.

from functools import wraps
from math import ceil

import torch
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.quant_args import (
QuantizationArgs,
QuantizationStrategy,
)
from compressed_tensors.quantization.quant_config import QuantizationStatus
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Module
Expand All @@ -32,10 +36,9 @@ def quantize(
q_min: torch.Tensor,
q_max: torch.Tensor,
) -> torch.Tensor:

return torch.clamp(
torch.round(
x / scale + zero_point,
),
torch.round(x / scale + zero_point),
q_min,
q_max,
)
Expand All @@ -57,12 +60,81 @@ def fake_quantize(
zero_point: torch.Tensor,
args: QuantizationArgs,
) -> torch.Tensor:
"""
Fake quantize the input tensor x depending on the group_size.
if group_size is greater than 0, then q/dq by groups. The groups
must be divisible by the column size
if group_size is -1, then channel wise q/dq. THe input scale and
zero_points are reshaped to support vectorization (Assumes 1 is
the channel dimension)

:param x: Input tensor
:param scale: scale tensor
:param zero_point: zero point tensor
:param args: quantization args that contain group_size info
:return: fake quantized tensor

"""
bit_range = 2**args.num_bits
max_q = torch.tensor(bit_range / 2 - 1, device=x.device)
min_q = torch.tensor(-bit_range / 2, device=x.device)
Q = torch.zeros_like(x)
Q = quantize(x, scale, zero_point, min_q, max_q)
return dequantize(Q, scale, zero_point)

group_size = args.group_size

# group
if args.strategy == QuantizationStrategy.GROUP:

DQ = torch.zeros_like(x)

# TODO: vectorize the for loop
# TODO: fix genetric assumption about the tensor size for computing group

horheynm marked this conversation as resolved.
Show resolved Hide resolved
# TODO: make validation step for inputs

while scale.ndim < 2:
scale = scale.unsqueeze(1)
horheynm marked this conversation as resolved.
Show resolved Hide resolved
zero_point = zero_point.unsqueeze(1)

columns = x.shape[1]
if columns >= group_size:
assert columns % group_size == 0
horheynm marked this conversation as resolved.
Show resolved Hide resolved
for i in range(ceil(columns / group_size)):
horheynm marked this conversation as resolved.
Show resolved Hide resolved

# scale.shape should be [nchan, ndim]
# sc.shape should be [nchan, 1] after unsqueeze
sc = scale[:, i].unsqueeze(1)
zp = zero_point[:, i].unsqueeze(1)
horheynm marked this conversation as resolved.
Show resolved Hide resolved

idx = i * group_size
Q = quantize(x[:, idx : (idx + group_size)], sc, zp, min_q, max_q)
DQ[:, idx : (idx + group_size)] = dequantize(Q, sc, zp)
horheynm marked this conversation as resolved.
Show resolved Hide resolved

# channel-wise
elif args.strategy == QuantizationStrategy.CHANNEL: # group_size == -1
# before: scale shape = [channel_size]
# after: scale shape = [1, channel_size]
scale = scale.unsqueeze(0)
zero_point = zero_point.unsqueeze(0)

Q = quantize(x, scale, zero_point, min_q, max_q)
DQ = dequantize(Q, scale, zero_point)

# per-token
elif args.strategy == QuantizationStrategy.TOKEN:
# before: scale shape = [channel_size]
# after: scale shape = [channel_size, 1]
horheynm marked this conversation as resolved.
Show resolved Hide resolved

scale = scale.unsqueeze(1)
zero_point = zero_point.unsqueeze(1)

Q = quantize(x, scale, zero_point, min_q, max_q)
DQ = dequantize(Q, scale, zero_point)

else:
Q = quantize(x, scale, zero_point, min_q, max_q)
DQ = dequantize(Q, scale, zero_point)

return DQ


def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
Expand Down Expand Up @@ -138,5 +210,4 @@ def _maybe_calibrate_or_quantize(
device = next(module.parameters()).device
scale.data = updated_scale.to(device)
zero_point.data = updated_zero_point.to(device)

return fake_quantize(value, scale, zero_point, args)
67 changes: 64 additions & 3 deletions src/compressed_tensors/quantization/observers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@

from typing import Optional, Tuple

from compressed_tensors.quantization.quant_args import QuantizationArgs
import torch
from compressed_tensors.quantization.quant_args import (
QuantizationArgs,
QuantizationStrategy,
)
from compressed_tensors.registry.registry import RegistryMixin
from torch import FloatTensor, IntTensor, Tensor
from torch.nn import Module
Expand Down Expand Up @@ -52,6 +56,12 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
"""
raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")

def post_calculate_qparams(self) -> None:
"""
Run any logic specific to its observers after running calculate_qparams
"""
...

def get_qparams(
self, observed: Optional[Tensor] = None
) -> Tuple[FloatTensor, IntTensor]:
Expand All @@ -64,6 +74,57 @@ def get_qparams(
:return: tuple of scale and zero point based on last observed value
"""
if observed is not None:
# re-calcualte scale and zero point, update the stored value
self._scale, self._zero_point = self.calculate_qparams(observed)
group_size = self.quantization_args.group_size

if self.quantization_args.strategy == QuantizationStrategy.TENSOR:

# re-calculate scale and zero point, update the stored value
self._scale, self._zero_point = self.calculate_qparams(observed)

elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
columns = observed.shape[1]
scales, zero_points = [], []
for i in range(0, columns, self.quantization_args.group_size):
scale, zero_point = self.get_qparams_along_dim(
observed[:, i : (i + group_size)],
0,
)
scales.append(scale)
zero_points.append(zero_point)

self._scale = torch.stack(scales, dim=1)
self._zero_point = torch.stack(zero_points, dim=1)

elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
# assume observed is transposed, because its the output, hence use dim 0
self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)

elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:

# use dim 1, assume the obsersed.shape = [batch, token, hidden]
# should be batch, token

self._scale, self._zero_point = self.get_qparams_along_dim(
observed, dim=1
)

return self._scale, self._zero_point

def get_qparams_along_dim(self, observed, dim: int):
# TODO: add documentation that specifies the shape must
# be padded with 1-dims so the scales are along the right channel
# TODO: generalize the logic for reduce_dims
scales, zero_points = [], []

# TODO: make a more generic way to get the channel
num_dims = observed.shape[dim]

for dim_idx in range(num_dims):
scale, zero_point = self.calculate_qparams(
observed.select(dim=dim, index=dim_idx)
)

scales.append(scale)
zero_points.append(zero_point)
# breakpoint()
return torch.stack(scales), torch.stack(zero_points)
34 changes: 32 additions & 2 deletions src/compressed_tensors/quantization/quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from enum import Enum
from typing import Any, Dict, Optional

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, validator


__all__ = ["QuantizationType", "QuantizationStrategy", "QuantizationArgs"]
Expand All @@ -39,6 +39,7 @@ class QuantizationStrategy(str, Enum):
CHANNEL = "channel"
GROUP = "group"
BLOCK = "block"
TOKEN = "token"


class QuantizationArgs(BaseModel):
Expand All @@ -63,7 +64,7 @@ class QuantizationArgs(BaseModel):
num_bits: int = 8
type: QuantizationType = QuantizationType.INT
symmetric: bool = True
strategy: QuantizationStrategy = QuantizationStrategy.TENSOR
strategy: Optional[QuantizationStrategy] = None
group_size: Optional[int] = None
block_structure: Optional[str] = None
dynamic: bool = False
Expand Down Expand Up @@ -94,3 +95,32 @@ def get_observer(self):
self.observer = "memoryless"

return Observer.load_from_registry(self.observer, quantization_args=self)

@validator("strategy", pre=True)
def validate_strategy(cls, value, values):
horheynm marked this conversation as resolved.
Show resolved Hide resolved
group_size = values.get("group_size")

# use group_size to determinine strategy if not given explicity
if group_size is not None and value is None:
if group_size > 0:
return QuantizationStrategy.GROUP

elif group_size == -1:
return QuantizationStrategy.CHANNEL

else:
raise ValueError(
f"group_size={group_size} with strategy {value} is invald. "
"group_size > 0 for strategy='group' and "
"group_size = -1 for 'channel'"
)
# breakpoint()
group_size = 128
if value == QuantizationStrategy.GROUP:
if group_size is None:
raise ValueError(f"strategy {value} requires group_size to be set.")

if value is None:
return QuantizationStrategy.TENSOR

return value
1 change: 1 addition & 0 deletions src/compressed_tensors/quantization/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def calculate_compression_ratio(model: Module) -> float:
compressed_bits = uncompressed_bits
if is_module_quantized(submodule):
compressed_bits = submodule.quantization_scheme.weights.num_bits

num_weights = parameter.numel()
total_compressed += compressed_bits * num_weights
total_uncompressed += uncompressed_bits * num_weights
Expand Down
Loading