Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
yanbing-j committed Nov 14, 2024
1 parent fbb2cae commit 512eb75
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 17 deletions.
5 changes: 3 additions & 2 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,8 +526,9 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self):
groupsize = 128

if TORCH_VERSION_AT_LEAST_2_5:
input_uint8 = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input_uint8, scales, zeros, n_bit, groupsize)
if not is_device(input.device.type, "cpu"):
input = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
else:
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
Expand Down
15 changes: 11 additions & 4 deletions torchao/prototype/hqq/hqq_tinygemm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import torch.nn.functional as F
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
from torchao.dtypes.utils import is_device


class HQQLinearTorchWeightOnlyInt4(torch.nn.Module):
Expand Down Expand Up @@ -162,9 +163,14 @@ def process_hqq_quants(self, W_q, meta):
W_q_torch, scales_torch, zeros_torch = self.hqq_quants_to_torch_quants(
W_q=W_q, scales=scales, zeros=zeros, shape=shape, nbits=self.nbits
)
self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
W_q_torch, self.inner_k_tiles
)
if is_device(W_q.device.type, "cpu"):
self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
W_q_torch, self.inner_k_tiles
)
else:
self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
W_q_torch, self.inner_k_tiles
)
self.scales_and_zeros = self.pack_scales_and_zeros(scales_torch, zeros_torch)

del W_q_torch, scales_torch, zeros_torch
Expand Down Expand Up @@ -200,7 +206,8 @@ def hqq_quants_to_torch_quants(
.contiguous()
)
if TORCH_VERSION_AT_LEAST_2_5:
W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8)
if not is_device(W_q.device.type, "cpu"):
W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8)

# group_dequantize_tensor_from_qparams
# W_r = W_q*scales + min_val
Expand Down
23 changes: 17 additions & 6 deletions torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
pack_tinygemm_scales_and_zeros,
per_token_dynamic_quant,
)
from torchao.dtypes.utils import is_device

aten = torch.ops.aten

Expand Down Expand Up @@ -765,9 +766,14 @@ def _create_quantized_state_dict(
self.precision, # dtype for scales_and_zeros
)
# TODO: just get the device from mod.weight.device?
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
w_int4x8.to(self.device), self.inner_k_tiles
)
if is_device(w_int4x8.device.type, "cpu"):
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
w_int4x8.to(self.device), self.inner_k_tiles
)
else:
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
w_int4x8.to(self.device), self.inner_k_tiles
)
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to(self.device)
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to(
self.device
Expand Down Expand Up @@ -851,9 +857,14 @@ def make_names_and_values_dict_func(q, qparams):
# how much we need to pad the weight
delta_k = int((new_k - k) / 2)
q = q.to(self.device)
final_q = torch.ops.aten._convert_weight_to_int4pack(
F.pad(q, pad=(0, delta_k)), inner_k_tiles
)
if is_device(self.device.type, "cpu"):
final_q = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
F.pad(q, pad=(0, delta_k)), inner_k_tiles
)
else:
final_q = torch.ops.aten._convert_weight_to_int4pack(
F.pad(q, pad=(0, delta_k)), inner_k_tiles
)
scales = qparams[0].to(torch.bfloat16).to(self.device)
zeros = qparams[1].to(torch.bfloat16).to(self.device)
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)
Expand Down
15 changes: 11 additions & 4 deletions torchao/quantization/qat/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .utils import (
_get_qmin_qmax,
)
from torchao.dtypes.utils import is_device


class FakeQuantizedLinear(torch.nn.Linear):
Expand Down Expand Up @@ -373,10 +374,16 @@ def _convert_qat_linear_4w(self, module: torch.nn.Module):
n_bit,
config.group_size,
)
q_weight = torch.ops.aten._convert_weight_to_int4pack(
q_weight.to(child.weight.device),
child.inner_k_tiles,
)
if is_device(q_weight.device.type, "cpu"):
q_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
q_weight.to(child.weight.device),
child.inner_k_tiles,
)
else:
q_weight = torch.ops.aten._convert_weight_to_int4pack(
q_weight.to(child.weight.device),
child.inner_k_tiles,
)
quantized_linear.weight = q_weight
quantized_linear.scales_and_zeros = scales_and_zeros
else:
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def groupwise_affine_dequantize_tensor_from_qparams(
# need to handle single column case so check for dtype/size from groupwise_affine_quantize_tensor_from_qparams path
if TORCH_VERSION_AT_LEAST_2_5 and (
w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1
):
) and not is_device(w_int4x8.device.type, "cpu"):
data = w_int4x8.to(torch.int32)
high_bits = data >> 4
low_bits = data & 0x0F
Expand Down

0 comments on commit 512eb75

Please sign in to comment.