From 512eb758d552242ec24b4dca0237e95821df35bb Mon Sep 17 00:00:00 2001 From: "Jiang, Yanbing" Date: Thu, 14 Nov 2024 10:06:03 +0000 Subject: [PATCH] Update --- test/quantization/test_quant_primitives.py | 5 +++-- torchao/prototype/hqq/hqq_tinygemm_linear.py | 15 +++++++++---- torchao/quantization/GPTQ.py | 23 +++++++++++++++----- torchao/quantization/qat/linear.py | 15 +++++++++---- torchao/quantization/utils.py | 2 +- 5 files changed, 43 insertions(+), 17 deletions(-) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 65ce9c71e..f163edeb4 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -526,8 +526,9 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self): groupsize = 128 if TORCH_VERSION_AT_LEAST_2_5: - input_uint8 = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8) - w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input_uint8, scales, zeros, n_bit, groupsize) + if not is_device(input.device.type, "cpu"): + input = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8) + w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) else: w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) diff --git a/torchao/prototype/hqq/hqq_tinygemm_linear.py b/torchao/prototype/hqq/hqq_tinygemm_linear.py index 8abdad039..15eddda39 100644 --- a/torchao/prototype/hqq/hqq_tinygemm_linear.py +++ b/torchao/prototype/hqq/hqq_tinygemm_linear.py @@ -13,6 +13,7 @@ import torch.nn.functional as F from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.dtypes.utils import is_device class HQQLinearTorchWeightOnlyInt4(torch.nn.Module): @@ -162,9 +163,14 @@ def process_hqq_quants(self, W_q, meta): W_q_torch, scales_torch, zeros_torch = self.hqq_quants_to_torch_quants( W_q=W_q, scales=scales, zeros=zeros, shape=shape, nbits=self.nbits ) - self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( - W_q_torch, self.inner_k_tiles - ) + if is_device(W_q.device.type, "cpu"): + self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + W_q_torch, self.inner_k_tiles + ) + else: + self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + W_q_torch, self.inner_k_tiles + ) self.scales_and_zeros = self.pack_scales_and_zeros(scales_torch, zeros_torch) del W_q_torch, scales_torch, zeros_torch @@ -200,7 +206,8 @@ def hqq_quants_to_torch_quants( .contiguous() ) if TORCH_VERSION_AT_LEAST_2_5: - W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8) + if not is_device(W_q.device.type, "cpu"): + W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8) # group_dequantize_tensor_from_qparams # W_r = W_q*scales + min_val diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index 470e71ae3..069086470 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -36,6 +36,7 @@ pack_tinygemm_scales_and_zeros, per_token_dynamic_quant, ) +from torchao.dtypes.utils import is_device aten = torch.ops.aten @@ -765,9 +766,14 @@ def _create_quantized_state_dict( self.precision, # dtype for scales_and_zeros ) # TODO: just get the device from mod.weight.device? - weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( - w_int4x8.to(self.device), self.inner_k_tiles - ) + if is_device(w_int4x8.device.type, "cpu"): + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + w_int4x8.to(self.device), self.inner_k_tiles + ) + else: + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + w_int4x8.to(self.device), self.inner_k_tiles + ) cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to(self.device) cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to( self.device @@ -851,9 +857,14 @@ def make_names_and_values_dict_func(q, qparams): # how much we need to pad the weight delta_k = int((new_k - k) / 2) q = q.to(self.device) - final_q = torch.ops.aten._convert_weight_to_int4pack( - F.pad(q, pad=(0, delta_k)), inner_k_tiles - ) + if is_device(self.device.type, "cpu"): + final_q = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + F.pad(q, pad=(0, delta_k)), inner_k_tiles + ) + else: + final_q = torch.ops.aten._convert_weight_to_int4pack( + F.pad(q, pad=(0, delta_k)), inner_k_tiles + ) scales = qparams[0].to(torch.bfloat16).to(self.device) zeros = qparams[1].to(torch.bfloat16).to(self.device) scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index cbe629640..02f898ffd 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -29,6 +29,7 @@ from .utils import ( _get_qmin_qmax, ) +from torchao.dtypes.utils import is_device class FakeQuantizedLinear(torch.nn.Linear): @@ -373,10 +374,16 @@ def _convert_qat_linear_4w(self, module: torch.nn.Module): n_bit, config.group_size, ) - q_weight = torch.ops.aten._convert_weight_to_int4pack( - q_weight.to(child.weight.device), - child.inner_k_tiles, - ) + if is_device(q_weight.device.type, "cpu"): + q_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + q_weight.to(child.weight.device), + child.inner_k_tiles, + ) + else: + q_weight = torch.ops.aten._convert_weight_to_int4pack( + q_weight.to(child.weight.device), + child.inner_k_tiles, + ) quantized_linear.weight = q_weight quantized_linear.scales_and_zeros = scales_and_zeros else: diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index dfa20532e..01793967e 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -418,7 +418,7 @@ def groupwise_affine_dequantize_tensor_from_qparams( # need to handle single column case so check for dtype/size from groupwise_affine_quantize_tensor_from_qparams path if TORCH_VERSION_AT_LEAST_2_5 and ( w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1 - ): + ) and not is_device(w_int4x8.device.type, "cpu"): data = w_int4x8.to(torch.int32) high_bits = data >> 4 low_bits = data & 0x0F