From 222ec25266ba51d6f7e3183ae68e5449fbdf8062 Mon Sep 17 00:00:00 2001 From: "Jiang, Yanbing" Date: Mon, 1 Jul 2024 05:21:07 -0400 Subject: [PATCH] Update int4 weight with serialized format --- generate.py | 7 +++++++ quantize.py | 11 +++++------ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/generate.py b/generate.py index b7a4c113..18ed3e1c 100644 --- a/generate.py +++ b/generate.py @@ -246,6 +246,13 @@ def _load_model(checkpoint_path, device, precision, use_tp): apply_tp(model) model = model.to(device=device, dtype=precision) + if "int4" in str(checkpoint_path): + from quantize import WeightOnlyInt4Linear + for fqn, mod in model.named_modules(): + if isinstance(mod, WeightOnlyInt4Linear): + weight = mod.weight.data + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight, mod.inner_k_tiles) + mod.weight = weight_int4pack return model.eval() def _get_model_size(model): diff --git a/quantize.py b/quantize.py index fb566421..f73aec4b 100644 --- a/quantize.py +++ b/quantize.py @@ -124,8 +124,8 @@ def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128) .to(torch.int32) .reshape_as(w) ) - - return w_int32 + w_uint8 = (w_int32[::,::2] << 4 | w_int32[::,1::2]).to(torch.uint8) + return w_uint8 def group_quantize_tensor(w, n_bit=4, groupsize=128): @@ -357,10 +357,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: ##### weight only int4 per channel groupwise quantized code ###### def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles): - weight_int32, scales_and_zeros = group_quantize_tensor( + weight_int4pack, scales_and_zeros = group_quantize_tensor( weight_bf16, n_bit=4, groupsize=groupsize ) - weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles) return weight_int4pack, scales_and_zeros @@ -404,7 +403,7 @@ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): @torch.no_grad() def create_quantized_state_dict(self, use_cuda = True): - if use_cuda: + if use_cuda and torch.cuda.is_available(): device="cuda" else: device="cpu" @@ -507,7 +506,7 @@ def __init__( assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0" self.register_buffer( "weight", - torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32) + torch.empty((out_features, in_features // 2), dtype=torch.uint8) ) self.register_buffer( "scales_and_zeros",