Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (llm): export to MatMulNBits #733

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions src/brevitas/export/onnx/standard/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,43 @@
import onnx
import torch
from torch.autograd import Function
from torch.onnx.symbolic_helper import _get_tensor_sizes

from brevitas.export.onnx import onnx_export_opset


class MatMulNBitsFn(Function):

@staticmethod
def symbolic(g, x, int_weights, scales, zero_points, K, N, bits, block_size):
ret = g.op(
'com.microsoft::MatMulNBits',
x,
int_weights,
scales,
zero_points,
K_i=K,
N_i=N,
bits_i=bits,
block_size_i=block_size)
output_size = _get_tensor_sizes(x)
output_size[-1] = N
ret.setType(x.type().with_sizes(output_size))
return ret

@staticmethod
def forward(g, x, int_weights, scales, zero_points, K, N, bits, block_size):
dtype = x.dtype
device = x.device
shape = x.shape
out_shape = list(shape)
out_shape[-1] = N
# Only tensor metadata (shape, dtype, device) are preserved in the forward pass during
# tracing, not the correct value
out = torch.empty(out_shape, dtype=dtype, device=device)
return out


AXIS_OPSET = 13

DATATYPE_DICT = {
Expand Down
1 change: 0 additions & 1 deletion src/brevitas/nn/quant_avg_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def forward(self, input: Union[Tensor, QuantTensor]):
# shortcut execution through the export impl during export
if self.export_mode:
out = self.export_handler(_unpack_quant_tensor(x))
self._set_global_is_quant_layer(False)
return out

if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled:
Expand Down
2 changes: 0 additions & 2 deletions src/brevitas/nn/quant_eltwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def forward(self, input: Union[Tensor, QuantTensor],
if self.export_mode:
assert self.cache_quant_io_metadata_only, "Can't cache multiple inputs"
out = self.export_handler(inp=input.value, other=other.value)
self._set_global_is_quant_layer(False)
return out
quant_input = self.input_quant(input)
quant_other = self.input_quant(other)
Expand Down Expand Up @@ -70,7 +69,6 @@ def forward(self,
# shortcut execution through the export impl during export
if self.export_mode:
out = self.export_handler([qt.value for qt in quant_tensor_list])
self._set_global_is_quant_layer(False)
return out
quant_tensor_list = [self.input_quant(qt) for qt in quant_tensor_list]
# trigger an assert if scale factors and bit widths are None or different
Expand Down
2 changes: 0 additions & 2 deletions src/brevitas/nn/quant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def forward(self, input: Union[Tensor, QuantTensor]):
# shortcut execution through the export impl during export
if self.export_mode:
out = self.export_handler(quant_input)
self._set_global_is_quant_layer(False)
return out
out = self.act_quant(quant_input)
out = self.pack_output(out)
Expand Down Expand Up @@ -139,7 +138,6 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
# shortcut execution through the export impl during export
if self.export_mode:
out = self.export_handler(inp)
self._set_global_is_quant_layer(False)
return out

quant_input = self.input_quant(inp)
Expand Down
3 changes: 0 additions & 3 deletions src/brevitas/nn/quant_upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def forward(self, input: Union[Tensor, QuantTensor]):
x = self.unpack_input(input)
if self.export_mode:
out = self.export_handler(x.value)
self._set_global_is_quant_layer(False)
return out
y_value = interpolate(x.value, self.size, self.scale_factor, self.mode, self.align_corners)
if self.mode != 'nearest':
Expand Down Expand Up @@ -69,7 +68,6 @@ def forward(self, input: Union[Tensor, QuantTensor]):
x = self.unpack_input(input)
if self.export_mode:
out = self.export_handler(x.value)
self._set_global_is_quant_layer(False)
return out
y_value = interpolate(x.value, self.size, self.scale_factor, self.mode, self.align_corners)
# round interpolated values to scale
Expand Down Expand Up @@ -97,7 +95,6 @@ def forward(self, input: Union[Tensor, QuantTensor]):
x = self.unpack_input(input)
if self.export_mode:
out = self.export_handler(x.value)
self._set_global_is_quant_layer(False)
return out
y_value = interpolate(x.value, self.size, self.scale_factor, self.mode, self.align_corners)
y = x.set(value=y_value)
Expand Down
151 changes: 109 additions & 42 deletions src/brevitas_examples/llm/llm_quant/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,21 @@

import numpy as np
import torch
from torch.nn import Module
from torch.onnx import register_custom_op_symbolic

from brevitas.export.common.handler.base import BaseHandler
from brevitas.export.manager import _set_layer_export_handler
from brevitas.export.manager import _set_layer_export_mode
from brevitas.export.manager import _set_proxy_export_handler
from brevitas.export.manager import _set_proxy_export_mode
from brevitas.export.manager import BaseManager
from brevitas.export.onnx.handler import ONNXBaseHandler
from brevitas.export.onnx.standard.function import MatMulNBitsFn
from brevitas.function.ops import max_int
from brevitas.function.ops import min_int
from brevitas.nn import QuantLinear
from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector


Expand Down Expand Up @@ -52,27 +57,6 @@ def __init__(self):
self.bit_width = None
self.dtype = None

def scaling_impl(self, proxy_module):
return proxy_module.tensor_quant.scaling_impl

def zero_point_impl(self, proxy_module):
return proxy_module.tensor_quant.zero_point_impl

def bit_width_impl(self, proxy_module):
return proxy_module.tensor_quant.msb_clamp_bit_width_impl

def export_scale(self, proxy_module, bit_width):
scaling_impl = self.scaling_impl(proxy_module)
int_scaling_impl = proxy_module.tensor_quant.int_scaling_impl
int_threshold = int_scaling_impl(bit_width)
threshold = scaling_impl.wrapped_scaling_impl.stats_scaling_impl(
scaling_impl.wrapped_scaling_impl.parameter_list_stats())
return threshold / int_threshold

def export_zero_point(self, proxy_module, scale, bit_width):
zero_point_impl = self.zero_point_impl(proxy_module)
return zero_point_impl.unexpanded_zero_point(scale, bit_width)

@abstractmethod
def prepare_for_export(self, module):
pass
Expand All @@ -83,6 +67,7 @@ def forward(self, x):


class WeightBlockQuantProxyHandler(WeightBlockQuantHandlerBase):
handled_layer = GroupwiseWeightQuantProxyFromInjector

def __init__(self):
super().__init__()
Expand All @@ -93,20 +78,18 @@ def __init__(self):

def prepare_for_export(self, module):
assert len(module.tracked_module_list) == 1, "Shared quantizers not supported."
self.bit_width = self.bit_width_impl(module)()
assert self.bit_width <= 8., "Only 8b or lower is supported."
quant_layer = module.tracked_module_list[0]
quant_weight = quant_layer.quant_weight()
self.bit_width = quant_weight.bit_width
assert self.bit_width <= 8., "Only 8b or lower is supported."
signed = module.is_signed
self.int_dtype = torch.int8 if signed else torch.uint8
self.dtype = quant_weight.value.dtype
self.scale = self.export_scale(module, self.bit_width).detach()
self.expanded_groupwise_shape = self.scaling_impl(module).expanded_groupwise_shape
self.reshaped_groupwise_shape = self.scaling_impl(module).reshaped_groupwise_shape
self.scale = quant_weight.scale_
self.expanded_scaling_shape = quant_weight.value_.shape
self.reshaped_scaling_shape = quant_weight.value.shape
if (quant_weight.zero_point != 0.).any():
self.zero_point = self.export_zero_point(module, self.scale, self.bit_width).detach()
self.expanded_zero_point_shape = self.zero_point_impl(module).expanded_zero_point_shape
self.reshaped_zero_point_shape = self.zero_point_impl(module).reshaped_zero_point_shape
self.zero_point = quant_weight.zero_point_
else:
self.zero_point = None

Expand All @@ -131,15 +114,9 @@ def forward(self, x):
x = (x.type(self.dtype) - zero_point) * scale

# Fix shape post quantization
scale = scale.expand(self.expanded_groupwise_shape).contiguous().view(
self.reshaped_groupwise_shape)
# If zero_point is not defined, propagate same shape as scale
if self.zero_point is None:
zero_point = torch.zeros_like(scale).type(self.int_dtype)
else:
zero_point = zero_point.expand(self.expanded_zero_point_shape).contiguous().view(
self.reshaped_zero_point_shape).type(self.int_dtype)
x = x.view(self.reshaped_groupwise_shape)

return x, scale, zero_point, bit_width

Expand Down Expand Up @@ -208,18 +185,17 @@ def lcm(x, y):
raise ValueError(f"Bit width {bit_width} not supported.")

def prepare_for_export(self, module):
self.bit_width = self.bit_width_impl(module.weight_quant)()
assert self.bit_width <= 8., "Only 8b or lower is supported."
quant_weight = module.quant_weight()
self.bit_width = quant_weight.bit_width
assert self.bit_width <= 8., "Only 8b or lower is supported."
self.bias = module.bias
self.scale = self.export_scale(module.weight_quant, self.bit_width)
self.scale = quant_weight.scale_
if (quant_weight.zero_point != 0.).any():
self.zero_point = self.export_zero_point(
module.weight_quant, self.scale, self.bit_width)
self.zero_point = quant_weight.zero_point_
else:
# if there is no zero-point, export zeroes in the shape of scale
self.zero_point = torch.zeros_like(self.scale)
self.group_size = module.weight_quant.quant_injector.block_size
self.group_size = quant_weight.group_size
self.bit_width = int(self.bit_width.cpu().item())
self.int_weight = self.pack_int_weights(self.bit_width, quant_weight.int().detach())

Expand All @@ -237,10 +213,12 @@ def set_export_handler(cls, module):
_set_proxy_export_handler(cls, module)


def block_quant_layer_level_manager(export_handlers):
def block_quant_layer_level_manager(export_handlers, target=None, custom_fns_to_register=None):

class BlockQuantLayerLevelManager(BaseManager):
handlers = export_handlers
target_name = '' if target is None else target
custom_fns = [] if custom_fns_to_register is None else custom_fns_to_register

@classmethod
def set_export_handler(cls, module):
Expand Down Expand Up @@ -281,3 +259,92 @@ def replace_call_fn_target(graph_model, src, target):
node.target = target
graph_model.graph.lint()
graph_model.recompile()


class ONNXLinearWeightBlockQuantHandlerFwd(ONNXBaseHandler, WeightBlockQuantHandlerBase):
handled_layer = QuantLinear

def __init__(self):
super(ONNXLinearWeightBlockQuantHandlerFwd, self).__init__()
self.group_size = None

def pack_int_weights(self, bit_width, int_weights, zero_point):
assert int_weights.dtype in [torch.uint8, torch.int8], "Packing requires (u)int8 input."
assert bit_width == 4, "Only 4 bit quantization export is supported at the moment"

is_symmetric = torch.sum(zero_point) == 0
zero_point = zero_point.to(torch.uint8)
rows, cols = int_weights.shape
group_size = self.group_size
blob_size = group_size // 2
k_blocks = (rows + group_size - 1) // group_size
padded_rows = k_blocks * group_size
pad_len = padded_rows - rows

# ONNX operator assumes implicit zp of 8 (largest negative number in Po2)
# If we are in a "symmetric" quantized scenario, we need to add this implicit zero point
# Otherwise it has already been added during the convesion to integer.
# This allows to pack weights always in unsigned integer.
zp = 0 if not int_weights.dtype == torch.int8 else 8
int_weights += zp
if pad_len > 0:
int_weights = torch.nn.functional(int_weights, (0, 0, 0, pad_len))
packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8")
rows, cols = int_weights.shape
int_weights = int_weights.t()
for n in range(cols):
for k_id in range(0, rows, group_size):
blk_int0 = (int_weights[n, k_id:k_id + group_size:2].numpy()).astype("uint8")
blk_int1 = (int_weights[n, k_id + 1:k_id + group_size:2].numpy()).astype("uint8")
packed[n, k_id // group_size] = np.bitwise_or(blk_int0, np.left_shift(blk_int1, 4))

zero_point = zero_point.to(torch.uint8).flatten()

# The constant value 136 is derived from the source code in ORT test suite.
# https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py
base_zp = 136 if is_symmetric else 0
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
packed_zp = base_zp * torch.ones(
(zero_point.shape[0] + 1) // 2, device=int_weights.device, dtype=torch.uint8)

i = 0
for column in range(packed_zp.shape[0]):
for j in range(i, i + (8 // bit_width)):
shift_factor = (bit_width * (j - i))
packed_zp[column] |= zero_point[j] << shift_factor
i += 8 // bit_width
return torch.tensor(packed), packed_zp

def prepare_for_export(self, module):
quant_weight = module.quant_weight()
self.bit_width = quant_weight.bit_width
assert self.bit_width <= 8., "Only 8b or lower is supported."
self.bias = module.bias
self.scale = quant_weight.scale_
if (quant_weight.zero_point != 0.).any():
self.zero_point = quant_weight.zero_point_
else:
# if there is no zero-point, export zeroes in the shape of scale
self.zero_point = torch.zeros_like(self.scale)
self.group_size = module.weight_quant.quant_injector.group_size
self.bit_width = int(self.bit_width.cpu().item())
self.int_weight, self.zero_point = self.pack_int_weights(self.bit_width, quant_weight.int().t().detach(), self.zero_point)
self.weight_shape = module.weight.shape

def symbolic_execution(self, x):
int_weights = self.int_weight
scale = self.scale
bit_width = self.bit_width
N, K = self.weight_shape
out = MatMulNBitsFn.apply(
x, int_weights, scale.flatten(), self.zero_point, K, N, bit_width, self.group_size)
return out


def export_packed_onnx(model, input, export_path):
export_class = block_quant_layer_level_manager(
export_handlers=[ONNXLinearWeightBlockQuantHandlerFwd],
target='',
custom_fns_to_register=MatMulNBitsFn)

with torch.inference_mode(), brevitas_layer_export_mode(model, export_class):
torch.onnx.export(model, input, export_path)
Loading