Skip to content

Commit

Permalink
Update int4 weight with serialized format
Browse files Browse the repository at this point in the history
  • Loading branch information
yanbing-j committed Oct 8, 2024
1 parent 32971d3 commit 222ec25
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
7 changes: 7 additions & 0 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,13 @@ def _load_model(checkpoint_path, device, precision, use_tp):
apply_tp(model)

model = model.to(device=device, dtype=precision)
if "int4" in str(checkpoint_path):
from quantize import WeightOnlyInt4Linear
for fqn, mod in model.named_modules():
if isinstance(mod, WeightOnlyInt4Linear):
weight = mod.weight.data
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight, mod.inner_k_tiles)
mod.weight = weight_int4pack
return model.eval()

def _get_model_size(model):
Expand Down
11 changes: 5 additions & 6 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128)
.to(torch.int32)
.reshape_as(w)
)

return w_int32
w_uint8 = (w_int32[::,::2] << 4 | w_int32[::,1::2]).to(torch.uint8)
return w_uint8


def group_quantize_tensor(w, n_bit=4, groupsize=128):
Expand Down Expand Up @@ -357,10 +357,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
##### weight only int4 per channel groupwise quantized code ######

def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
weight_int32, scales_and_zeros = group_quantize_tensor(
weight_int4pack, scales_and_zeros = group_quantize_tensor(
weight_bf16, n_bit=4, groupsize=groupsize
)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles)
return weight_int4pack, scales_and_zeros


Expand Down Expand Up @@ -404,7 +403,7 @@ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):

@torch.no_grad()
def create_quantized_state_dict(self, use_cuda = True):
if use_cuda:
if use_cuda and torch.cuda.is_available():
device="cuda"
else:
device="cpu"
Expand Down Expand Up @@ -507,7 +506,7 @@ def __init__(
assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
self.register_buffer(
"weight",
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
torch.empty((out_features, in_features // 2), dtype=torch.uint8)
)
self.register_buffer(
"scales_and_zeros",
Expand Down

0 comments on commit 222ec25

Please sign in to comment.