Skip to content

Commit

Permalink
Feat (examples/llm): add custom float support
Browse files Browse the repository at this point in the history
  • Loading branch information
volcacius committed Sep 21, 2023
1 parent c6b86cc commit 799349e
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 70 deletions.
170 changes: 111 additions & 59 deletions src/brevitas_examples/llm/llm_quant/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@
Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
"""
import re

from torch import nn

from brevitas import nn as qnn
from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint
from brevitas.graph.quantize import layerwise_quantize
from brevitas.quant.experimental.float import Fp8e4m3Act
from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat
from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat
from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloat
from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint
from brevitas.quant.fixed_point import Int8ActPerTensorFixedPointMSE
from brevitas.quant.fixed_point import Int8WeightPerChannelFixedPoint
Expand All @@ -26,6 +31,7 @@
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatMSE
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE
from brevitas_examples.llm.llm_quant.quantizers import Fp8e4m3WeightSymmetricGroupQuant
from brevitas_examples.llm.llm_quant.quantizers import Int8ActDynamicPerGroupFloat
from brevitas_examples.llm.llm_quant.quantizers import Int8ActDynamicPerRowFloat
from brevitas_examples.llm.llm_quant.quantizers import Int8ActDynamicPerTensorFloat
Expand All @@ -37,62 +43,82 @@
from brevitas_examples.llm.llm_quant.quantizers import ShiftedUintWeightAsymmetricGroupQuant

WEIGHT_QUANT_MAP = {
'float': {
'stats': {
'per_tensor': {
'sym': Int8WeightPerTensorFloat, 'asym': ShiftedUint8WeightPerTensorFloat},
'per_channel': {
'sym': Int8WeightPerChannelFloat, 'asym': ShiftedUint8WeightPerChannelFloat},
'per_group': {
'sym': IntWeightSymmetricGroupQuant, 'asym': ShiftedUintWeightAsymmetricGroupQuant},
},
'mse': {
'per_tensor': {
'sym': Int8WeightPerTensorFloatMSE, 'asym': ShiftedUint8WeightPerTensorFloatMSE},
'per_channel': {
'sym': Int8WeightPerChannelFloatMSE, 'asym': ShiftedUint8WeightPerChannelFloatMSE},
},},
'po2': {
'stats': {
'per_tensor': {
'sym': Int8WeightPerTensorFixedPoint},
'per_channel': {
'sym': Int8WeightPerChannelFixedPoint},},
'mse': {
'per_tensor': {
'sym': Int8WeightPerTensorFixedPointMSE},
'per_channel': {
'sym': Int8WeightPerChannelFixedPointMSE},},}}

INPUT_QUANT_MAP = {
'static': {
'float': {
'int': {
'float_scale': {
'stats': {
'per_tensor': {
'sym': Int8ActPerTensorFloat, 'asym': ShiftedUint8ActPerTensorFloat},
'per_row': {
'sym': Int8ActPerRowFloat, 'asym': ShiftedUint8ActPerRowFloat},},
'sym': Int8WeightPerTensorFloat, 'asym': ShiftedUint8WeightPerTensorFloat},
'per_channel': {
'sym': Int8WeightPerChannelFloat, 'asym': ShiftedUint8WeightPerChannelFloat},
'per_group': {
'sym': IntWeightSymmetricGroupQuant,
'asym': ShiftedUintWeightAsymmetricGroupQuant},},
'mse': {
'per_tensor': {
'sym': Int8ActPerTensorFloatMSE, 'asym': ShiftedUint8ActPerTensorFloatMSE},
'per_row': {
'sym': Int8ActPerRowFloatMSE, 'asym': ShiftedUint8ActPerRowFloatMSE},},},
'po2': {
'sym': Int8WeightPerTensorFloatMSE,
'asym': ShiftedUint8WeightPerTensorFloatMSE},
'per_channel': {
'sym': Int8WeightPerChannelFloatMSE,
'asym': ShiftedUint8WeightPerChannelFloatMSE},},},
'po2_scale': {
'stats': {
'per_tensor': {
'sym': Int8ActPerTensorFixedPoint},},
'sym': Int8WeightPerTensorFixedPoint},
'per_channel': {
'sym': Int8WeightPerChannelFixedPoint},},
'mse': {
'per_tensor': {
'sym': Int8ActPerTensorFixedPointMSE},},}},
'dynamic': {
'float': {
'sym': Int8WeightPerTensorFixedPointMSE},
'per_channel': {
'sym': Int8WeightPerChannelFixedPointMSE},},}},
'float': {
'float_scale': {
'stats': {
'per_tensor': {
'sym': Int8ActDynamicPerTensorFloat},
'per_row': {
'sym': Int8ActDynamicPerRowFloat},
'sym': Fp8e4m3WeightPerTensorFloat},
'per_channel': {
'sym': Fp8e4m3WeightPerChannelFloat},
'per_group': {
'sym': Int8ActDynamicPerGroupFloat},}}}}
'sym': Fp8e4m3WeightSymmetricGroupQuant}},}}}

INPUT_QUANT_MAP = {
'int': {
'static': {
'float_scale': {
'stats': {
'per_tensor': {
'sym': Int8ActPerTensorFloat, 'asym': ShiftedUint8ActPerTensorFloat},
'per_row': {
'sym': Int8ActPerRowFloat, 'asym': ShiftedUint8ActPerRowFloat},},
'mse': {
'per_tensor': {
'sym': Int8ActPerTensorFloatMSE, 'asym': ShiftedUint8ActPerTensorFloatMSE},
'per_row': {
'sym': Int8ActPerRowFloatMSE, 'asym': ShiftedUint8ActPerRowFloatMSE},},},
'po2_scale': {
'stats': {
'per_tensor': {
'sym': Int8ActPerTensorFixedPoint},},
'mse': {
'per_tensor': {
'sym': Int8ActPerTensorFixedPointMSE},},}},
'dynamic': {
'float_scale': {
'stats': {
'per_tensor': {
'sym': Int8ActDynamicPerTensorFloat},
'per_row': {
'sym': Int8ActDynamicPerRowFloat},
'per_group': {
'sym': Int8ActDynamicPerGroupFloat},}}}},
'float': {
'static': {
'float_scale': {
'stats': {
'per_tensor': {
'sym': Fp8e4m3ActPerTensorFloat},}}},
'no_scale': {
'sym': Fp8e4m3Act,}}}


def quantize_model(
Expand All @@ -105,7 +131,9 @@ def quantize_model(
weight_quant_granularity,
weight_group_size,
quantize_weight_zero_point,
weight_quant_format='int',
input_bit_width=None,
input_quant_format=None,
input_scale_precision=None,
input_scale_type=None,
input_param_method=None,
Expand All @@ -119,18 +147,38 @@ def quantize_model(
Replace float layers with quant layers in the target model
"""
# Retrive base input and weight quantizers
weight_quant = WEIGHT_QUANT_MAP[weight_scale_precision][weight_param_method][
weight_quant_granularity][weight_quant_type]
if input_bit_width is not None:
input_quant = INPUT_QUANT_MAP[input_scale_type][input_scale_precision][input_param_method][
input_quant_granularity][input_quant_type]

# match against custom float format
if re.compile(r'e[1-8]m[1-8]').match(weight_quant_format):
weight_float_format = {
'exponent_bit_width': int(weight_quant_format[1]),
'mantissa_bit_width': int(weight_quant_format[3])}
weight_quant_format = 'float'
else:
weight_float_format = {}
if re.compile(r'e[1-8]m[1-8]').match(input_quant_format):
input_float_format = {
'exponent_bit_width': int(input_quant_format[1]),
'mantissa_bit_width': int(input_quant_format[3])}
input_quant_format = 'float'
else:
input_float_format = {}

weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][weight_scale_precision][
weight_param_method][weight_quant_granularity][weight_quant_type]
if input_bit_width is not None and input_scale_type == 'no_scale':
input_quant = sym_input_quant = linear_2d_input_quant = INPUT_QUANT_MAP[input_quant_format][
input_scale_type][input_quant_type]
elif input_bit_width is not None:
input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][input_scale_precision][
input_param_method][input_quant_granularity][input_quant_type]
# Some activations in MHA should always be symmetric
sym_input_quant = INPUT_QUANT_MAP[input_scale_type][input_scale_precision][
input_param_method][input_quant_granularity]['sym']
sym_input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][
input_scale_precision][input_param_method][input_quant_granularity]['sym']
# Linear layers with 2d input should always be per tensor or per group, as there is no row dimension
if input_quant_granularity == 'per_tensor' or input_quant_granularity == 'per_row':
linear_2d_input_quant = INPUT_QUANT_MAP[input_scale_type][input_scale_precision][
input_param_method]['per_tensor'][input_quant_type]
linear_2d_input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][
input_scale_precision][input_param_method]['per_tensor'][input_quant_type]
else:
assert input_quant_granularity == 'per_group'
linear_2d_input_quant = input_quant
Expand All @@ -145,7 +193,8 @@ def quantize_model(
'bit_width': weight_bit_width,
'narrow_range': False,
'block_size': weight_group_size,
'quantize_zero_point': quantize_weight_zero_point})
'quantize_zero_point': quantize_weight_zero_point},
**weight_float_format)
# weight scale is converted to a standalone parameter
# This is done already by default in the per_group quantizer
if weight_quant_granularity != 'per_group':
Expand All @@ -161,7 +210,8 @@ def quantize_model(
**{
'bit_width': input_bit_width,
'quantize_zero_point': quantize_input_zero_point,
'dtype': dtype})
'dtype': dtype,},
**input_float_format)
if input_scale_type == 'static' and input_quant_granularity == 'per_row':
# QuantMHA internally always uses Seq, B, E
input_quant = input_quant.let(
Expand All @@ -188,7 +238,8 @@ def quantize_model(
**{
'bit_width': input_bit_width,
'quantize_zero_point': quantize_input_zero_point,
'dtype': dtype})
'dtype': dtype},
**input_float_format)
if input_scale_type == 'static' and input_quant_granularity == 'per_row':
q_scaled_quant = sym_input_quant.let(
**{
Expand Down Expand Up @@ -241,7 +292,8 @@ def quantize_model(
**{
'bit_width': input_bit_width,
'quantize_zero_point': quantize_input_zero_point,
'dtype': dtype})
'dtype': dtype},
**input_float_format)
if input_scale_type == 'dynamic':
# Note: this breaks if applied to 3d Linear inputs,
# in case standard MHA layers haven't been inserted
Expand All @@ -265,7 +317,7 @@ def quantize_model(
'in_proj_bias_quant': None,
'softmax_input_quant': None,
'attn_output_weights_quant': attn_output_weights_quant,
'attn_output_weights_signed': False,
'attn_output_weights_signed': input_quant_format == 'float',
'q_scaled_quant': q_scaled_quant,
'k_transposed_quant': k_transposed_quant,
'v_quant': v_quant,
Expand Down
27 changes: 21 additions & 6 deletions src/brevitas_examples/llm/llm_quant/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
from brevitas.core.stats import NegativePercentileOrZero
from brevitas.core.zero_point import ParameterFromRuntimeZeroPoint
from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint
from brevitas.inject import ExtendedInjector
from brevitas.inject import this
from brevitas.inject import value
from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
from brevitas.quant.scaled_int import Int8ActPerTensorFloatMSE
from brevitas.quant.scaled_int import Int8WeightPerChannelFloat
Expand All @@ -25,11 +27,7 @@
from .quant_blocks import *


class IntWeightSymmetricGroupQuant(Int8WeightPerChannelFloat):
"""
Block / group / vector signed symmetric weight quantizer with float scales.
We inherit from a per-channel quantizer to re-use some underlying machinery.
"""
class WeightSymmetricGroupQuantMixin(ExtendedInjector):

@value
def expanded_scaling_shape(module, block_size):
Expand Down Expand Up @@ -69,6 +67,23 @@ def reshaped_scaling_shape(module):
block_size = None


class IntWeightSymmetricGroupQuant(WeightSymmetricGroupQuantMixin, Int8WeightPerChannelFloat):
"""
Block / group / vector signed symmetric int weight quantizer with float scales.
We inherit from a per-channel quantizer to re-use some underlying machinery.
"""
pass


class Fp8e4m3WeightSymmetricGroupQuant(WeightSymmetricGroupQuantMixin,
Fp8e4m3WeightPerChannelFloat):
"""
Block / group / vector signed symmetric e4m3 weight quantizer with float scales.
We inherit from a per-channel quantizer to re-use some underlying machinery.
"""
pass


class ShiftedUintWeightAsymmetricGroupQuant(IntWeightSymmetricGroupQuant):
"""
Block / group / vector signed asymmetric weight quantizer with float scales and zero-points.
Expand Down Expand Up @@ -125,7 +140,7 @@ class Int8ActDynamicPerRowFloat(Int8ActPerRowFloat):

class Int8ActDynamicPerGroupFloat(Int8ActPerRowFloat):
"""
Symmetric quantizer with per row dynamic scale.
Symmetric quantizer with per group scale.
"""
scaling_impl = RuntimeDynamicGroupStatsScaling
keepdim = True
Expand Down
Loading

0 comments on commit 799349e

Please sign in to comment.