Skip to content

Commit

Permalink
Replace WeightOnlyInt8Linear with TorchAO int8_weight_only quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
vmpuri authored and vmpuri committed Oct 24, 2024
1 parent 7fe2c86 commit d43d52e
Showing 1 changed file with 11 additions and 142 deletions.
153 changes: 11 additions & 142 deletions torchchat/utils/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa
from torchao.quantization.quant_api import (
int4_weight_only,
int8_weight_only,
Int4WeightOnlyQuantizer,
Int8DynActInt4WeightQuantizer,
quantize_,
Expand Down Expand Up @@ -110,12 +111,20 @@ def quantize_model(
if quantizer not in quantizer_class_dict:
raise RuntimeError(f"unknown quantizer {quantizer} specified")
else:
ao_quant = True
# Use tensor subclass API for int4 weight only.
if device == "cuda" and quantizer == "linear:int4":
quantize_(model, int4_weight_only(q_kwargs["groupsize"]))
elif quantizer == "linear:int8":
print("quantizer is linear int8")
quantize_(model, int8_weight_only())
else:
ao_quant = False
if ao_quant:
if not support_tensor_subclass:
unwrap_tensor_subclass(model)
continue


if quantizer in ["linear:a8wxdq", "embedding:wx"]:
# These quantizers require float32 input weights. Note that after quantization,
Expand Down Expand Up @@ -529,147 +538,6 @@ def linear_int8_et(input, weight, scales):
)


class WeightOnlyInt8Linear(nn.Module):
__constants__ = ["in_features", "out_features"]
in_features: int
out_features: int
weight: torch.Tensor
scales: torch.Tensor

def __init__(
self,
in_features,
out_features,
bias=None,
device=None,
dtype=None,
*,
weight: Optional[torch.Tensor] = None,
scales: Optional[torch.Tensor] = None,
groupsize: Optional[int] = None,
):
super().__init__()
if dtype is None:
dtype = torch.get_default_dtype()

if device is None:
device = "cpu"

assert not bias, "Bias is not supported by LinearInt8"
self.in_features = in_features
self.out_features = out_features

assert (weight is None) == bool(
scales is None
), "must specify both weights and scales, or neither"
if weight is None:
weight = torch.empty(
(out_features, in_features),
dtype=torch.int8,
device=device,
)
if groupsize is None or (groupsize == 0):
scales = torch.empty(out_features, dtype=dtype, device=device)
else:
n_groups = (in_features + groupsize - 1) // groupsize
scales = torch.empty(out_features, n_groups, dtype=dtype, device=device)

self.register_buffer("weight", weight.to(device))
self.register_buffer("scales", scales.to(device))

if use_et_backend():
self.forward = self.et_forward
else:
self.forward = self.aoti_forward

def aoti_forward(self, input: torch.Tensor) -> torch.Tensor:
return linear_int8_aoti(input, self.weight, self.scales)

def et_forward(self, input: torch.Tensor) -> torch.Tensor:
return linear_int8_et(input, self.weight, self.scales)


class WeightOnlyInt8QuantHandler(QuantHandler):
def __init__(
self,
model: Optional[nn.Module] = None,
device = None,
precision=None,
tokenizer=None,
*,
node_type: str = "*",
bitwidth: Optional[int] = None,
groupsize: Optional[int] = None,
):
self.model_ = model
self.device = device
self.groupsize = groupsize
self.node_type = node_type
if bitwidth is None:
self.bitwidth = 8
else:
self.bitwidth = bitwidth

@torch.no_grad()
def quantize(self, module):
# cur_state_dict = state_dict_device(self.model_.state_dict())
# dict_device = "cpu" # self.device

if self.bitwidth == 4:
range_min = -8
range_max = 7
elif self.bitwidth == 8:
range_min = -128
range_max = 127
else:
raise ValueError(f"Unsupported bitwidth {self.bitwidth}")

for name, child in module.named_children():
# print(f"name: {name}")
if isinstance(child, nn.Linear):
if (
(self.node_type == "*")
or (self.node_type == "output" and name == "output")
or (self.node_type == "!output" and name != "output")
):
# print(f"{name, child}")
input_weight = child.weight.float()
# print(f"{name, child}")
# print(f"in_features: {child.in_features}")
# print(f"out_features: {child.out_features}")

# print(f"expanded weight shape {input_weight.shape}")
weight, scales, _ = dynamically_quantize_per_channel(
input_weight,
range_min,
range_max,
torch.int8,
self.groupsize,
scales_dtype=child.weight.dtype,
)

setattr(
module,
name,
WeightOnlyInt8Linear(
in_features=child.in_features,
out_features=child.out_features,
device=self.device,
# update variables from quantization
weight=weight,
scales=scales,
groupsize=self.groupsize,
),
)
else:
self.quantize(child)

return module

def quantized_model(self) -> nn.Module:
return self.quantize(self.model_)


#########################################################################
##### embedding table quantization ######
### (unify with torchao in future) ###
Expand Down Expand Up @@ -886,10 +754,10 @@ def quantized_model(self) -> nn.Module:
# class references
quantizer_class_dict = {
"embedding": EmbeddingOnlyQuantHandler,
"linear:int8": WeightOnlyInt8QuantHandler,
"precision": PrecisionHandler,
"executor": ExecutorHandler,
"linear:int4": Int4WeightOnlyQuantizer,
"linear:int8": int8_weight_only,
"linear:a8w4dq": Int8DynActInt4WeightQuantizer,
}

Expand Down Expand Up @@ -917,6 +785,7 @@ def quantized_model(self) -> nn.Module:
IntxWeightEmbeddingQuantizer,
)


quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightLinearQuantizer
quantizer_class_dict["embedding:wx"] = IntxWeightEmbeddingQuantizer

Expand Down

0 comments on commit d43d52e

Please sign in to comment.