From aaf3cf460481dc54d4d4ab607d5d2b795cef87bc Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Mon, 21 Oct 2024 15:27:04 -0700 Subject: [PATCH] Subclass API (#995) Summary: Adds new int8_dynamic_activation_intx_weight quantization with subclass API Differential Revision: D62464487 --- ...8bit_act_xbit_weight_subclass_quantizer.py | 397 ++++++++++++++++++ torchao/experimental/docs/readme.md | 54 ++- torchao/experimental/quant_api.py | 53 +++ ...8bit_act_xbit_weight_subclass_quantizer.py | 180 ++++++++ torchao/quantization/quant_primitives.py | 28 +- 5 files changed, 698 insertions(+), 14 deletions(-) create mode 100644 torchao/experimental/_linear_8bit_act_xbit_weight_subclass_quantizer.py create mode 100644 torchao/experimental/tests/test_linear_8bit_act_xbit_weight_subclass_quantizer.py diff --git a/torchao/experimental/_linear_8bit_act_xbit_weight_subclass_quantizer.py b/torchao/experimental/_linear_8bit_act_xbit_weight_subclass_quantizer.py new file mode 100644 index 000000000..9cd5564a7 --- /dev/null +++ b/torchao/experimental/_linear_8bit_act_xbit_weight_subclass_quantizer.py @@ -0,0 +1,397 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from enum import auto, Enum + +import logging +from typing import List, Optional, Tuple + +import torch +from torch.ao.quantization.fx._decomposed import ( + dequantize_per_channel_group, + quantize_per_channel_group, +) +from torch.utils._python_dispatch import return_and_correct_aliasing +from torchao.dtypes.affine_quantized_tensor import ( + AQTTensorImpl, + register_aqt_quantized_linear_dispatch, + register_layout, +) +from torchao.dtypes.utils import Layout +from torchao.quantization.quant_primitives import ( + choose_qparams_affine, + MappingType, + ZeroPointDomain, +) +from torchao.utils import TorchAOBaseTensor + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + +import sys + +handler = logging.StreamHandler(sys.stdout) +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) + + +class Target(Enum): + """Enum that indicates the backend target + """ + NATIVE = auto() + FALLBACK = auto() + +def target_from_str(target: str) -> Target: + if target.lower() == "native": + return Target.NATIVE + elif target.lower() == "fallback": + return Target.FALLBACK + else: + raise ValueError(f"Invalid target: {target}") + + +# This format is intended for use with int8 dynamic quantization +class Linear8BitActXBitWeightLayout(Layout): + nbit: int + group_size: int + + # The target platform for the layout, either 'native' or 'fallback'. + target: Target + + def __init__( + self, + nbit: int, + group_size: int, + target: str, + ): + assert nbit <= 7 + self.nbit = nbit + self.group_size = group_size + self.target = target_from_str(target) + + def extra_repr(self): + return f"nbit={self.nbit}, group_size={self.group_size}, target={self.target}" + + +def _pack_weights_native( + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + layout: Layout, +): + assert isinstance(layout, Linear8BitActXBitWeightLayout) + assert layout.target == Target.NATIVE + nbit = layout.nbit + group_size = layout.group_size + has_weight_zeros = zero_point is not None + + if has_weight_zeros: + args = [ + int_data.to(torch.int8), + scale.reshape(-1), + zero_point.reshape(-1).to(torch.int8), + torch.empty(0, group_size, dtype=torch.int8), + ] + else: + args = [ + int_data.to(torch.int8), + scale.reshape(-1), + torch.empty(0, group_size, dtype=torch.int8), + ] + + wzp_suffix = "" if has_weight_zeros else "0zp" + return getattr(torch.ops.torchao, f"_pack_8bit_act_{nbit}bit{wzp_suffix}_weight")( + *args + ) + + +@register_layout(Linear8BitActXBitWeightLayout) +class Linear8BitActXBitWeightAQTTensorImpl(AQTTensorImpl): + def __new__( + cls, + packed_weight: torch.Tensor, + scale: Optional[torch.Tensor], + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = packed_weight.device + kwargs["dtype"] = packed_weight.dtype + assert not packed_weight.requires_grad + kwargs["requires_grad"] = False + shape = packed_weight.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + packed_weight: torch.Tensor, + scale: Optional[torch.Tensor], + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + assert isinstance(_layout, Linear8BitActXBitWeightLayout) + + # In the native case, scale and zero_point information is inside + # the packed_weight + if _layout.target == Target.NATIVE: + assert scale is None + assert zero_point is None + + self.packed_weight = packed_weight + self.scale = scale + self.zero_point = zero_point + self._layout = _layout + + def __repr__(self): + layout = self.get_layout() + return f"{self.__class__.__name__}(packed_weight={str(self.packed_weight)}, scale={str(self.scale)}, zero_point={str(self.zero_point)}, layout={layout})" + + def get_layout(self) -> Layout: + return self._layout + + def get_plain(self) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + if self.get_layout().target == Target.FALLBACK: + return self.packed_weight, self.scale, self.zero_point + raise NotImplementedError("get_plain is not supported for Linear8BitActXBitWeightAQTTensorImpl when target is not fallback") + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + layout: Layout, + ): + assert isinstance(layout, Linear8BitActXBitWeightLayout) + + try: + if layout.target == Target.NATIVE: + packed_weight = _pack_weights_native( + int_data, scale, zero_point, layout + ) + scale = None + zero_point = None + return cls(packed_weight, scale, zero_point, layout) + except Exception as e: + logger.warning( + f"A failure occurred when packing weights with Linear8BitActXBitWeightLayout.target={layout.target}: {e}\n" + + "Falling back to **slow** implementation Linear8BitActXBitWeightLayout.target=fallback." + ) + layout.target = Target.FALLBACK + + # Fallback + assert layout.target == Target.FALLBACK + packed_weight = int_data.to(torch.int8) + return cls(packed_weight, scale, zero_point, layout) + + def _apply_fn_to_data(self, fn): + self.packed_weight = fn(self.packed_weight) + if self.scale is not None: + self.scale = fn(self.scale) + + if self.zero_point is not None: + self.zero_point = fn(self.zero_point) + return self + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is torch.ops.aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + if func is torch.ops.aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + raise NotImplementedError( + f"Linear8BitActXBitWeightAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + def __tensor_flatten__(self): + if self.get_layout().target == Target.NATIVE: + return ["packed_weight"], [self.get_layout()] + + # fallback + assert self.get_layout().target == Target.FALLBACK + if self.zero_point is None: + return ["packed_weight", "scale"], [self.get_layout()] + return ["packed_weight", "scale", "zero_point"], [self.get_layout()] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + packed_weight, scale, zero_point = ( + tensor_data_dict["packed_weight"], + tensor_data_dict.get("scale", None), + tensor_data_dict.get("zero_point", None), + ) + (layout,) = tensor_attributes + return cls(packed_weight, scale, zero_point, layout) + + +def _linear_int8_dynamic_activation_intx_weight_check( + input_tensor, weight_tensor, bias +): + layout = weight_tensor.tensor_impl.get_layout() + return isinstance(layout, Linear8BitActXBitWeightLayout) and bias is None + + +def _linear_int8_dynamic_activation_intx_weight_fallback_impl( + input_tensor, weight_tensor, bias +): + assert weight_tensor.tensor_impl.get_layout().target == Target.FALLBACK + assert bias is None + + def _impl_2d(input_tensor, weight_tensor): + assert input_tensor.dim() == 2 + assert weight_tensor.dim() == 2 + + weight_qvals = weight_tensor.tensor_impl.packed_weight.to(torch.int32) + weight_scales = weight_tensor.tensor_impl.scale + weight_zeros = weight_tensor.tensor_impl.zero_point + group_size = weight_tensor.tensor_impl.get_layout().group_size + has_weight_zeros = weight_zeros is not None + m, k = input_tensor.shape + n, k_ = weight_tensor.shape + assert k_ == k + + weights_dequantized = weight_tensor.dequantize() + + # Quantize activations + activation_scales, activation_zeros = choose_qparams_affine( + input=input_tensor, + mapping_type=MappingType.ASYMMETRIC, + block_size=(1, k), + target_dtype=torch.int32, + quant_min=-128, + quant_max=127, + eps=0.0, + scale_dtype=torch.float32, + zero_point_dtype=torch.int32, + preserve_zero=True, + zero_point_domain=ZeroPointDomain.INT, + ) + activation_qvals = quantize_per_channel_group( + input=input_tensor, + scales=activation_scales, + zero_points=activation_zeros, + quant_min=-128, + quant_max=127, + dtype=torch.int8, + group_size=k, + ) + activations_dequantized = dequantize_per_channel_group( + w_int8=activation_qvals, + scales=activation_scales, + zero_points=activation_zeros, + quant_min=None, # TODO: why is this an arg for this function + quant_max=None, # TODO: why is this an arg for this function + dtype=None, # TODO: why is this an arg for this function + group_size=k, + output_dtype=torch.float32, + ) + + return torch.matmul( + activations_dequantized, weights_dequantized.transpose(1, 0) + ) + + if input_tensor.dim() == 2: + return _impl_2d(input_tensor, weight_tensor) + + assert input_tensor.dim() >= 3 + lead_shape = input_tensor.shape[0:-2] + m, k = input_tensor.shape[-2], input_tensor.shape[-1] + n, k_ = weight_tensor.shape + assert k_ == k + + res = _impl_2d(input_tensor.reshape(-1, k), weight_tensor) + res = res.reshape(*lead_shape, m, n) + + return res + + +def _linear_int8_dynamic_activation_intx_weight_native_impl( + input_tensor, weight_tensor, bias +): + assert weight_tensor.tensor_impl.get_layout().target == Target.NATIVE + assert bias is None + + def _impl_2d(input_tensor, weight_tensor): + assert input_tensor.dim() == 2 + assert weight_tensor.dim() == 2 + + m, k = input_tensor.shape + n, k_ = weight_tensor.shape + assert k_ == k + group_size = weight_tensor.tensor_impl.get_layout().group_size + packed_weight = weight_tensor.tensor_impl.packed_weight + + # TODO(T200095131): convert self.n, self.k, self.group_size to + # int when supported by AOTI + args = ( + input_tensor, + packed_weight, + torch.empty(0, group_size, dtype=torch.int8), + torch.empty(0, n, dtype=torch.int8), + torch.empty(0, k, dtype=torch.int8), + ) + + has_weight_zeros = (weight_tensor.zero_point_domain != ZeroPointDomain.ZERO) + + assert len(weight_tensor.block_size) == 2 + assert weight_tensor.block_size[0] == 1 + group_size = weight_tensor.block_size[1] + assert group_size == weight_tensor.tensor_impl.get_layout().group_size + nbit = weight_tensor.tensor_impl.get_layout().nbit + + n, k = weight_tensor.shape + m, k_ = input_tensor.shape + assert k_ == k + + packed_weight = weight_tensor.tensor_impl.packed_weight + wzp_suffix = "" if has_weight_zeros else "0zp" + return getattr( + torch.ops.torchao, f"_linear_8bit_act_{nbit}bit{wzp_suffix}_weight" + )(*args) + + if input_tensor.dim() == 2: + return _impl_2d(input_tensor, weight_tensor) + + assert input_tensor.dim() >= 3 + lead_shape = input_tensor.shape[0:-2] + m, k = input_tensor.shape[-2], input_tensor.shape[-1] + n, k_ = weight_tensor.shape + assert k_ == k + + res = _impl_2d(input_tensor.reshape(-1, k), weight_tensor) + res = res.reshape(*lead_shape, m, n) + return res + + +def _linear_int8_dynamic_activation_intx_weight_impl(input_tensor, weight_tensor, bias): + target = weight_tensor.tensor_impl.get_layout().target + if target == Target.NATIVE: + return _linear_int8_dynamic_activation_intx_weight_native_impl( + input_tensor, weight_tensor, bias + ) + + if target == Target.FALLBACK: + return _linear_int8_dynamic_activation_intx_weight_fallback_impl( + input_tensor, weight_tensor, bias + ) + + assert False, f"Unknown target {target}" + + +register_aqt_quantized_linear_dispatch( + _linear_int8_dynamic_activation_intx_weight_check, + _linear_int8_dynamic_activation_intx_weight_impl, +) diff --git a/torchao/experimental/docs/readme.md b/torchao/experimental/docs/readme.md index e7e3cddf0..53f6e3ddf 100644 --- a/torchao/experimental/docs/readme.md +++ b/torchao/experimental/docs/readme.md @@ -1,21 +1,29 @@ # TorchAO experimental -TorchAO experimental contains lowbit ARM CPU and Metal kernels for linear and embedding ops. +TorchAO experimental contains lowbit ARM CPU and Metal kernels for linear and +embedding ops. ## Building ARM CPU kernels -To build torch ops that use the lowbit kernels, run `sh build_torchao_ops.sh ` from torchao/experimental. +To build torch ops that use the lowbit kernels, run +`sh build_torchao_ops.sh ` from torchao/experimental. -For example, to build ATen ops, run `sh build_torchao_ops.sh aten` (this requires PyTorch). Similarly, to build the ExecuTorch ops, run `sh build_torchao_ops executorch` (this requires ExecuTorch). +For example, to build ATen ops, run `sh build_torchao_ops.sh aten` (this +requires PyTorch). Similarly, to build the ExecuTorch ops, run +`sh build_torchao_ops executorch` (this requires ExecuTorch). After running the script, the op libraries will be in + ``` cmake-out/lib/libtorchao_ops_aten.{dylib|so} # ATen op library cmake-out/lib/libtorchao_ops_executorch.a # ExecuTorch op library ``` ## Quantizing models -Once the ATen ops are built, you can quantize PyTorch models with them. The quantized models can be run in eager model, compiled, used with AOTI, or exported. The exported models can be lowered to ExecuTorch. + +Once the ATen ops are built, you can quantize PyTorch models with them. The +quantized models can be run in eager model, compiled, used with AOTI, or +exported. The exported models can be lowered to ExecuTorch. ```python import torch @@ -43,8 +51,42 @@ linear_quantizer = Int8DynActIntxWeightLinearQuantizer( quantized_model = linear_quantizer.quantize(quantized_model) ``` -If you get stuck on the above steps, working examples for both linear and embedding are in torchao/experimental/tests/test_linear_8bit_act_xbit_weight_quantizer.py and torchao/experimental/tests/test_embedding_xbit_quantizer.py. For example, running `python tests/test_linear_8bit_act_xbit_weight_quantizer.py` loads the ops, creates a toy model, quantizes the model, and runs it in eager, compile, AOTI, and exports the model. +If you get stuck on the above steps, working examples for both linear and +embedding are in +torchao/experimental/tests/test_linear_8bit_act_xbit_weight_quantizer.py and +torchao/experimental/tests/test_embedding_xbit_quantizer.py. For example, +running `python tests/test_linear_8bit_act_xbit_weight_quantizer.py` loads the +ops, creates a toy model, quantizes the model, and runs it in eager, compile, +AOTI, and exports the model. + +### Subclass API + +For linear, you can also use the new subclass API in torchao. + +```python +import torch +torch.ops.load_library("cmake-out/lib/libtorchao_ops_aten.dylib") # make sure this path is correct on your machine + +my_model = Model() + +from torchao.experimental.quant_api import int8_dyn_act_intx_weight +from torchao.quantization.quant_api import quantize_ +quantize_( + my_model, + int8_dyn_act_intx_weight( + group_size=256, + nbit=4, + has_weight_zeros=False, + ), +) +``` + +If you get stuck, consult +`tests/test_linear_8bit_act_xbit_weight_subclass_quantizer.py`. ## Available in torchchat -TorchAO experimental kernels are [available in torchchat](https://github.com/pytorch/torchchat/blob/main/docs/quantization.md#experimental-torchao-lowbit-kernels), PyTorch's solution for running LLMs locally. Torchchat integration uses similar steps to above. +TorchAO experimental kernels are +[available in torchchat](https://github.com/pytorch/torchchat/blob/main/docs/quantization.md#experimental-torchao-lowbit-kernels), +PyTorch's solution for running LLMs locally. Torchchat integration uses similar +steps to above. diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index e22c97e05..cb2dc9b79 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -317,3 +317,56 @@ def quantize(self, model: nn.Module) -> nn.Module: }, ) return model + + +from torchao.experimental._linear_8bit_act_xbit_weight_subclass_quantizer import Linear8BitActXBitWeightLayout +from torchao.quantization.quant_api import ( + _get_linear_subclass_inserter, + MappingType, + to_affine_quantized_intx, + ZeroPointDomain, +) + + +def int8_dyn_act_intx_weight( + group_size: int = 128, + nbit: int = 4, + has_weight_zeros: bool = False, + target: str = "native", +): + + def apply(weight): + assert weight.shape[-1] % group_size == 0 + assert weight.device == torch.device("cpu"), "Only CPU is supported" + use_hqq = False + layout = Linear8BitActXBitWeightLayout( + nbit=nbit, group_size=group_size, target=target + ) + mapping_type = MappingType.ASYMMETRIC + eps = torch.finfo(torch.float32).eps + block_size = (1, group_size) + target_dtype = torch.int32 + quant_min = -(1 << (nbit - 1)) + quant_max = (1 << (nbit - 1)) - 1 + zero_point_dtype = torch.int8 + preserve_zero = has_weight_zeros + zero_point_domain = ZeroPointDomain.INT if has_weight_zeros else ZeroPointDomain.ZERO + # Note: this works differently than other quantizers because the dynamic + # activation quantization is fused with the kernel/op (and static activation quantization + # is not supported). + return to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + _layout=layout, + use_hqq=use_hqq, + ) + + return _get_linear_subclass_inserter(apply) diff --git a/torchao/experimental/tests/test_linear_8bit_act_xbit_weight_subclass_quantizer.py b/torchao/experimental/tests/test_linear_8bit_act_xbit_weight_subclass_quantizer.py new file mode 100644 index 000000000..7e337a49c --- /dev/null +++ b/torchao/experimental/tests/test_linear_8bit_act_xbit_weight_subclass_quantizer.py @@ -0,0 +1,180 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import copy + +import glob +import os +import subprocess + +import sys +import tempfile +import unittest + +import torch + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))) + +from torchao.experimental.quant_api import int8_dyn_act_intx_weight +from torchao.quantization.quant_api import quantize_ + +from torchao.utils import unwrap_tensor_subclass +from torchao.experimental.quant_api import ( + _Int8DynActIntxWeightQuantizedLinearFallback, +) + +def cmake_build_torchao_ops(temp_build_dir): + from distutils.sysconfig import get_python_lib + + print("Building torchao ops for ATen target") + cmake_prefix_path = get_python_lib() + dir_path = os.path.dirname(os.path.realpath(__file__)) + subprocess.run( + [ + "cmake", + "-DCMAKE_PREFIX_PATH=" + cmake_prefix_path, + "-DCMAKE_INSTALL_PREFIX=" + temp_build_dir.name, + "-S " + dir_path + "/../", + "-B " + temp_build_dir.name, + ] + ) + subprocess.run( + [ + "cmake", + "--build", + temp_build_dir.name, + "-j 16", + "--target install", + "--config Release", + ] + ) + + +temp_build_dir = tempfile.TemporaryDirectory() +cmake_build_torchao_ops(temp_build_dir) +libs = glob.glob(f"{temp_build_dir.name}/lib/libtorchao_ops_aten.*") +libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) +assert len(libs) == 1 +torch.ops.load_library(libs[0]) + + +class TestInt8DynamicActivationIntxWeight(unittest.TestCase): + def test_accuracy(self): + group_size = 128 + m = 1 + n = 1071 + k = 4096 + activations = torch.randn(m, k, dtype=torch.float32) + model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)]) + + for nbit in [1, 2, 3, 4, 5, 6, 7]: + for has_weight_zeros in [True, False]: + print(f"Testing nbit={nbit}, has_weight_zeros={has_weight_zeros}") + quantized_model = copy.deepcopy(model) + quantize_( + quantized_model, + int8_dyn_act_intx_weight( + group_size=group_size, + nbit=nbit, + has_weight_zeros=has_weight_zeros, + ), + ) + + quantized_model_reference = copy.deepcopy(model) + quantize_( + quantized_model_reference, + int8_dyn_act_intx_weight( + group_size=group_size, + nbit=nbit, + has_weight_zeros=has_weight_zeros, + target="fallback", + ), + ) + + with torch.no_grad(): + result = quantized_model(activations) + expected_result = quantized_model_reference(activations) + + #TODO: remove expected_result2 checks when we deprecate non-subclass API + reference_impl = _Int8DynActIntxWeightQuantizedLinearFallback() + reference_impl.quantize_and_pack_weights( + model[0].weight, nbit, group_size, has_weight_zeros + ) + expected_result2 = reference_impl(activations) + + num_mismatch_at_low_tol = 0 + num_mismatch_at_low_tol2 = 0 + num_total = result.reshape(-1).shape[0] + for i in range(num_total): + actual_val = result.reshape(-1)[i] + expected_val = expected_result.reshape(-1)[i] + expected_val2 = expected_result2.reshape(-1)[i] + self.assertTrue(torch.allclose(actual_val, expected_val, atol=1e-6)) + if not torch.allclose(actual_val, expected_val): + num_mismatch_at_low_tol += 1 + + self.assertTrue(torch.allclose(expected_val, expected_val2, atol=1e-2, rtol=1e-1)) + if not torch.allclose(expected_val, expected_val2): + num_mismatch_at_low_tol2 += 1 + + # Assert at most 5% of entries are not close at a low tolerance + self.assertTrue(num_mismatch_at_low_tol / num_total <= 0.05) + self.assertTrue(num_mismatch_at_low_tol2 / num_total <= 0.01) + + def test_export_compile_aoti(self): + group_size = 32 + m = 3 + k0 = 512 + k1 = 256 + k2 = 128 + k3 = 1024 + nbit = 4 + has_weight_zeros = True + layers = [ + torch.nn.Linear(k0, k1, bias=False), + torch.nn.Linear(k1, k2, bias=False), + torch.nn.Linear(k2, k3, bias=False), + ] + model = torch.nn.Sequential(*layers) + activations = torch.randn(2, 1, m, k0, dtype=torch.float32) + + print("Quantizing model") + quantize_( + model, + int8_dyn_act_intx_weight( + group_size=group_size, + nbit=nbit, + has_weight_zeros=has_weight_zeros, + target="native", + ), + ) + + unwrapped_model = copy.deepcopy(model) + unwrap_tensor_subclass(model) + + print("Exporting quantized model") + exported = torch.export.export(model, (activations,)) + + print("Compiling quantized model") + compiled = torch.compile(unwrapped_model) + with torch.no_grad(): + compiled(activations) + + with tempfile.TemporaryDirectory() as tmpdirname: + print("Exporting quantized model with AOTI") + torch._export.aot_compile( + model, + (activations,), + options={"aot_inductor.output_path": f"{tmpdirname}/model.so"}, + ) + + print("Running quantized model in AOTI") + fn = torch._export.aot_load(f"{tmpdirname}/model.so", "cpu") + fn(activations) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index dfd3bcaad..5a4bcf527 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -60,9 +60,11 @@ class ZeroPointDomain(Enum): integer domain: quantized_val = (float_val / scale) (integer) + zero_point (integer) float domain: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale + zero domain: quantized_val = (float_val / scale) """ INT = auto() FLOAT = auto() + ZERO = auto() class TorchAODType(Enum): """ @@ -344,6 +346,8 @@ def _quantize_affine_no_dtype_cast( quant = torch.clamp( torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max ) + elif zero_point_domain == ZeroPointDomain.ZERO.name: + quant = torch.clamp(torch.round(input * (1.0 / scale)), quant_min, quant_max) elif zero_point_domain is None: # This case handles quantization for float8 we expect no zero point and no zero point domain assert zero_point is None, "zero_point should be None when zero_point_domain is None" @@ -477,6 +481,9 @@ def _dequantize_affine_no_dtype_check( dequant = dequant - zero_point.to(torch.int32) dequant = dequant.to(output_dtype) dequant = dequant * scale + elif zero_point_domain == ZeroPointDomain.ZERO.name: + dequant = input.to(output_dtype) + dequant = dequant * scale elif zero_point_domain is None: # This case handles dequantization for float8 we expect no zero point and no zero point domain assert zero_point is None, "zero_point should be None when zero_point_domain is None" @@ -813,15 +820,20 @@ def _choose_qparams_affine( assert mapping_type == MappingType.ASYMMETRIC.name scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) scale = torch.clamp(scale, min=eps) - if preserve_zero: - zero_point = quant_min - torch.round(min_val_neg / scale) - zero_point = torch.clamp(zero_point, quant_min, quant_max) + if zero_point_domain == ZeroPointDomain.ZERO.name: + zero_point = None else: - assert zero_point_domain == ZeroPointDomain.FLOAT.name, "if not preserve_zero, zero_point must be in FLOAT domain" - mid_point = (quant_max + quant_min + 1) / 2 - zero_point = min_val_neg + scale * mid_point - - return scale.to(dtype=scale_dtype), zero_point.to(dtype=zero_point_dtype) + if preserve_zero: + zero_point = quant_min - torch.round(min_val_neg / scale) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + else: + assert zero_point_domain == ZeroPointDomain.FLOAT.name, "if not preserve_zero, zero_point must be in FLOAT domain" + mid_point = (quant_max + quant_min + 1) / 2 + zero_point = min_val_neg + scale * mid_point + + if zero_point is not None: + zero_point = zero_point.to(dtype=zero_point_dtype) + return scale.to(dtype=scale_dtype), zero_point # HQQ