From f4b62bf803e925d9430a94b63af8e0069873e985 Mon Sep 17 00:00:00 2001 From: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Date: Sat, 21 Dec 2024 17:20:37 -0800 Subject: [PATCH] quant: update tpu_int8 to use AphroditeParameters (#959) --- aphrodite/modeling/layers/linear.py | 2 +- aphrodite/quantization/tpu_int8.py | 21 +++++++++++---------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/aphrodite/modeling/layers/linear.py b/aphrodite/modeling/layers/linear.py index 281c1a03f..9790f9848 100644 --- a/aphrodite/modeling/layers/linear.py +++ b/aphrodite/modeling/layers/linear.py @@ -27,7 +27,7 @@ "CompressedTensorsLinearMethod", "GPTQMarlinLinearMethod", "AWQMarlinLinearMethod", "AWQLinearMethod", "HQQMarlinMethod", "Fp8LinearMethod", "MarlinLinearMethod", "QQQLinearMethod", - "GPTQMarlin24LinearMethod", + "GPTQMarlin24LinearMethod", "TPUInt8LinearMethod" ] diff --git a/aphrodite/quantization/tpu_int8.py b/aphrodite/quantization/tpu_int8.py index f30590c3b..2271f0b67 100644 --- a/aphrodite/quantization/tpu_int8.py +++ b/aphrodite/quantization/tpu_int8.py @@ -5,7 +5,7 @@ from torch.nn.parameter import Parameter from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase -from aphrodite.modeling.utils import set_weight_attrs +from aphrodite.modeling.parameter import ModelWeightParameter from aphrodite.quantization.base_config import QuantizationConfig ACTIVATION_SCHEMES = ["none"] @@ -63,16 +63,16 @@ def create_weights(self, layer: Module, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): - weight = Parameter(torch.empty(sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype), - requires_grad=False) + + weight_loader = extra_weight_attrs.get("weight_loader") + weight = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) layer.register_parameter("weight", weight) - set_weight_attrs(weight, { - **extra_weight_attrs, - "input_dim": 1, - "output_dim": 0, - }) def _quantize_weight( self, weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -91,6 +91,7 @@ def _quantize_weight( return qweight, qscale def process_weights_after_loading(self, layer: Module) -> None: + layer.weight = Parameter(layer.weight.data, requires_grad=False) device = layer.weight.device qweight, qscale = self._quantize_weight(layer.weight) qweight = qweight.to(device)