Skip to content

Commit

Permalink
Feat (quant): initial support for fp8 variants
Browse files Browse the repository at this point in the history
  • Loading branch information
volcacius committed Jul 26, 2023
1 parent 4d0852d commit 96e0fb5
Show file tree
Hide file tree
Showing 6 changed files with 312 additions and 0 deletions.
85 changes: 85 additions & 0 deletions src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Optional

import torch
import torch.nn as nn

import brevitas
from brevitas.core.function_wrapper import RoundSte
from brevitas.core.scaling import ConstScaling
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops import max_float
from brevitas.function.ops_ste import floor_ste


class FloatQuant(brevitas.jit.ScriptModule):
__constants__ = ['signed']

def __init__(
self,
bit_width: int,
signed: bool,
exponent_bit_width: int,
mantissa_bit_width: int,
exponent_bias: Optional[int] = None,
scaling_impl: Optional[nn.Module] = None,
float_scaling_impl: Optional[nn.Module] = None,
float_to_int_impl: nn.Module = RoundSte(),
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None):
super(FloatQuant, self).__init__()
if bit_width != exponent_bit_width + mantissa_bit_width + int(signed):
raise RuntimeError("Mismatch between total bit-width, exponent, mantissa and sign.")
self.bit_width = StatelessBuffer(torch.tensor(float(bit_width), device=device, dtype=dtype))
self.signed: bool = signed
self.float_to_int_impl = float_to_int_impl
self.exponent_bit_width = StatelessBuffer(
torch.tensor(float(exponent_bit_width), device=device, dtype=dtype))
self.mantissa_bit_width = StatelessBuffer(
(torch.tensor(float(mantissa_bit_width), device=device, dtype=dtype)))
self.exponent_bias = StatelessBuffer(
torch.tensor(float(exponent_bias), device=device, dtype=dtype))
self.fp_max_val = StatelessBuffer(
max_float(self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()))
self.fp_internal_scale_min = StatelessBuffer(
1. - self.exponent_bias() - self.mantissa_bit_width())
if float_scaling_impl is None:
float_scaling_impl = ConstScaling(1., device=device, dtype=dtype)
if scaling_impl is None:
scaling_impl = ConstScaling(1., device=device, dtype=dtype)
# Zero-point is currently hardcoded to 0
self.zero_point_impl = StatelessBuffer(torch.tensor(0., device=device, dtype=dtype))
self.float_scaling_impl = float_scaling_impl
self.scaling_impl = scaling_impl

@brevitas.jit.script_method
def internal_scale(self, x):
internal_scale = floor_ste(torch.log2(torch.abs(x))) - self.mantissa_bit_width()
internal_scale = torch.clamp_min(internal_scale, self.fp_internal_scale_min())
internal_scale = torch.exp2(internal_scale)
return internal_scale

@brevitas.jit.script_method
def quantize(self, x: torch.Tensor):
scale = self.scaling_impl(x) / self.float_scaling_impl(x)
scaled_x = x / scale
internal_scale = self.internal_scale(scaled_x)
val_fp_quant = internal_scale * self.float_to_int_impl(scaled_x / internal_scale)
if self.signed:
val_fp_quant = torch.clip(val_fp_quant, -1. * self.fp_max_val(), self.fp_max_val())
else:
val_fp_quant = torch.clip(val_fp_quant, 0., self.fp_max_val())
return val_fp_quant, scale

@brevitas.jit.script_method
def dequantize(self, y, scale):
return y * scale

@brevitas.jit.script_method
def forward(self, x):
y, scale = self.quantize(x)
y = self.dequantize(y, scale)
# This is to respect the current interface of proxies
return y, scale, self.zero_point_impl(), self.bit_width()
32 changes: 32 additions & 0 deletions src/brevitas/core/scaling/float_scaling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Optional

import torch
from torch import Tensor

import brevitas
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops import max_float


class FloatScaling(brevitas.jit.ScriptModule):

def __init__(
self,
exponent_bit_width: int,
mantissa_bit_width: int,
exponent_bias: int,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None):
super(FloatScaling, self).__init__()
exponent_bit_width = torch.tensor(exponent_bit_width, device=device, dtype=dtype)
mantissa_bit_width = torch.tensor(mantissa_bit_width, device=device, dtype=dtype)
exponent_bias = torch.tensor(exponent_bias, device=device, dtype=dtype)
self.max_float_val = StatelessBuffer(
max_float(exponent_bit_width, mantissa_bit_width, exponent_bias))

@brevitas.jit.script_method
def forward(self, input: torch.Tensor) -> Tensor:
return self.max_float_val()
14 changes: 14 additions & 0 deletions src/brevitas/function/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,17 @@ def min_int(signed: bool, narrow_range: bool, bit_width: Tensor) -> Tensor:
else:
value = 0 * bit_width
return value


@brevitas.jit.script
def max_float(exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_bias: Tensor):
max_exponent = (2. ** exponent_bit_width) - 1. - exponent_bias
max_mantissa = torch.sum((
2. ** torch.arange(
0,
-1. * mantissa_bit_width - 1.,
-1.,
dtype=mantissa_bit_width.dtype,
device=mantissa_bit_width.device)))
max_val = max_mantissa * (2 ** max_exponent)
return max_val
Empty file.
126 changes: 126 additions & 0 deletions src/brevitas/quant/experimental/float.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from brevitas.quant.base import MSESymmetricScale
from brevitas.quant.experimental.float_base import FloatActBase
from brevitas.quant.experimental.float_base import FloatWeightBase
from brevitas.quant.experimental.float_base import Fp8e4m3Mixin
from brevitas.quant.experimental.float_base import Fp8e5m2Mixin
from brevitas.quant.experimental.float_base import ScaledFloatActBase
from brevitas.quant.experimental.float_base import ScaledFloatWeightBase


class Fp8e4m3Weight(Fp8e4m3Mixin, FloatWeightBase):
"""
FP8 signed E3M4 weight quantizer.
"""
pass


class Fp8e5m2Weight(Fp8e5m2Mixin, FloatWeightBase):
"""
FP8 signed E5M2 weight quantizer.
"""
pass


class Fp8e4m3Act(Fp8e4m3Mixin, FloatActBase):
"""
FP8 signed E4M3 activation quantizer.
"""
pass


class Fp8e5m2Act(Fp8e5m2Mixin, FloatActBase):
"""
FP8 signed E5M2 activation quantizer.
"""
pass


class Fp8e4m3WeightPerTensorFloat(Fp8e4m3Mixin, ScaledFloatWeightBase):
"""
FP8 signed E3M4 weight quantizer with per-tensor absmax-based scaling.
"""
scaling_per_output_channel = False


class Fp8e5m2WeightPerTensorFloat(Fp8e5m2Mixin, ScaledFloatWeightBase):
"""
FP8 signed E5M2 weight quantizer with per-tensor absmax-based scaling.
"""
scaling_per_output_channel = False


class Fp8e4m3ActPerTensorFloat(Fp8e4m3Mixin, ScaledFloatActBase):
"""
FP8 signed E4M3 activation quantizer with per-tensor static percentile-based scaling.
"""
scaling_per_output_channel = False


class Fp8e5m2ActPerTensorFloat(Fp8e5m2Mixin, ScaledFloatActBase):
"""
FP8 signed E5M2 activation quantizer with per-tensor static percentile-based scaling.
"""
scaling_per_output_channel = False


class Fp8e4m3WeightPerChannelFloat(Fp8e4m3Mixin, ScaledFloatWeightBase):
"""
FP8 signed E3M4 weight quantizer with per-channel absmax-based scaling.
"""
scaling_per_output_channel = True


class Fp8e5m2WeightPerChannelFloat(Fp8e5m2Mixin, ScaledFloatWeightBase):
"""
FP8 signed E5M2 weight quantizer with per-channel absmax-based scaling.
"""
scaling_per_output_channel = True


class Fp8e4m3ActPerChannelFloat2d(Fp8e4m3Mixin, ScaledFloatActBase):
"""
FP8 signed E4M3 activation quantizer with per-channel static percentile-based scaling.
"""
scaling_per_output_channel = True
scaling_stats_permute_dims = (1, 0, 2, 3)


class Fp8e5m2ActPerChannelFloat2d(Fp8e5m2Mixin, ScaledFloatActBase):
"""
FP8 signed E5M2 activation quantizer with per-channel static percentile-based scaling.
"""
scaling_per_output_channel = True
scaling_stats_permute_dims = (1, 0, 2, 3)


class Fp8e4m3ActPerTensorFloatMSE(Fp8e4m3Mixin, MSESymmetricScale, ScaledFloatActBase):
"""
FP8 signed E4M3 activation quantizer with per-tensor static MSE-based scaling.
"""
scaling_per_output_channel = False


class Fp8e5m2ActPerTensorFloatMSE(Fp8e5m2Mixin, MSESymmetricScale, ScaledFloatActBase):
"""
FP8 signed E5M2 activation quantizer with per-tensor static MSE-based scaling.
"""
scaling_per_output_channel = False


class Fp8e4m3ActPerChannelFloat2dMSE(Fp8e4m3Mixin, MSESymmetricScale, ScaledFloatActBase):
"""
FP8 signed E4M3 activation quantizer with per-channel static MSE-based scaling.
"""
scaling_per_output_channel = True
scaling_stats_permute_dims = (1, 0, 2, 3)


class Fp8e5m2ActPerChannelFloat2dMSE(Fp8e5m2Mixin, MSESymmetricScale, ScaledFloatActBase):
"""
FP8 signed E5M2 activation quantizer with per-channel static MSE-based scaling.
"""
scaling_per_output_channel = True
scaling_stats_permute_dims = (1, 0, 2, 3)
55 changes: 55 additions & 0 deletions src/brevitas/quant/experimental/float_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from brevitas.core.function_wrapper import RoundSte
from brevitas.core.quant.float import FloatQuant
from brevitas.core.scaling.float_scaling import FloatScaling
from brevitas.inject import ExtendedInjector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector
from brevitas.quant.solver import ActQuantSolver
from brevitas.quant.solver import WeightQuantSolver


class FloatWeightBase(ExtendedInjector):
proxy_class = WeightQuantProxyFromInjector
tensor_quant = FloatQuant
signed = True
float_to_int_impl = RoundSte


class FloatActBase(ExtendedInjector):
proxy_class = ActQuantProxyFromInjector
tensor_quant = FloatQuant
signed = True
float_to_int_impl = RoundSte


class ScaledFloatWeightBase(FloatWeightBase, WeightQuantSolver):
scaling_stats_op = 'max'
scaling_impl_type = 'stats'
restrict_scaling_type = 'fp'
float_scaling_impl = FloatScaling


class ScaledFloatActBase(FloatActBase, ActQuantSolver):
scaling_stats_op = 'percentile'
scaling_impl_type = 'parameter_from_stats'
restrict_scaling_type = 'fp'
high_percentile_q = 99.999
collect_stats_steps = 300
float_scaling_impl = FloatScaling


class Fp8e4m3Mixin(ExtendedInjector):
bit_width = 8
exponent_bit_width = 4
mantissa_bit_width = 3
exponent_bias = 7


class Fp8e5m2Mixin(ExtendedInjector):
bit_width = 8
exponent_bit_width = 5
mantissa_bit_width = 2
exponent_bias = 15

0 comments on commit 96e0fb5

Please sign in to comment.