Skip to content

Commit

Permalink
Feat (minifloat): support for FNUZ variants
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jun 18, 2024
1 parent d90f876 commit 335e514
Show file tree
Hide file tree
Showing 8 changed files with 261 additions and 41 deletions.
19 changes: 13 additions & 6 deletions src/brevitas/export/common/handler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,20 @@ def validate_neg_scalar_int_exponent(cls, scale: Tensor):
class FloatZeroPointHandlerMixin(ABC):

@classmethod
def zero_point_with_dtype(cls, exponent_bit_width, mantissa_bit_width, zero_point):
if exponent_bit_width == 4 and mantissa_bit_width == 3:
return zero_point.type(torch.float8_e4m3fn)
elif exponent_bit_width == 5 and mantissa_bit_width == 2:
return zero_point.type(torch.float8_e5m2)
def zero_point_with_dtype(
cls, exponent_bit_width, mantissa_bit_width, is_ocp, is_fnuz, zero_point):
if is_ocp:
if exponent_bit_width == 4 and mantissa_bit_width == 3:
return zero_point.type(torch.float8_e4m3fn)
elif exponent_bit_width == 5 and mantissa_bit_width == 2:
return zero_point.type(torch.float8_e5m2)
elif is_fnuz:
if exponent_bit_width == 4 and mantissa_bit_width == 3:
return zero_point.type(torch.float8_e4m3fnuz)
elif exponent_bit_width == 5 and mantissa_bit_width == 2:
return zero_point.type(torch.float8_e5m2fnuz)
else:
return zero_point.type(torch.float32)
raise NotImplementedError


class ZeroPointHandlerMixin(ABC):
Expand Down
53 changes: 35 additions & 18 deletions src/brevitas/export/common/handler/qcdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,21 @@ def quantize_fn(self, x, scale, zero_point, dtype, axis):
pass

@classmethod
def signed_dtype(cls, exponent_bit_width, mantissa_bit_width):
def signed_dtype(cls, exponent_bit_width, mantissa_bit_width, is_ocp, is_fnuz):
if exponent_bit_width is None or mantissa_bit_width is None:
return None
if exponent_bit_width == 4 and mantissa_bit_width == 3:
dtype = torch.float8_e4m3fn
elif exponent_bit_width == 5 and mantissa_bit_width == 2:
dtype = torch.float8_e5m2
if is_ocp:
if exponent_bit_width == 4 and mantissa_bit_width == 3:
dtype = torch.float8_e4m3fn
elif exponent_bit_width == 5 and mantissa_bit_width == 2:
dtype = torch.float8_e5m2
elif is_fnuz:
if exponent_bit_width == 4 and mantissa_bit_width == 3:
dtype = torch.float8_e4m3fnuz
elif exponent_bit_width == 5 and mantissa_bit_width == 2:
dtype = torch.float8_e5m2fnuz
else:
dtype = torch.float32
raise NotImplementedError
return dtype


Expand Down Expand Up @@ -140,7 +146,8 @@ class FloatCDQCastProxyHandlerMixin(QuantAxisMixin,
CDQCastMixin,
ABC):

def dequantize_symbolic_kwargs(cls, scale, zero_point, exponent_bit_width, mantissa_bit_width):
def dequantize_symbolic_kwargs(
cls, scale, zero_point, exponent_bit_width, mantissa_bit_width, is_ocp, is_fnuz):
scale_orig_shape = scale.shape
axis = cls.quant_axis(scale)
if cls.flatten_dequantize_params:
Expand All @@ -150,7 +157,7 @@ def dequantize_symbolic_kwargs(cls, scale, zero_point, exponent_bit_width, manti
zero_point = zero_point.flatten()
zp = to_0dim_if_scalar(zero_point)
zp = zp.expand_as(scale)
zp = cls.zero_point_with_dtype(exponent_bit_width, mantissa_bit_width, zp)
zp = cls.zero_point_with_dtype(exponent_bit_width, mantissa_bit_width, is_ocp, is_fnuz, zp)
return {
'scale': scale,
'zero_point': zp,
Expand Down Expand Up @@ -187,18 +194,19 @@ def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed):
class FloatQCDQCastWeightQuantProxyHandlerMixin(FloatQMixin, FloatCDQCastProxyHandlerMixin):
handled_layer = WeightFloatQuantProxyFromInjector

def quantize_symbolic_kwargs(cls, scale, zero_point, exponent_bit_width, mantissa_bit_width):
def quantize_symbolic_kwargs(
cls, scale, zero_point, exponent_bit_width, mantissa_bit_width, is_ocp, is_fnuz):
# compute axis before redefining scale
axis = cls.quant_axis(scale)
scale = to_0dim_if_scalar(scale.flatten())
zp = to_0dim_if_scalar(zero_point.flatten())
# expand_as must go after 0-dim check
zp = zp.expand_as(scale)
zp = cls.zero_point_with_dtype(exponent_bit_width, mantissa_bit_width, zp)
zp = cls.zero_point_with_dtype(exponent_bit_width, mantissa_bit_width, is_ocp, is_fnuz, zp)
if cls.itemize_quantize_scalar_params:
scale = to_item_if_0dim(scale)
zp = to_item_if_0dim(zp)
dtype = cls.signed_dtype(exponent_bit_width, mantissa_bit_width)
dtype = cls.signed_dtype(exponent_bit_width, mantissa_bit_width, is_ocp, is_fnuz)
return {'scale': scale, 'zero_point': zp, 'dtype': dtype, 'axis': axis}

def prepare_quantize_from_floating_point(self, module):
Expand All @@ -211,7 +219,9 @@ def prepare_quantize_from_floating_point(self, module):
scale,
quant_weight.zero_point,
quant_weight.exponent_bit_width,
quant_weight.mantissa_bit_width)
quant_weight.mantissa_bit_width,
module.is_ocp,
module.is_fnuz)

def prepare_quantize_from_minifloat(self, module):
raise NotImplementedError
Expand Down Expand Up @@ -249,7 +259,9 @@ def prepare_for_export(self, module):
scale,
quant_weight.zero_point,
quant_weight.exponent_bit_width,
quant_weight.mantissa_bit_width)
quant_weight.mantissa_bit_width,
module.is_ocp,
module.is_fnuz)
else:
self.symbolic_kwargs = None

Expand Down Expand Up @@ -421,18 +433,19 @@ def symbolic_execution(self, x: Tensor, input_bit_width: torch.Tensor, input_is_
class FloatQCDQCastActQuantProxyHandlerMixin(FloatQMixin, FloatCDQCastProxyHandlerMixin, ABC):
handled_layer = ActFloatQuantProxyFromInjector

def quantize_symbolic_kwargs(cls, scale, zero_point, exponent_bit_width, mantissa_bit_width):
def quantize_symbolic_kwargs(
cls, scale, zero_point, exponent_bit_width, mantissa_bit_width, is_ocp, is_fnuz):
# compute axis before redefining scale
axis = cls.quant_axis(scale)
scale = to_0dim_if_scalar(scale.flatten())
zp = to_0dim_if_scalar(zero_point.flatten())
# expand_as must go after 0-dim check
zp = zp.expand_as(scale)
zp = cls.zero_point_with_dtype(exponent_bit_width, mantissa_bit_width, zp)
zp = cls.zero_point_with_dtype(exponent_bit_width, mantissa_bit_width, is_ocp, is_fnuz, zp)
if cls.itemize_quantize_scalar_params:
scale = to_item_if_0dim(scale)
zp = to_item_if_0dim(zp)
dtype = cls.signed_dtype(exponent_bit_width, mantissa_bit_width)
dtype = cls.signed_dtype(exponent_bit_width, mantissa_bit_width, is_ocp, is_fnuz)
return {'scale': scale, 'zero_point': zp, 'dtype': dtype, 'axis': axis}

def prepare_for_export(self, module):
Expand All @@ -457,12 +470,16 @@ def prepare_for_export(self, module):
scale,
module.zero_point(),
module.exponent_bit_width(),
module.mantissa_bit_width())
module.mantissa_bit_width(),
module.is_ocp,
module.is_fnuz)
self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs(
scale,
module.zero_point(),
module.exponent_bit_width(),
module.mantissa_bit_width())
module.mantissa_bit_width(),
module.is_ocp,
module.is_fnuz)
self.symbolic_kwargs['clip_symbolic_kwargs'] = self.clip_symbolic_kwargs(
module.is_narrow_range,
module.is_signed,
Expand Down
14 changes: 1 addition & 13 deletions src/brevitas/export/onnx/standard/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,8 @@ def validate(self, module):

class StdFloatDQCastONNXMixin(StdDQCastONNXMixin, ABC):

def is_ocp(self, module):
is_e4m3 = module.mantissa_bit_width() == 3 and module.exponent_bit_width() == 4

is_ocp_e4m3 = is_e4m3 and module.inf_values() is None and module.nan_values() == (('111',))

is_e5m2 = module.mantissa_bit_width() == 5 and module.exponent_bit_width() == 2

is_ocp_e5m2 = is_e5m2 and module.inf_values() == (
('00',)) and module.nan_values() == ('01', '11', '10')

return is_ocp_e4m3 or is_ocp_e5m2

def validate(self, module):
assert self.is_ocp(module), 'Only OCP Standard is supported for FP8 export'
assert module.is_ocp or module.is_fnuz, 'Only OCP/FNUZ Standard are supported for FP8 export'


class StdFloatCDQCastONNXMixin(CDQCastMixin, StdFloatDQCastONNXMixin, ABC):
Expand Down
22 changes: 22 additions & 0 deletions src/brevitas/proxy/float_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,28 @@ def nan_values(self):
nan_values = self.__call__(self.tracked_parameter_list[0]).nan_values
return nan_values

@property
def is_ocp(self):
is_e4m3 = self.mantissa_bit_width() == 3 and self.exponent_bit_width() == 4
is_ocp_e4m3 = is_e4m3 and self.inf_values() is None and self.nan_values() == (('111',))

is_e5m2 = self.mantissa_bit_width() == 5 and self.exponent_bit_width() == 2
is_ocp_e5m2 = is_e5m2 and self.inf_values() == (
('00',)) and self.nan_values() == ('01', '11', '10')

return is_ocp_e4m3 or is_ocp_e5m2

@property
def is_fnuz(self):
is_e4m3 = self.mantissa_bit_width() == 3 and self.exponent_bit_width() == 4
is_fnuz_e4m3 = is_e4m3 and self.inf_values() is None and self.nan_values(
) is None and self.exponent_bias() == 8

is_e5m2 = self.mantissa_bit_width() == 5 and self.exponent_bit_width() == 2
is_fnuz_e5m2 = is_e5m2 and self.inf_values() is None and self.nan_values(
) is None and self.exponent_bias() == 16
return is_fnuz_e4m3 or is_fnuz_e5m2

def forward(self, x: torch.Tensor) -> Union[Tensor, FloatQuantTensor]:
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
Expand Down
22 changes: 22 additions & 0 deletions src/brevitas/proxy/float_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,28 @@ def inf_values(self, force_eval=True):
def nan_values(self, force_eval=True):
return self.retrieve_attribute('nan_values', force_eval)

@property
def is_ocp(self):
is_e4m3 = self.mantissa_bit_width() == 3 and self.exponent_bit_width() == 4
is_ocp_e4m3 = is_e4m3 and self.inf_values() is None and self.nan_values() == (('111',))

is_e5m2 = self.mantissa_bit_width() == 5 and self.exponent_bit_width() == 2
is_ocp_e5m2 = is_e5m2 and self.inf_values() == (
('00',)) and self.nan_values() == ('01', '11', '10')

return is_ocp_e4m3 or is_ocp_e5m2

@property
def is_fnuz(self):
is_e4m3 = self.mantissa_bit_width() == 3 and self.exponent_bit_width() == 4
is_fnuz_e4m3 = is_e4m3 and self.inf_values() is None and self.nan_values(
) is None and self.exponent_bias() == 8

is_e5m2 = self.mantissa_bit_width() == 5 and self.exponent_bit_width() == 2
is_fnuz_e5m2 = is_e5m2 and self.inf_values() is None and self.nan_values(
) is None and self.exponent_bias() == 16
return is_fnuz_e4m3 or is_fnuz_e5m2

def forward(self, x: Union[Tensor, FloatQuantTensor]) -> Union[Tensor, FloatQuantTensor]:
out = x
if self.fused_activation_quant_proxy is not None:
Expand Down
Loading

0 comments on commit 335e514

Please sign in to comment.