Skip to content

Commit

Permalink
Add Int4CPULayout and update int4 woq
Browse files Browse the repository at this point in the history
  • Loading branch information
yanbing-j committed Nov 13, 2024
1 parent 01dc7da commit 1b26f26
Show file tree
Hide file tree
Showing 6 changed files with 290 additions and 34 deletions.
40 changes: 22 additions & 18 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
float8_weight_only,
)
from torchao.quantization.quant_primitives import MappingType
from torchao.dtypes import SemiSparseLayout
from torchao.dtypes import SemiSparseLayout, Int4CPULayout
from torch.testing._internal import common_utils
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

Expand All @@ -22,15 +22,18 @@
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)


def get_quantization_functions(do_sparse: bool, do_int4: bool):
def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cuda"):
base_functions = [
int8_weight_only(),
int8_dynamic_activation_int4_weight(),
int8_dynamic_activation_int8_weight(),
int8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC),
]
if do_int4:
base_functions.append(int4_weight_only(group_size=32))
if device == "cpu":
base_functions.append(int4_weight_only(group_size=32, layout=Int4CPULayout()))
else:
base_functions.append(int4_weight_only(group_size=32))

if do_sparse:
base_functions.append(int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()))
Expand Down Expand Up @@ -139,23 +142,24 @@ class TestAffineQuantizedBasic(TestCase):
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
COMMON_DTYPES = [torch.bfloat16]

@common_utils.parametrize("apply_quant", get_quantization_functions(False, True))
@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
def test_flatten_unflatten(self, apply_quant, device, dtype):
l = torch.nn.Linear(128, 256, dtype=dtype, device=device)
ql = apply_quant(l)
lp_tensor = ql.weight
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
tensor_data_dict = {name: getattr(lp_tensor, name) for name in tensor_data_name_dict}
outer_size = lp_tensor.size()
outer_stride = lp_tensor.stride()
reconstructed = type(lp_tensor).__tensor_unflatten__(tensor_data_dict, tensor_attributes, outer_size, outer_stride)
example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),)
ref = ql(*example_inputs)
ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False)
reconstruct_res = ql(*example_inputs)
self.assertEqual(reconstruct_res, ref)
def test_flatten_unflatten(self, device, dtype):
apply_quant_list = get_quantization_functions(False, True, device)
for apply_quant in apply_quant_list:
l = torch.nn.Linear(128, 256, dtype=dtype, device=device)
ql = apply_quant(l)
lp_tensor = ql.weight
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
tensor_data_dict = {name: getattr(lp_tensor, name) for name in tensor_data_name_dict}
outer_size = lp_tensor.size()
outer_stride = lp_tensor.stride()
reconstructed = type(lp_tensor).__tensor_unflatten__(tensor_data_dict, tensor_attributes, outer_size, outer_stride)
example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),)
ref = ql(*example_inputs)
ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False)
reconstruct_res = ql(*example_inputs)
self.assertEqual(reconstruct_res, ref)

common_utils.instantiate_parametrized_tests(TestAffineQuantized)
common_utils.instantiate_parametrized_tests(TestAffineQuantizedBasic)
Expand Down
18 changes: 14 additions & 4 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torchao.quantization.dynamic_quant import (
DynamicallyPerAxisQuantizedLinear,
)
from torchao.dtypes import TensorCoreTiledLayout
from torchao.dtypes import TensorCoreTiledLayout, Int4CPULayout
from torchao.quantization.quant_api import (
int4_weight_only,
int8_weight_only,
Expand Down Expand Up @@ -132,7 +132,11 @@ def _int8da_int8w_api(mod):

def _int4wo_api(mod):
if TORCH_VERSION_AT_LEAST_2_4:
quantize_(mod, int4_weight_only(), set_inductor_config=False)
device_type = next(mod.parameters()).device
if device_type == torch.device("cpu"):
quantize_(mod, int4_weight_only(layout=Int4CPULayout()), set_inductor_config=False)
else:
quantize_(mod, int4_weight_only(), set_inductor_config=False)
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(mod)
else:
Expand Down Expand Up @@ -911,10 +915,16 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):
def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
layout_list = []
if device == 'cuda':
for inner_k_tiles in [4, 2]:
layout_list.append(TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles))
elif device == 'cpu':
layout_list.append(Int4CPULayout())
for test_shape in ([(256, 256, 16)] + ([(256, 256, 8)] if device=='cuda' else [])):
for groupsize in [64, 32]:
for inner_k_tiles in [4, 2]:
kwargs = {"groupsize": groupsize, "layout": TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles)}
for layout in layout_list:
kwargs = {"groupsize": groupsize, "layout": layout}

def api(mod):
kwargs_copy = kwargs.copy()
Expand Down
1 change: 1 addition & 0 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
PlainLayout,
SemiSparseLayout,
TensorCoreTiledLayout,
Int4CPULayout,
to_affine_quantized_floatx,
to_affine_quantized_floatx_static,
# experimental, will be merged into floatx in the future
Expand Down
229 changes: 229 additions & 0 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,11 @@ def extra_repr(self):
return f"inner_k_tiles={self.inner_k_tiles}"


@dataclass(frozen=True)
class Int4CPULayout(Layout):
def pre_process(self, input: torch.Tensor) -> torch.Tensor:
return input

@dataclass(frozen=True)
class Float8Layout(Layout):
mm_config: Optional[Float8MMConfig] = None
Expand Down Expand Up @@ -1616,6 +1621,230 @@ def get_layout(self) -> Layout:
return self._layout


@register_layout(Int4CPULayout)
class Int4CPUAQTTensorImpl(AQTTensorImpl):
"""
TensorImpl for int4 CPU layout for affine quantized tensor, this is for int4 only,
used by tinygemm kernels `_weight_int4pack_mm`
It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 2-d tensor of
dimension: [n][k / 2] (uint8 dtype)
(unpacked Tensor shape is n * k)
Note: we also pack scale and zero point together here for tinygemm kernel
Note: technically Int4 CPU layout should be the layout for the underlying packed weight
(int Tensor) but since the scale and zero_point are also packed into the same tensor here which is not used
in plain layout, we just created a layout for AQT right now, this could be improved if we split out
int4 aqt into a separate tensor subclass
fields:
packed_weight (torch.Tensor): the 2-d packed tensor in a Int4 CPU layout
scale_and_zero (torch.Tensor): the combined scale Tensor used to map between floating point tensor to quantized tensor and zero_point Tensor
"""

def __new__(
cls,
packed_weight: torch.Tensor,
scale_and_zero: torch.Tensor,
transposed: bool,
_layout: Layout,
):
kwargs = {}
kwargs["device"] = packed_weight.device
kwargs["layout"] = (
kwargs.get("layout")
if kwargs.get("layout", False)
else packed_weight.layout
)
kwargs["dtype"] = packed_weight.dtype
kwargs["requires_grad"] = False
shape = packed_weight.shape
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(
self,
packed_weight: torch.Tensor,
scale_and_zero: torch.Tensor,
transposed: bool,
_layout: Layout,
):
self.packed_weight = packed_weight
self.scale_and_zero = scale_and_zero
self.transposed = False
self._layout = _layout

def __tensor_flatten__(self):
return ["packed_weight", "scale_and_zero"], [self.transposed, self._layout]

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
packed_weight, scale_and_zero = (
tensor_data_dict["packed_weight"],
tensor_data_dict["scale_and_zero"],
)
(
transposed,
_layout,
) = tensor_attributes
return cls(packed_weight, scale_and_zero, transposed, _layout)

@classmethod
def from_plain(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
_layout: Layout,
):
assert isinstance(_layout, Int4CPULayout)

assert (
int_data.dtype == torch.int32
), "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype"
packed_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
int_data, 1 # TODO:remove
)
scale = scale.reshape(int_data.shape[0], -1)
zero_point = zero_point.reshape(int_data.shape[0], -1)

scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point)
return cls(packed_weight, scale_and_zero, False, _layout)

def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
device = kwargs["device"]
return self.__class__(
self.packed_weight.to(device),
self.scale_and_zero.to(device),
self.transposed,
self._layout,
)

def _apply_fn_to_data(self, fn):
# self.packed_weight = fn(self.packed_weight)
# self.scale_and_zero = fn(self.scale_and_zero)
# return self
return self.__class__(
fn(self.packed_weight),
fn(self.scale_and_zero),
self.transposed,
self._layout,
)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
kwargs = {} if kwargs is None else kwargs

if func is aten.detach.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)

if func is aten.clone.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)

if func is aten.t.default:
"""we don't need to repack the weight and just rely on external
shape being changed and record the status of transpose/no-transpose
"""
transposed = Int4CPUAQTTensorImpl(
args[0].packed_weight,
args[0].scale_and_zero,
not args[0].transposed,
args[0]._layout,
)
return return_and_correct_aliasing(func, args, kwargs, transposed)

if func is aten.slice.Tensor:
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
if dim == 0:
int_data, scale, zero_point = self.get_plain()
int_data = aten.slice.Tensor(int_data, dim, start, end, step)
# this is to handle padding
int_data = self._layout.post_process(int_data)
sliced = self.from_plain(int_data, scale, zero_point, self._layout)
return return_and_correct_aliasing(func, args, kwargs, sliced)
elif dim == 1:
int_data, scale, zero_point = self.get_plain()
assert step == 1, "Only step == 1 is supported in slicing right now"
data_len = int_data.shape[dim]
scale_len = scale.shape[dim]
ratio = data_len / scale_len
start_scale = int(start / ratio)
end_scale = int(end / ratio)

int_data = aten.slice.Tensor(int_data, dim, start, end, step)
# this is to handle padding
int_data = self._layout.post_process(int_data)
scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step)
zero_point = aten.slice.Tensor(
zero_point, dim, start_scale, end_scale, step
)
sliced = self.from_plain(int_data, scale, zero_point, self._layout)
return sliced
else:
raise NotImplementedError(
f"Int4CPUAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported"
)

raise NotImplementedError(
f"Int4CPUAQTTensorImpl dispatch: attempting to run {func}, this is not supported"
)

__torch_function__ = torch._C._disabled_torch_function_impl

def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
from torchao.quantization.quant_primitives import (
ZeroPointDomain,
quantize_affine,
)
from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros

scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero)

cur_shape = self.shape
assert len(cur_shape) == 2
original_shape = (cur_shape[0], cur_shape[1] * 2)
eye_shape = original_shape[1]
groupsize = int(original_shape[1] / scale.shape[-2])
block_size = (1, groupsize)
device = self.device
original_dtype = torch.bfloat16
target_dtype = torch.int32
quant_min = 0
quant_max = 15
zero_point_domain = ZeroPointDomain.FLOAT
assert len(block_size) == 2 and block_size[0] == 1
dequantized = torch.ops.aten._weight_int4pack_mm_for_cpu(
torch.eye(eye_shape, device=device, dtype=original_dtype),
self.packed_weight,
groupsize,
self.scale_and_zero,
)
dequantized = dequantized.t().contiguous()
# TODO: move this to `unpack_tinygemm_scales_and_zeros`?
scale = scale.reshape(scale.shape[:-1]).contiguous()
zero = zero.reshape(zero.shape[:-1]).contiguous()
int_data = quantize_affine(
dequantized,
block_size,
scale,
zero,
target_dtype,
quant_min,
quant_max,
zero_point_domain,
)
return int_data, scale, zero

def get_layout(self) -> Layout:
return self._layout

#####################################################
# torch functional and aten operator implementation #
#####################################################
Expand Down
25 changes: 18 additions & 7 deletions torchao/quantization/subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,12 +458,20 @@ def _quantized_op(act_mat, w_qtensor, bias):
act_mat = torch.nn.functional.pad(act_mat, (0, pad_size - act_mat.shape[-1]))

# matmul
y = aten._weight_int4pack_mm(
act_mat.contiguous(),
w_qtensor.int_data,
w_qtensor.groupsize,
w_qtensor.scales_and_zeros,
)
if act_mat.device == torch.device("cpu"):
y = aten._weight_int4pack_mm_for_cpu(
act_mat.contiguous(),
w_qtensor.int_data,
w_qtensor.groupsize,
w_qtensor.scales_and_zeros,
)
else:
y = aten._weight_int4pack_mm(
act_mat.contiguous(),
w_qtensor.int_data,
w_qtensor.groupsize,
w_qtensor.scales_and_zeros,
)

# remove out_feature padding
orig_out_features = (
Expand Down Expand Up @@ -609,5 +617,8 @@ def to_qtensor_components(cls, input_float, groupsize=128, inner_k_tiles=8):
input_int4x8, scales_and_zeros = groupwise_affine_quantize_tensor(
input_float, 4, groupsize, dtype=input_float.dtype
)
int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles)
if input_float.device == torch.device("cpu"):
int_data = aten._convert_weight_to_int4pack_for_cpu(input_int4x8, inner_k_tiles)
else:
int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles)
return int_data, scales_and_zeros, False, groupsize, inner_k_tiles
Loading

0 comments on commit 1b26f26

Please sign in to comment.