diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 65ce9c71e..f163edeb4 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -526,8 +526,9 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self): groupsize = 128 if TORCH_VERSION_AT_LEAST_2_5: - input_uint8 = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8) - w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input_uint8, scales, zeros, n_bit, groupsize) + if not is_device(input.device.type, "cpu"): + input = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8) + w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) else: w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index d2ee61c18..fb9833c65 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -2056,9 +2056,14 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): # groupwise int4 quantization groupsize = weight_tensor.block_size[1] - y = torch.ops.aten._weight_int4pack_mm( - act_mat.contiguous(), packed_weight, groupsize, scale_and_zero - ) + if is_device(input_tensor.device.type, "cpu"): + y = torch.ops.aten._weight_int4pack_mm_for_cpu( + act_mat.contiguous(), packed_weight, groupsize, scale_and_zero + ) + else: + y = torch.ops.aten._weight_int4pack_mm( + act_mat.contiguous(), packed_weight, groupsize, scale_and_zero + ) # remove out_feature padding orig_out_features = weight_tensor.shape[-2] diff --git a/torchao/prototype/hqq/README.md b/torchao/prototype/hqq/README.md index 8bf1d3426..1bdbcd96e 100644 --- a/torchao/prototype/hqq/README.md +++ b/torchao/prototype/hqq/README.md @@ -83,7 +83,7 @@ Initial benchmarking (on `A6000`) demonstrates promising results, scaling well f - Times are in `ms`, see `benchmarks/benchmark_hqq.py`. - `hqq_ref` is the base `HQQ_Linear` [module](https://github.com/mobiusml/hqq/blob/6d50eee4bcdd99cc10716f1297c5b2803d2b6da4/hqq/core/quantize.py#L349) that is unfused (dequantization followed by call to torch.matmul). -- `tinygemm` calls `torch.ops.aten._weight_int4pack_mm`. Implementation is a custom HQQLinear layer that wraps the preprocessing necessary for this kernel, adapted from a benchmark script posted by @mobicham from `CUDA-mode` Discord discussions. +- `tinygemm` calls `torch.ops.aten._weight_int4pack_mm` or `torch.ops.aten._weight_int4pack_mm_for_cpu`. Implementation is a custom HQQLinear layer that wraps the preprocessing necessary for this kernel, adapted from a benchmark script posted by @mobicham from `CUDA-mode` Discord discussions. GPU details: diff --git a/torchao/prototype/hqq/hqq_tinygemm_linear.py b/torchao/prototype/hqq/hqq_tinygemm_linear.py index 8abdad039..09e623fdc 100644 --- a/torchao/prototype/hqq/hqq_tinygemm_linear.py +++ b/torchao/prototype/hqq/hqq_tinygemm_linear.py @@ -13,6 +13,7 @@ import torch.nn.functional as F from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.dtypes.utils import is_device class HQQLinearTorchWeightOnlyInt4(torch.nn.Module): @@ -162,9 +163,14 @@ def process_hqq_quants(self, W_q, meta): W_q_torch, scales_torch, zeros_torch = self.hqq_quants_to_torch_quants( W_q=W_q, scales=scales, zeros=zeros, shape=shape, nbits=self.nbits ) - self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( - W_q_torch, self.inner_k_tiles - ) + if is_device(W_q.device.type, "cpu"): + self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + W_q_torch, self.inner_k_tiles + ) + else: + self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + W_q_torch, self.inner_k_tiles + ) self.scales_and_zeros = self.pack_scales_and_zeros(scales_torch, zeros_torch) del W_q_torch, scales_torch, zeros_torch @@ -200,7 +206,8 @@ def hqq_quants_to_torch_quants( .contiguous() ) if TORCH_VERSION_AT_LEAST_2_5: - W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8) + if not is_device(W_q.device.type, "cpu"): + W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8) # group_dequantize_tensor_from_qparams # W_r = W_q*scales + min_val @@ -232,9 +239,14 @@ def pack_scales_and_zeros(self, scales, zeros): def matmul(self, x): origin_x_size = x.size() x = x.reshape(-1, origin_x_size[-1]) - c = torch.ops.aten._weight_int4pack_mm( - x, self.weight_int4pack, self.groupsize, self.scales_and_zeros - ) + if is_device(x.device.type, "cpu"): + c = torch.ops.aten._weight_int4pack_mm_for_cpu( + x, self.weight_int4pack, self.groupsize, self.scales_and_zeros + ) + else: + c = torch.ops.aten._weight_int4pack_mm( + x, self.weight_int4pack, self.groupsize, self.scales_and_zeros + ) new_shape = origin_x_size[:-1] + (self.out_features,) c = c.reshape(new_shape) return c diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index 470e71ae3..da9a44e7a 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -36,6 +36,7 @@ pack_tinygemm_scales_and_zeros, per_token_dynamic_quant, ) +from torchao.dtypes.utils import is_device aten = torch.ops.aten @@ -542,12 +543,20 @@ def linear_forward_int4( ): origin_x_size = x.size() x = x.reshape(-1, origin_x_size[-1]) - c = torch.ops.aten._weight_int4pack_mm( - x.to(precision), - weight_int4pack, - groupsize, - scales_and_zeros.to(scales_precision), - ).to(dtype=x.dtype) + if is_device(x.device.type, "cpu"): + c = torch.ops.aten._weight_int4pack_mm_for_cpu( + x.to(precision), + weight_int4pack, + groupsize, + scales_and_zeros.to(scales_precision), + ).to(dtype=x.dtype) + else: + c = torch.ops.aten._weight_int4pack_mm( + x.to(precision), + weight_int4pack, + groupsize, + scales_and_zeros.to(scales_precision), + ).to(dtype=x.dtype) new_shape = origin_x_size[:-1] + (out_features,) c = c.reshape(new_shape) return c @@ -596,19 +605,32 @@ def __init__( assert ( in_features % (inner_k_tiles * 16) == 0 ), "require in_features % (innerKTiles * 16) == 0" - self.register_buffer( - "weight", - torch.zeros( - ( - out_features // 8, - in_features // (inner_k_tiles * 16), - 32, - inner_k_tiles // 2, + if is_device(device.type, "cpu"): + self.register_buffer( + "weight", + torch.zeros( + ( + out_features, + in_features // 2, + ), + dtype=torch.uint8, + device=device, ), - dtype=torch.int32, - device=device, - ), - ) + ) + else: + self.register_buffer( + "weight", + torch.zeros( + ( + out_features // 8, + in_features // (inner_k_tiles * 16), + 32, + inner_k_tiles // 2, + ), + dtype=torch.int32, + device=device, + ), + ) self.dtype = dtype self.register_buffer( "scales_and_zeros", @@ -765,9 +787,14 @@ def _create_quantized_state_dict( self.precision, # dtype for scales_and_zeros ) # TODO: just get the device from mod.weight.device? - weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( - w_int4x8.to(self.device), self.inner_k_tiles - ) + if is_device(w_int4x8.device.type, "cpu"): + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + w_int4x8.to(self.device), self.inner_k_tiles + ) + else: + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + w_int4x8.to(self.device), self.inner_k_tiles + ) cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to(self.device) cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to( self.device @@ -851,9 +878,14 @@ def make_names_and_values_dict_func(q, qparams): # how much we need to pad the weight delta_k = int((new_k - k) / 2) q = q.to(self.device) - final_q = torch.ops.aten._convert_weight_to_int4pack( - F.pad(q, pad=(0, delta_k)), inner_k_tiles - ) + if is_device(self.device.type, "cpu"): + final_q = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + F.pad(q, pad=(0, delta_k)), inner_k_tiles + ) + else: + final_q = torch.ops.aten._convert_weight_to_int4pack( + F.pad(q, pad=(0, delta_k)), inner_k_tiles + ) scales = qparams[0].to(torch.bfloat16).to(self.device) zeros = qparams[1].to(torch.bfloat16).to(self.device) scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index cbe629640..02f898ffd 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -29,6 +29,7 @@ from .utils import ( _get_qmin_qmax, ) +from torchao.dtypes.utils import is_device class FakeQuantizedLinear(torch.nn.Linear): @@ -373,10 +374,16 @@ def _convert_qat_linear_4w(self, module: torch.nn.Module): n_bit, config.group_size, ) - q_weight = torch.ops.aten._convert_weight_to_int4pack( - q_weight.to(child.weight.device), - child.inner_k_tiles, - ) + if is_device(q_weight.device.type, "cpu"): + q_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + q_weight.to(child.weight.device), + child.inner_k_tiles, + ) + else: + q_weight = torch.ops.aten._convert_weight_to_int4pack( + q_weight.to(child.weight.device), + child.inner_k_tiles, + ) quantized_linear.weight = q_weight quantized_linear.scales_and_zeros = scales_and_zeros else: diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 476cc229f..b92cacea3 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -575,7 +575,8 @@ def int4_weight_only( "tensor_core_tiled" layout for speedup with tinygemm kernel Note: - This is targeting `tinygemm` int4mm kernel (`torch.ops.aten._weight_int4pack_mm`), the main difference + This is targeting `tinygemm` int4mm kernel (`torch.ops.aten._weight_int4pack_mm` + and `torch.ops.aten._weight_int4pack_mm_for_cpu`), the main difference of quantization algorithm compared to the more traditional type of integer quantization is the following: 1). zero_point is in floating point domain instead of integer domain (`zero_point_domain`=`ZeroPointDomain.FLOAT`) 2). floating point zero does not have to be exactly representable (`preserve_zero`=False in `choose_qparams_affine`) diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index dfa20532e..01793967e 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -418,7 +418,7 @@ def groupwise_affine_dequantize_tensor_from_qparams( # need to handle single column case so check for dtype/size from groupwise_affine_quantize_tensor_from_qparams path if TORCH_VERSION_AT_LEAST_2_5 and ( w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1 - ): + ) and not is_device(w_int4x8.device.type, "cpu"): data = w_int4x8.to(torch.int32) high_bits = data >> 4 low_bits = data & 0x0F