Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
yanbing-j committed Nov 14, 2024
1 parent fbb2cae commit 98b8f8c
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 43 deletions.
5 changes: 3 additions & 2 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,8 +526,9 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self):
groupsize = 128

if TORCH_VERSION_AT_LEAST_2_5:
input_uint8 = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input_uint8, scales, zeros, n_bit, groupsize)
if not is_device(input.device.type, "cpu"):
input = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
else:
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
Expand Down
11 changes: 8 additions & 3 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2056,9 +2056,14 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias):

# groupwise int4 quantization
groupsize = weight_tensor.block_size[1]
y = torch.ops.aten._weight_int4pack_mm(
act_mat.contiguous(), packed_weight, groupsize, scale_and_zero
)
if is_device(input_tensor.device.type, "cpu"):
y = torch.ops.aten._weight_int4pack_mm_for_cpu(
act_mat.contiguous(), packed_weight, groupsize, scale_and_zero
)
else:
y = torch.ops.aten._weight_int4pack_mm(
act_mat.contiguous(), packed_weight, groupsize, scale_and_zero
)

# remove out_feature padding
orig_out_features = weight_tensor.shape[-2]
Expand Down
2 changes: 1 addition & 1 deletion torchao/prototype/hqq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ Initial benchmarking (on `A6000`) demonstrates promising results, scaling well f

- Times are in `ms`, see `benchmarks/benchmark_hqq.py`.
- `hqq_ref` is the base `HQQ_Linear` [module](https://github.com/mobiusml/hqq/blob/6d50eee4bcdd99cc10716f1297c5b2803d2b6da4/hqq/core/quantize.py#L349) that is unfused (dequantization followed by call to torch.matmul).
- `tinygemm` calls `torch.ops.aten._weight_int4pack_mm`. Implementation is a custom HQQLinear layer that wraps the preprocessing necessary for this kernel, adapted from a benchmark script posted by @mobicham from `CUDA-mode` Discord discussions.
- `tinygemm` calls `torch.ops.aten._weight_int4pack_mm` or `torch.ops.aten._weight_int4pack_mm_for_cpu`. Implementation is a custom HQQLinear layer that wraps the preprocessing necessary for this kernel, adapted from a benchmark script posted by @mobicham from `CUDA-mode` Discord discussions.

GPU details:

Expand Down
26 changes: 19 additions & 7 deletions torchao/prototype/hqq/hqq_tinygemm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import torch.nn.functional as F
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
from torchao.dtypes.utils import is_device


class HQQLinearTorchWeightOnlyInt4(torch.nn.Module):
Expand Down Expand Up @@ -162,9 +163,14 @@ def process_hqq_quants(self, W_q, meta):
W_q_torch, scales_torch, zeros_torch = self.hqq_quants_to_torch_quants(
W_q=W_q, scales=scales, zeros=zeros, shape=shape, nbits=self.nbits
)
self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
W_q_torch, self.inner_k_tiles
)
if is_device(W_q.device.type, "cpu"):
self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
W_q_torch, self.inner_k_tiles
)
else:
self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
W_q_torch, self.inner_k_tiles
)
self.scales_and_zeros = self.pack_scales_and_zeros(scales_torch, zeros_torch)

del W_q_torch, scales_torch, zeros_torch
Expand Down Expand Up @@ -200,7 +206,8 @@ def hqq_quants_to_torch_quants(
.contiguous()
)
if TORCH_VERSION_AT_LEAST_2_5:
W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8)
if not is_device(W_q.device.type, "cpu"):
W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8)

# group_dequantize_tensor_from_qparams
# W_r = W_q*scales + min_val
Expand Down Expand Up @@ -232,9 +239,14 @@ def pack_scales_and_zeros(self, scales, zeros):
def matmul(self, x):
origin_x_size = x.size()
x = x.reshape(-1, origin_x_size[-1])
c = torch.ops.aten._weight_int4pack_mm(
x, self.weight_int4pack, self.groupsize, self.scales_and_zeros
)
if is_device(x.device.type, "cpu"):
c = torch.ops.aten._weight_int4pack_mm_for_cpu(
x, self.weight_int4pack, self.groupsize, self.scales_and_zeros
)
else:
c = torch.ops.aten._weight_int4pack_mm(
x, self.weight_int4pack, self.groupsize, self.scales_and_zeros
)
new_shape = origin_x_size[:-1] + (self.out_features,)
c = c.reshape(new_shape)
return c
Expand Down
80 changes: 56 additions & 24 deletions torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
pack_tinygemm_scales_and_zeros,
per_token_dynamic_quant,
)
from torchao.dtypes.utils import is_device

aten = torch.ops.aten

Expand Down Expand Up @@ -542,12 +543,20 @@ def linear_forward_int4(
):
origin_x_size = x.size()
x = x.reshape(-1, origin_x_size[-1])
c = torch.ops.aten._weight_int4pack_mm(
x.to(precision),
weight_int4pack,
groupsize,
scales_and_zeros.to(scales_precision),
).to(dtype=x.dtype)
if is_device(x.device.type, "cpu"):
c = torch.ops.aten._weight_int4pack_mm_for_cpu(
x.to(precision),
weight_int4pack,
groupsize,
scales_and_zeros.to(scales_precision),
).to(dtype=x.dtype)
else:
c = torch.ops.aten._weight_int4pack_mm(
x.to(precision),
weight_int4pack,
groupsize,
scales_and_zeros.to(scales_precision),
).to(dtype=x.dtype)
new_shape = origin_x_size[:-1] + (out_features,)
c = c.reshape(new_shape)
return c
Expand Down Expand Up @@ -596,19 +605,32 @@ def __init__(
assert (
in_features % (inner_k_tiles * 16) == 0
), "require in_features % (innerKTiles * 16) == 0"
self.register_buffer(
"weight",
torch.zeros(
(
out_features // 8,
in_features // (inner_k_tiles * 16),
32,
inner_k_tiles // 2,
if is_device(device.type, "cpu"):
self.register_buffer(
"weight",
torch.zeros(
(
out_features,
in_features // 2,
),
dtype=torch.uint8,
device=device,
),
dtype=torch.int32,
device=device,
),
)
)
else:
self.register_buffer(
"weight",
torch.zeros(
(
out_features // 8,
in_features // (inner_k_tiles * 16),
32,
inner_k_tiles // 2,
),
dtype=torch.int32,
device=device,
),
)
self.dtype = dtype
self.register_buffer(
"scales_and_zeros",
Expand Down Expand Up @@ -765,9 +787,14 @@ def _create_quantized_state_dict(
self.precision, # dtype for scales_and_zeros
)
# TODO: just get the device from mod.weight.device?
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
w_int4x8.to(self.device), self.inner_k_tiles
)
if is_device(w_int4x8.device.type, "cpu"):
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
w_int4x8.to(self.device), self.inner_k_tiles
)
else:
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
w_int4x8.to(self.device), self.inner_k_tiles
)
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to(self.device)
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to(
self.device
Expand Down Expand Up @@ -851,9 +878,14 @@ def make_names_and_values_dict_func(q, qparams):
# how much we need to pad the weight
delta_k = int((new_k - k) / 2)
q = q.to(self.device)
final_q = torch.ops.aten._convert_weight_to_int4pack(
F.pad(q, pad=(0, delta_k)), inner_k_tiles
)
if is_device(self.device.type, "cpu"):
final_q = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
F.pad(q, pad=(0, delta_k)), inner_k_tiles
)
else:
final_q = torch.ops.aten._convert_weight_to_int4pack(
F.pad(q, pad=(0, delta_k)), inner_k_tiles
)
scales = qparams[0].to(torch.bfloat16).to(self.device)
zeros = qparams[1].to(torch.bfloat16).to(self.device)
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)
Expand Down
15 changes: 11 additions & 4 deletions torchao/quantization/qat/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .utils import (
_get_qmin_qmax,
)
from torchao.dtypes.utils import is_device


class FakeQuantizedLinear(torch.nn.Linear):
Expand Down Expand Up @@ -373,10 +374,16 @@ def _convert_qat_linear_4w(self, module: torch.nn.Module):
n_bit,
config.group_size,
)
q_weight = torch.ops.aten._convert_weight_to_int4pack(
q_weight.to(child.weight.device),
child.inner_k_tiles,
)
if is_device(q_weight.device.type, "cpu"):
q_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
q_weight.to(child.weight.device),
child.inner_k_tiles,
)
else:
q_weight = torch.ops.aten._convert_weight_to_int4pack(
q_weight.to(child.weight.device),
child.inner_k_tiles,
)
quantized_linear.weight = q_weight
quantized_linear.scales_and_zeros = scales_and_zeros
else:
Expand Down
3 changes: 2 additions & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,8 @@ def int4_weight_only(
"tensor_core_tiled" layout for speedup with tinygemm kernel
Note:
This is targeting `tinygemm` int4mm kernel (`torch.ops.aten._weight_int4pack_mm`), the main difference
This is targeting `tinygemm` int4mm kernel (`torch.ops.aten._weight_int4pack_mm`
and `torch.ops.aten._weight_int4pack_mm_for_cpu`), the main difference
of quantization algorithm compared to the more traditional type of integer quantization is the following:
1). zero_point is in floating point domain instead of integer domain (`zero_point_domain`=`ZeroPointDomain.FLOAT`)
2). floating point zero does not have to be exactly representable (`preserve_zero`=False in `choose_qparams_affine`)
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def groupwise_affine_dequantize_tensor_from_qparams(
# need to handle single column case so check for dtype/size from groupwise_affine_quantize_tensor_from_qparams path
if TORCH_VERSION_AT_LEAST_2_5 and (
w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1
):
) and not is_device(w_int4x8.device.type, "cpu"):
data = w_int4x8.to(torch.int32)
high_bits = data >> 4
low_bits = data & 0x0F
Expand Down

0 comments on commit 98b8f8c

Please sign in to comment.