Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (examples/llm): add custom float support #708

Merged
merged 2 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/brevitas/quant/experimental/float_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,9 @@ class ScaledFloatWeightBase(FloatWeightBase, WeightQuantSolver):


class ScaledFloatActBase(FloatActBase, ActQuantSolver):
scaling_stats_op = 'percentile'
scaling_stats_op = 'max'
scaling_impl_type = 'parameter_from_stats'
restrict_scaling_type = 'fp'
high_percentile_q = 99.999
collect_stats_steps = 300
float_scaling_impl = FloatScaling

Expand Down
170 changes: 111 additions & 59 deletions src/brevitas_examples/llm/llm_quant/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@
Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
"""
import re

from torch import nn

from brevitas import nn as qnn
from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint
from brevitas.graph.quantize import layerwise_quantize
from brevitas.quant.experimental.float import Fp8e4m3Act
from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat
from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat
from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloat
from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint
from brevitas.quant.fixed_point import Int8ActPerTensorFixedPointMSE
from brevitas.quant.fixed_point import Int8WeightPerChannelFixedPoint
Expand All @@ -26,6 +31,7 @@
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatMSE
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE
from brevitas_examples.llm.llm_quant.quantizers import Fp8e4m3WeightSymmetricGroupQuant
from brevitas_examples.llm.llm_quant.quantizers import Int8ActDynamicPerGroupFloat
from brevitas_examples.llm.llm_quant.quantizers import Int8ActDynamicPerRowFloat
from brevitas_examples.llm.llm_quant.quantizers import Int8ActDynamicPerTensorFloat
Expand All @@ -37,62 +43,82 @@
from brevitas_examples.llm.llm_quant.quantizers import ShiftedUintWeightAsymmetricGroupQuant

WEIGHT_QUANT_MAP = {
'float': {
'stats': {
'per_tensor': {
'sym': Int8WeightPerTensorFloat, 'asym': ShiftedUint8WeightPerTensorFloat},
'per_channel': {
'sym': Int8WeightPerChannelFloat, 'asym': ShiftedUint8WeightPerChannelFloat},
'per_group': {
'sym': IntWeightSymmetricGroupQuant, 'asym': ShiftedUintWeightAsymmetricGroupQuant},
},
'mse': {
'per_tensor': {
'sym': Int8WeightPerTensorFloatMSE, 'asym': ShiftedUint8WeightPerTensorFloatMSE},
'per_channel': {
'sym': Int8WeightPerChannelFloatMSE, 'asym': ShiftedUint8WeightPerChannelFloatMSE},
},},
'po2': {
'stats': {
'per_tensor': {
'sym': Int8WeightPerTensorFixedPoint},
'per_channel': {
'sym': Int8WeightPerChannelFixedPoint},},
'mse': {
'per_tensor': {
'sym': Int8WeightPerTensorFixedPointMSE},
'per_channel': {
'sym': Int8WeightPerChannelFixedPointMSE},},}}

INPUT_QUANT_MAP = {
'static': {
'float': {
'int': {
'float_scale': {
'stats': {
'per_tensor': {
'sym': Int8ActPerTensorFloat, 'asym': ShiftedUint8ActPerTensorFloat},
'per_row': {
'sym': Int8ActPerRowFloat, 'asym': ShiftedUint8ActPerRowFloat},},
'sym': Int8WeightPerTensorFloat, 'asym': ShiftedUint8WeightPerTensorFloat},
'per_channel': {
'sym': Int8WeightPerChannelFloat, 'asym': ShiftedUint8WeightPerChannelFloat},
'per_group': {
'sym': IntWeightSymmetricGroupQuant,
'asym': ShiftedUintWeightAsymmetricGroupQuant},},
'mse': {
'per_tensor': {
'sym': Int8ActPerTensorFloatMSE, 'asym': ShiftedUint8ActPerTensorFloatMSE},
'per_row': {
'sym': Int8ActPerRowFloatMSE, 'asym': ShiftedUint8ActPerRowFloatMSE},},},
'po2': {
'sym': Int8WeightPerTensorFloatMSE,
'asym': ShiftedUint8WeightPerTensorFloatMSE},
'per_channel': {
'sym': Int8WeightPerChannelFloatMSE,
'asym': ShiftedUint8WeightPerChannelFloatMSE},},},
'po2_scale': {
'stats': {
'per_tensor': {
'sym': Int8ActPerTensorFixedPoint},},
'sym': Int8WeightPerTensorFixedPoint},
'per_channel': {
'sym': Int8WeightPerChannelFixedPoint},},
'mse': {
'per_tensor': {
'sym': Int8ActPerTensorFixedPointMSE},},}},
'dynamic': {
'float': {
'sym': Int8WeightPerTensorFixedPointMSE},
'per_channel': {
'sym': Int8WeightPerChannelFixedPointMSE},},}},
'float': {
'float_scale': {
'stats': {
'per_tensor': {
'sym': Int8ActDynamicPerTensorFloat},
'per_row': {
'sym': Int8ActDynamicPerRowFloat},
'sym': Fp8e4m3WeightPerTensorFloat},
'per_channel': {
'sym': Fp8e4m3WeightPerChannelFloat},
'per_group': {
'sym': Int8ActDynamicPerGroupFloat},}}}}
'sym': Fp8e4m3WeightSymmetricGroupQuant}},}}}

INPUT_QUANT_MAP = {
'int': {
'static': {
'float_scale': {
'stats': {
'per_tensor': {
'sym': Int8ActPerTensorFloat, 'asym': ShiftedUint8ActPerTensorFloat},
'per_row': {
'sym': Int8ActPerRowFloat, 'asym': ShiftedUint8ActPerRowFloat},},
'mse': {
'per_tensor': {
'sym': Int8ActPerTensorFloatMSE, 'asym': ShiftedUint8ActPerTensorFloatMSE},
'per_row': {
'sym': Int8ActPerRowFloatMSE, 'asym': ShiftedUint8ActPerRowFloatMSE},},},
'po2_scale': {
'stats': {
'per_tensor': {
'sym': Int8ActPerTensorFixedPoint},},
'mse': {
'per_tensor': {
'sym': Int8ActPerTensorFixedPointMSE},},}},
'dynamic': {
'float_scale': {
'stats': {
'per_tensor': {
'sym': Int8ActDynamicPerTensorFloat},
'per_row': {
'sym': Int8ActDynamicPerRowFloat},
'per_group': {
'sym': Int8ActDynamicPerGroupFloat},}}}},
'float': {
'static': {
'float_scale': {
'stats': {
'per_tensor': {
'sym': Fp8e4m3ActPerTensorFloat},}}},
'no_scale': {
'sym': Fp8e4m3Act,}}}


def quantize_model(
Expand All @@ -105,7 +131,9 @@ def quantize_model(
weight_quant_granularity,
weight_group_size,
quantize_weight_zero_point,
weight_quant_format='int',
input_bit_width=None,
input_quant_format=None,
input_scale_precision=None,
input_scale_type=None,
input_param_method=None,
Expand All @@ -119,18 +147,38 @@ def quantize_model(
Replace float layers with quant layers in the target model
"""
# Retrive base input and weight quantizers
weight_quant = WEIGHT_QUANT_MAP[weight_scale_precision][weight_param_method][
weight_quant_granularity][weight_quant_type]
if input_bit_width is not None:
input_quant = INPUT_QUANT_MAP[input_scale_type][input_scale_precision][input_param_method][
input_quant_granularity][input_quant_type]

# match against custom float format
if re.compile(r'e[1-8]m[1-8]').match(weight_quant_format):
weight_float_format = {
'exponent_bit_width': int(weight_quant_format[1]),
'mantissa_bit_width': int(weight_quant_format[3])}
weight_quant_format = 'float'
else:
weight_float_format = {}
if re.compile(r'e[1-8]m[1-8]').match(input_quant_format):
input_float_format = {
'exponent_bit_width': int(input_quant_format[1]),
'mantissa_bit_width': int(input_quant_format[3])}
input_quant_format = 'float'
else:
input_float_format = {}

weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][weight_scale_precision][
weight_param_method][weight_quant_granularity][weight_quant_type]
if input_bit_width is not None and input_scale_type == 'no_scale':
input_quant = sym_input_quant = linear_2d_input_quant = INPUT_QUANT_MAP[input_quant_format][
input_scale_type][input_quant_type]
elif input_bit_width is not None:
input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][input_scale_precision][
input_param_method][input_quant_granularity][input_quant_type]
# Some activations in MHA should always be symmetric
sym_input_quant = INPUT_QUANT_MAP[input_scale_type][input_scale_precision][
input_param_method][input_quant_granularity]['sym']
sym_input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][
input_scale_precision][input_param_method][input_quant_granularity]['sym']
# Linear layers with 2d input should always be per tensor or per group, as there is no row dimension
if input_quant_granularity == 'per_tensor' or input_quant_granularity == 'per_row':
linear_2d_input_quant = INPUT_QUANT_MAP[input_scale_type][input_scale_precision][
input_param_method]['per_tensor'][input_quant_type]
linear_2d_input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][
input_scale_precision][input_param_method]['per_tensor'][input_quant_type]
else:
assert input_quant_granularity == 'per_group'
linear_2d_input_quant = input_quant
Expand All @@ -145,7 +193,8 @@ def quantize_model(
'bit_width': weight_bit_width,
'narrow_range': False,
'block_size': weight_group_size,
'quantize_zero_point': quantize_weight_zero_point})
'quantize_zero_point': quantize_weight_zero_point},
**weight_float_format)
# weight scale is converted to a standalone parameter
# This is done already by default in the per_group quantizer
if weight_quant_granularity != 'per_group':
Expand All @@ -161,7 +210,8 @@ def quantize_model(
**{
'bit_width': input_bit_width,
'quantize_zero_point': quantize_input_zero_point,
'dtype': dtype})
'dtype': dtype,},
**input_float_format)
if input_scale_type == 'static' and input_quant_granularity == 'per_row':
# QuantMHA internally always uses Seq, B, E
input_quant = input_quant.let(
Expand All @@ -188,7 +238,8 @@ def quantize_model(
**{
'bit_width': input_bit_width,
'quantize_zero_point': quantize_input_zero_point,
'dtype': dtype})
'dtype': dtype},
**input_float_format)
if input_scale_type == 'static' and input_quant_granularity == 'per_row':
q_scaled_quant = sym_input_quant.let(
**{
Expand Down Expand Up @@ -241,7 +292,8 @@ def quantize_model(
**{
'bit_width': input_bit_width,
'quantize_zero_point': quantize_input_zero_point,
'dtype': dtype})
'dtype': dtype},
**input_float_format)
if input_scale_type == 'dynamic':
# Note: this breaks if applied to 3d Linear inputs,
# in case standard MHA layers haven't been inserted
Expand All @@ -265,7 +317,7 @@ def quantize_model(
'in_proj_bias_quant': None,
'softmax_input_quant': None,
'attn_output_weights_quant': attn_output_weights_quant,
'attn_output_weights_signed': False,
'attn_output_weights_signed': input_quant_format == 'float',
'q_scaled_quant': q_scaled_quant,
'k_transposed_quant': k_transposed_quant,
'v_quant': v_quant,
Expand Down
27 changes: 21 additions & 6 deletions src/brevitas_examples/llm/llm_quant/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
from brevitas.core.stats import NegativePercentileOrZero
from brevitas.core.zero_point import ParameterFromRuntimeZeroPoint
from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint
from brevitas.inject import ExtendedInjector
from brevitas.inject import this
from brevitas.inject import value
from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
from brevitas.quant.scaled_int import Int8ActPerTensorFloatMSE
from brevitas.quant.scaled_int import Int8WeightPerChannelFloat
Expand All @@ -25,11 +27,7 @@
from .quant_blocks import *


class IntWeightSymmetricGroupQuant(Int8WeightPerChannelFloat):
"""
Block / group / vector signed symmetric weight quantizer with float scales.
We inherit from a per-channel quantizer to re-use some underlying machinery.
"""
class WeightSymmetricGroupQuantMixin(ExtendedInjector):

@value
def expanded_scaling_shape(module, block_size):
Expand Down Expand Up @@ -69,6 +67,23 @@ def reshaped_scaling_shape(module):
block_size = None


class IntWeightSymmetricGroupQuant(WeightSymmetricGroupQuantMixin, Int8WeightPerChannelFloat):
"""
Block / group / vector signed symmetric int weight quantizer with float scales.
We inherit from a per-channel quantizer to re-use some underlying machinery.
"""
pass


class Fp8e4m3WeightSymmetricGroupQuant(WeightSymmetricGroupQuantMixin,
Fp8e4m3WeightPerChannelFloat):
"""
Block / group / vector signed symmetric e4m3 weight quantizer with float scales.
We inherit from a per-channel quantizer to re-use some underlying machinery.
"""
pass


class ShiftedUintWeightAsymmetricGroupQuant(IntWeightSymmetricGroupQuant):
"""
Block / group / vector signed asymmetric weight quantizer with float scales and zero-points.
Expand Down Expand Up @@ -125,7 +140,7 @@ class Int8ActDynamicPerRowFloat(Int8ActPerRowFloat):

class Int8ActDynamicPerGroupFloat(Int8ActPerRowFloat):
"""
Symmetric quantizer with per row dynamic scale.
Symmetric quantizer with per group scale.
"""
scaling_impl = RuntimeDynamicGroupStatsScaling
keepdim = True
Expand Down
Loading
Loading