Skip to content

Commit

Permalink
Feat (llm): export to MatMulNBits
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jul 29, 2024
1 parent 55fd0ea commit ab4ce9b
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 4 deletions.
34 changes: 34 additions & 0 deletions src/brevitas/export/onnx/standard/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,43 @@
import onnx
import torch
from torch.autograd import Function
from torch.onnx.symbolic_helper import _get_tensor_sizes

from brevitas.export.onnx import onnx_export_opset


class MatMulNBitsFn(Function):

@staticmethod
def symbolic(g, x, int_weights, scales, zero_points, K, N, bits, block_size):
ret = g.op(
'com.microsoft::MatMulNBits',
x,
int_weights,
scales,
zero_points,
K_i=K,
N_i=N,
bits_i=bits,
block_size_i=block_size)
output_size = _get_tensor_sizes(x)
output_size[-1] = N
ret.setType(x.type().with_sizes(output_size))
return ret

@staticmethod
def forward(g, x, int_weights, scales, zero_points, K, N, bits, block_size):
dtype = x.dtype
device = x.device
shape = x.shape
out_shape = list(shape)
out_shape[-1] = N
# Only tensor metadata (shape, dtype, device) are preserved in the forward pass during
# tracing, not the correct value
out = torch.empty(out_shape, dtype=dtype, device=device)
return out


AXIS_OPSET = 13

DATATYPE_DICT = {
Expand Down
107 changes: 103 additions & 4 deletions src/brevitas_examples/llm/llm_quant/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import numpy as np
import torch
from torch.nn import Module
from torch.onnx import register_custom_op_symbolic

from brevitas.export.common.handler.base import BaseHandler
from brevitas.export.manager import _set_layer_export_handler
Expand All @@ -19,6 +21,8 @@
from brevitas.export.manager import BaseManager
from brevitas.function.ops import max_int
from brevitas.function.ops import min_int
from brevitas.export.onnx.handler import ONNXBaseHandler
from brevitas.export.onnx.standard.function import MatMulNBitsFn
from brevitas.nn import QuantLinear
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector

Expand Down Expand Up @@ -65,8 +69,11 @@ def export_scale(self, proxy_module, bit_width):
scaling_impl = self.scaling_impl(proxy_module)
int_scaling_impl = proxy_module.tensor_quant.int_scaling_impl
int_threshold = int_scaling_impl(bit_width)
threshold = scaling_impl.wrapped_scaling_impl.stats_scaling_impl(
scaling_impl.wrapped_scaling_impl.parameter_list_stats())
if hasattr(scaling_impl, 'wrapped_scaling_impl'):
threshold = scaling_impl.wrapped_scaling_impl.stats_scaling_impl(
scaling_impl.wrapped_scaling_impl.parameter_list_stats())
else:
threshold = scaling_impl.stats_scaling_impl(scaling_impl.parameter_list_stats())
return threshold / int_threshold

def export_zero_point(self, proxy_module, scale, bit_width):
Expand Down Expand Up @@ -219,7 +226,7 @@ def prepare_for_export(self, module):
else:
# if there is no zero-point, export zeroes in the shape of scale
self.zero_point = torch.zeros_like(self.scale)
self.group_size = module.weight_quant.quant_injector.block_size
self.group_size = module.weight_quant.quant_injector.group_size
self.bit_width = int(self.bit_width.cpu().item())
self.int_weight = self.pack_int_weights(self.bit_width, quant_weight.int().detach())

Expand All @@ -237,10 +244,12 @@ def set_export_handler(cls, module):
_set_proxy_export_handler(cls, module)


def block_quant_layer_level_manager(export_handlers):
def block_quant_layer_level_manager(export_handlers, target=None, custom_fns_to_register=None):

class BlockQuantLayerLevelManager(BaseManager):
handlers = export_handlers
target_name = '' if target is None else target
custom_fns = [] if custom_fns_to_register is None else custom_fns_to_register

@classmethod
def set_export_handler(cls, module):
Expand Down Expand Up @@ -281,3 +290,93 @@ def replace_call_fn_target(graph_model, src, target):
node.target = target
graph_model.graph.lint()
graph_model.recompile()


class ONNXLinearWeightBlockQuantHandlerFwd(ONNXBaseHandler, WeightBlockQuantHandlerBase):
handled_layer = QuantLinear

def __init__(self):
super(ONNXLinearWeightBlockQuantHandlerFwd, self).__init__()
self.group_size = None

def pack_int_weights(self, bit_width, int_weights, zero_point):
assert int_weights.dtype in [torch.uint8, torch.int8], "Packing requires (u)int8 input."
assert bit_width == 4, "Only 4 bit quantization export is supported at the moment"

is_symmetric = torch.sum(zero_point) == 0
zero_point = zero_point.to(torch.uint8)
rows, cols = int_weights.shape
group_size = self.group_size
blob_size = group_size // 2
k_blocks = (rows + group_size - 1) // group_size
padded_rows = k_blocks * group_size
pad_len = padded_rows - rows

# ONNX operator assumes implicit zp of 8 (largest negative number in Po2)
# If we are in a "symmetric" quantized scenario, we need to add this implicit zero point
# Otherwise it has already been added during the convesion to integer.
# This allows to pack weights always in unsigned integer.
zp = 0 if not int_weights.dtype == torch.int8 else 8
int_weights += zp
if pad_len > 0:
int_weights = torch.nn.functional(int_weights, (0, 0, 0, pad_len))
packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8")
rows, cols = int_weights.shape
int_weights = int_weights.t()
for n in range(cols):
for k_id in range(0, rows, group_size):
blk_int0 = (int_weights[n, k_id:k_id + group_size:2].numpy()).astype("uint8")
blk_int1 = (int_weights[n, k_id + 1:k_id + group_size:2].numpy()).astype("uint8")
packed[n, k_id // group_size] = np.bitwise_or(blk_int0, np.left_shift(blk_int1, 4))

zero_point = zero_point.to(torch.uint8).flatten()

# The constant value 136 is derived from the source code in ORT test suite.
# https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py
base_zp = 136 if is_symmetric else 0
packed_zp = base_zp * torch.ones(
(zero_point.shape[0] + 1) // 2, device=int_weights.device, dtype=torch.uint8)

i = 0
for column in range(packed_zp.shape[0]):
for j in range(i, i + (8 // bit_width)):
shift_factor = (bit_width * (j - i))
packed_zp[column] |= zero_point[j] << shift_factor
i += 8 // bit_width
return torch.tensor(packed), packed_zp

def prepare_for_export(self, module):
self.bit_width = self.bit_width_impl(module.weight_quant)()
assert self.bit_width <= 8., "Only 8b or lower is supported."
quant_weight = module.quant_weight()
self.bias = module.bias
self.scale = self.export_scale(module.weight_quant, self.bit_width)
if (quant_weight.zero_point != 0.).any():
self.zero_point = self.export_zero_point(
module.weight_quant, self.scale, self.bit_width)
else:
# if there is no zero-point, export zeroes in the shape of scale
self.zero_point = torch.zeros_like(self.scale)
self.group_size = module.weight_quant.quant_injector.group_size
self.bit_width = int(self.bit_width.cpu().item())
self.int_weight, self.zero_point = self.pack_int_weights(self.bit_width, quant_weight.int().t().detach(), self.zero_point)
self.weight_shape = module.weight.shape

def symbolic_execution(self, x):
int_weights = self.int_weight
scale = self.scale
bit_width = self.bit_width
N, K = self.weight_shape
out = MatMulNBitsFn.apply(
x, int_weights, scale.flatten(), self.zero_point, K, N, bit_width, self.group_size)
return out


def export_packed_onnx(model, input, export_path):
export_class = block_quant_layer_level_manager(
export_handlers=[ONNXLinearWeightBlockQuantHandlerFwd],
target='',
custom_fns_to_register=MatMulNBitsFn)

with torch.inference_mode(), brevitas_layer_export_mode(model, export_class):
torch.onnx.export(model, input, export_path)
3 changes: 3 additions & 0 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@
choices=[
None,
'onnx_qcdq',
'packed_onnx',
'torch_qcdq',
'sharded_torchmlir_group_weight',
'sharded_packed_torchmlir_group_weight'],
Expand All @@ -190,6 +191,8 @@ def model_export(model, ref_input, args):
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import \
sharded_weight_group_export
sharded_weight_group_export(model, no_custom_packed_export=False)
elif args.export_target == 'packed_onnx':
export_packed_onnx(model, ref_input, export_path=f"{args.model.replace('/', '-')}.onnx")
elif args.export_target == 'onnx_qcdq':
if args.weight_quant_granularity == 'per_group':
export_manager = BlockQuantProxyLevelManager
Expand Down

0 comments on commit ab4ce9b

Please sign in to comment.