Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Observers] group size + channel wise + per token #32

Merged
merged 25 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
from math import ceil

import torch
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.quant_args import (
QuantizationArgs,
QuantizationStrategy,
)
from compressed_tensors.quantization.quant_config import QuantizationStatus
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Module
Expand Down Expand Up @@ -77,12 +80,8 @@ def fake_quantize(

group_size = args.group_size

if group_size is None or group_size == 0:
Q = quantize(x, scale, zero_point, min_q, max_q)
DQ = dequantize(Q, scale, zero_point)

# group
elif group_size > 0:
if args.strategy == QuantizationStrategy.GROUP:

DQ = torch.zeros_like(x)

Expand All @@ -101,15 +100,32 @@ def fake_quantize(
DQ[:, idx : (idx + group_size)] = dequantize(Q, sc, zp)

# channel-wise
else: # group_size == -1
elif args.strategy == QuantizationStrategy.CHANNEL: # group_size == -1
# before: scale shape = [channel_size]
# after: scale shape = [1, channel_size]

breakpoint()
horheynm marked this conversation as resolved.
Show resolved Hide resolved

scale = scale.unsqueeze(0)
zero_point = zero_point.unsqueeze(0)

Q = quantize(x, scale, zero_point, min_q, max_q)
DQ = dequantize(Q, scale, zero_point)

# per-token
elif args.strategy == QuantizationStrategy.TOKEN:
# before: scale shape = [channel_size]
# after: scale shape = [channel_size, 1]
horheynm marked this conversation as resolved.
Show resolved Hide resolved
scale = scale.unsqueeze(1)
zero_point = zero_point.unsqueeze(1)

Q = quantize(x, scale, zero_point, min_q, max_q)
DQ = dequantize(Q, scale, zero_point)

else:
Q = quantize(x, scale, zero_point, min_q, max_q)
DQ = dequantize(Q, scale, zero_point)

return DQ


Expand Down
21 changes: 10 additions & 11 deletions src/compressed_tensors/quantization/observers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,31 +97,30 @@ def get_qparams(
elif (
self.quantization_args.strategy == QuantizationStrategy.CHANNEL
): # channel-wise quantization

# TODO: make a genertic way to get the channel
channel = 1
self._scale, self._zero_point = self.get_qparams_per_channel(
observed, channel
self._scale, self._zero_point = self.get_qparams_along_dim(observed, 1)
elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:
dims = observed.ndim
self._scale, self._zero_point = self.get_qparams_along_dim(
observed, dim=dims - 2
horheynm marked this conversation as resolved.
Show resolved Hide resolved
)

self.post_calculate_qparams()
return self._scale, self._zero_point

def get_qparams_per_channel(self, observed, channel: int):
def get_qparams_along_dim(self, observed, dim: int):
# TODO: add documentation that specifies the shape must
# be padded with 1-dims so the scales are along the right channel
# TODO: generalize the logic for reduce_dims
scales, zero_points = [], []

# TODO: make a more generic way to get the channel
num_channels = observed.shape[channel]
num_dims = observed.shape[dim]

for channel_idx in range(num_channels):
for dim_idx in range(num_dims):
scale, zero_point = self.calculate_qparams(
observed.select(dim=channel, index=channel_idx)
observed.select(dim=dim, index=dim_idx)
)

scales.append(scale)
zero_points.append(zero_point)

return torch.cat(scales), torch.cat(zero_points)
return torch.stack(scales), torch.stack(zero_points)
1 change: 0 additions & 1 deletion src/compressed_tensors/quantization/observers/min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def __init__(
self.max_val = -float("inf")
self.averaging_constant = averaging_constant


def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
"""
Updates the observed min and max using a moving average smoothed by the
Expand Down
23 changes: 11 additions & 12 deletions src/compressed_tensors/quantization/quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class QuantizationStrategy(str, Enum):
CHANNEL = "channel"
GROUP = "group"
BLOCK = "block"
TOKEN = "token"


class QuantizationArgs(BaseModel):
Expand All @@ -63,7 +64,7 @@ class QuantizationArgs(BaseModel):
num_bits: int = 8
type: QuantizationType = QuantizationType.INT
symmetric: bool = True
strategy: QuantizationStrategy = QuantizationStrategy.TENSOR
strategy: Optional[QuantizationStrategy] = None
group_size: Optional[int] = None
block_structure: Optional[str] = None
dynamic: bool = False
Expand Down Expand Up @@ -98,21 +99,13 @@ def get_observer(self):
@validator("strategy", pre=True)
def validate_strategy(cls, value, values):
horheynm marked this conversation as resolved.
Show resolved Hide resolved
group_size = values.get("group_size")
if group_size is not None:

# use group_size to determinine strategy if not given explicity
if group_size is not None and value is None:
if group_size > 0:
if value != QuantizationStrategy.GROUP:
raise ValueError(
f"group_size={group_size} with strategy {value} is invald. "
"Please set strategy to 'group'"
)
return QuantizationStrategy.GROUP

elif group_size == -1:
if value != QuantizationStrategy.CHANNEL:
raise ValueError(
f"group_size={group_size} with strategy {value} is invald. "
"Please set strategy to 'channel'"
)
return QuantizationStrategy.CHANNEL

else:
Expand All @@ -121,5 +114,11 @@ def validate_strategy(cls, value, values):
"group_size > 0 for strategy='group' and "
"group_size = -1 for 'channel'"
)
if value == QuantizationStrategy.GROUP:
if group_size is None:
raise ValueError(f"strategy {value} is need group_size to be set.")
horheynm marked this conversation as resolved.
Show resolved Hide resolved

if value is None:
return QuantizationStrategy.TENSOR

return value
Loading