Skip to content

Commit

Permalink
Final update hopefully
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Aug 13, 2024
1 parent 5fef0b1 commit 5cc7648
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 103 deletions.
40 changes: 16 additions & 24 deletions src/brevitas/core/stats/stats_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,7 @@ def __init__(
self,
proxy_module,
hqo_init_op_scale,
keepdim: bool,
inner_stats_input_view_shape_impl: torch.nn.Module,
scaling_min_val: Optional[float] = None,
stats_reduce_dim: Optional[int] = None,
Expand Down Expand Up @@ -601,11 +602,11 @@ def __init__(
self.clamp_min_ste = ScalarClampMinSte(scaling_min_val)
else:
self.clamp_min_ste = Identity()
self.keepdim = keepdim

def parameter_search(self, xl, x):
best_loss = torch.tensor(float('inf'), device=x.device, dtype=x.dtype)
candidate = xl
# candidate = self.input_view_shape_impl(candidate)
best_candidate = candidate
beta = self.beta
with torch.no_grad():
Expand All @@ -614,24 +615,23 @@ def parameter_search(self, xl, x):
self.set_local_loss_mode(True)
quant_tensor = self.proxy_forward(x).detach()
self.set_local_loss_mode(False)
loss = torch.abs(quant_tensor.value_ - x).mean()
loss = torch.abs(quant_tensor.value - x).mean()

best_candidate = torch.where(loss < best_loss, candidate, best_candidate)
if loss >= best_loss:
break
best_loss = torch.min(loss, best_loss)
W_e = shrink_lp_op(x - quant_tensor.value_, beta, self.lp_norm)
W_e = shrink_lp_op(x - quant_tensor.value, beta, self.lp_norm)
zero_point = quant_tensor.zero_point
num = self.input_view_shape_impl(x - W_e).detach()
den = self.input_view_shape_impl(
torch.round(quant_tensor.value_ / quant_tensor.scale_) - zero_point).detach()
torch.round(quant_tensor.value / quant_tensor.scale) - zero_point).detach()
mask = (num != 0.) & (den != 0.)
if self.stats_reduce_dim is None:
candidate = masked_median(num / den, mask)
else:
candidate = masked_median(
num / den, mask, dim=self.stats_reduce_dim, keepdim=True)
# candidate = self.input_view_shape_impl(candidate)
num / den, mask, dim=self.stats_reduce_dim, keepdim=self.keepdim)
candidate = self.clamp_min_ste(candidate)
bit_width = self.msb_clamp_bit_width_impl()
int_threshold = self.int_scaling_impl(bit_width)
Expand Down Expand Up @@ -672,19 +672,17 @@ class HalfQuadraticOptimizerZeroPoint(torch.nn.Module):
def __init__(
self,
proxy_module,
keepdim: bool,
hqo_init_op_zp: torch.nn.Module,
inner_stats_input_view_shape_impl: torch.nn.Module,
stats_reduce_dim: Optional[int] = None,
inner_expanded_zero_point_shape=None,
reshaped_zero_point_shape=None,
hqo_beta_zp: float = 1e0,
hqo_kappa_zp: float = 1.01,
hqo_lp_norm_zp: float = .5,
hqo_iters_zp: int = 1000):
super(HalfQuadraticOptimizerZeroPoint, self).__init__()
self.hqo_init_op_zp = hqo_init_op_zp
self.input_view_shape_impl = inner_stats_input_view_shape_impl
self.proxy_module = proxy_module
self.proxy_forward = proxy_module.forward
self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled)
self.set_quantize_zero_point = lambda enabled: _set_quantize_zero_point(
Expand All @@ -696,47 +694,41 @@ def __init__(
self.kappa = hqo_kappa_zp
self.lp_norm = hqo_lp_norm_zp
self.hqo_iters = hqo_iters_zp
self.inner_expanded_zero_point_shape = inner_expanded_zero_point_shape
self.reshaped_zero_point_shape = reshaped_zero_point_shape
self.keepdim = keepdim

def parameter_search(self, xl, x):
best_loss = torch.tensor(float('inf'), device=x.device, dtype=x.dtype)
candidate = xl
candidate = self.input_view_shape_impl(candidate)
best_candidate = candidate
with torch.no_grad():
for i in range(0, self.hqo_iters):
self.internal_candidate = candidate
self.set_local_loss_mode(True)
prev_state = _set_quantize_zero_point(self.proxy_module, False)
quant_tensor = self.proxy_forward(x).detach()
self.set_local_loss_mode(False)
_restore_quantize_zero_point(self.proxy_module, prev_state)
loss = torch.abs(quant_tensor.value - x).mean()
qt_value = self.input_view_shape_impl(quant_tensor.value)
qt_scale = self.input_view_shape_impl(quant_tensor.scale)
qt_int = self.input_view_shape_impl(quant_tensor.int())
loss = torch.abs(qt_value - x).mean()
best_candidate = torch.where(loss < best_loss, candidate, best_candidate)
if loss >= best_loss:
break
best_loss = torch.min(loss, best_loss)
W_e = shrink_lp_op(x - quant_tensor.value, self.beta, self.lp_norm)
W_e = shrink_lp_op(x - qt_value, self.beta, self.lp_norm)

val = self.input_view_shape_impl((x - W_e) - qt_int * qt_scale)

val = self.input_view_shape_impl((x - W_e) -
quant_tensor.int() * quant_tensor.scale)
if self.inner_expanded_zero_point_shape is not None:
val = val.reshape(self.inner_expanded_zero_point_shape)
if self.stats_reduce_dim is None:
candidate = torch.mean(val)
else:
candidate = torch.mean(val, dim=self.stats_reduce_dim, keepdim=True)
candidate = self.input_view_shape_impl(candidate)
candidate = torch.mean(val, dim=self.stats_reduce_dim, keepdim=self.keepdim)
self.beta *= self.kappa
return best_candidate

def optimize(self, x):
x_view = self.input_view_shape_impl(x)

init = self.hqo_init_op_zp(x_view).detach()
if self.reshaped_zero_point_shape is not None:
x = x.reshape(self.reshaped_zero_point_shape)

best_candidate = self.parameter_search(init, x)

Expand Down
23 changes: 2 additions & 21 deletions src/brevitas/quant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,42 +521,23 @@ class MSEActZeroPoint(MSEZeroPoint):
class HQOZeroPoint(ExtendedInjector):

hqo_init_op_zp = NegativeMinOrZero
inner_stats_input_view_shape_impl = this.zero_point_stats_input_view_shape_impl
stats_impl_zp = HalfQuadraticOptimizerZeroPoint
zero_point_stats_input_view_shape_impl = nn.Identity()

@value
def zero_point_stats_impl():
return this.stats_impl_zp

@value
def inner_stats_input_view_shape_impl(scaling_per_output):
if scaling_per_output == ScalingPerOutputType.CHANNEL:
return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
elif scaling_per_output == ScalingPerOutputType.TENSOR:
return StatsInputViewShapeImpl.OVER_TENSOR
elif scaling_per_output == ScalingPerOutputType.GROUP:
return StatsInputViewShapeImpl.OVER_SUBCHANNEL_BLOCK


class HQOScale(ExtendedInjector):
scaling_impl_type = ScalingImplType.PARAMETER_FROM_STATS
scaling_stats_input_view_shape_impl = nn.Identity()

inner_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl
stats_impl_scale = HalfQuadraticOptimizerScale

@value
def scaling_stats_impl():
return this.stats_impl_scale

@value
def inner_stats_input_view_shape_impl(scaling_per_output):
if scaling_per_output == ScalingPerOutputType.CHANNEL:
return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
elif scaling_per_output == ScalingPerOutputType.TENSOR:
return StatsInputViewShapeImpl.OVER_TENSOR
elif scaling_per_output == ScalingPerOutputType.GROUP:
return StatsInputViewShapeImpl.OVER_SUBCHANNEL_BLOCK


class HQOAsymmetricScale(HQOScale):
hqo_init_op_scale = AbsMinMax
Expand Down
1 change: 1 addition & 0 deletions src/brevitas/quant/shifted_scaled_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from brevitas.quant.base import HQOActZeroPoint
from brevitas.quant.base import HQOAsymmetricScale
from brevitas.quant.base import HQOZeroPoint
from brevitas.quant.scaled_int import Int8WeightPerTensorFloat
from brevitas.quant.solver.act import ActQuantSolver
from brevitas.quant.solver.bias import BiasQuantSolver
from brevitas.quant.solver.trunc import TruncQuantSolver
Expand Down
78 changes: 50 additions & 28 deletions src/brevitas/quant_tensor/groupwise_int_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch

from brevitas.function.ops_ste import round_ste
from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor.base_quant_tensor import GroupwisIntQuantTensorBase
from brevitas.quant_tensor.base_quant_tensor import QuantTensor
Expand Down Expand Up @@ -101,29 +102,29 @@ def zero_point(self):
new_value, new_scale, new_zp = self.expand()
return new_zp

@property
def _pre_round_float_value(self):
value, scale, zp = self.expand()
if self.scale.dtype == torch.bfloat16:
value = value.type(torch.float32)
scale = scale.type(torch.float32)
minifloat_value = value / scale
fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width
int_scale = float_internal_scale(self.value, self.mantissa_bit_width, fp_internal_scale)
minifloat_value = minifloat_value / int_scale
return minifloat_value

@property
def is_valid(self):
with torch.no_grad():
pre_round_minifloat_value = self._pre_round_float_value
rounded_minifloat_value = torch.round(pre_round_minifloat_value)
max_abs_diff = torch.max(torch.abs(pre_round_minifloat_value - rounded_minifloat_value))
pre_round_int_value = self._pre_round_int_value
rounded_int_value = torch.round(pre_round_int_value)
max_abs_diff = torch.max(torch.abs(pre_round_int_value - rounded_int_value))
atol = BFLOAT16_IS_VALID_ATOL if self.value.dtype == torch.bfloat16 else IS_VALID_ATOL
is_minifloat = max_abs_diff < atol
# We are missing the checks about self being contained between max and min value
# given by mantissa, exponent, inf, nan, and saturating
return is_minifloat
is_int = max_abs_diff < atol
if self.bit_width >= 2:
if self.signed:
is_upper_b = (2.0 ** (self.bit_width - 1) - 1 >= rounded_int_value).all()
is_lower_b = (-2.0 ** (self.bit_width - 1) <= rounded_int_value).all()
else:
is_upper_b = (2.0 ** self.bit_width - 1 >= rounded_int_value).all()
is_lower_b = (0. <= rounded_int_value).all()
return (is_int & is_upper_b & is_lower_b).item()
else: # binary case
unique_vals = rounded_int_value.unique(
sorted=False, return_counts=False, return_inverse=False)
is_binary = unique_vals.view(-1).size()[0] == 2
is_signed = (unique_vals < 0.).any().item()
sign_match = is_signed == self.signed
return is_int.item() and is_binary and sign_match

@property
def device(self):
Expand All @@ -139,17 +140,38 @@ def device(self):
raise RuntimeError("Value and metadata are on different devices")
return value_device

def minifloat(self, float_datatype=True):
# TODO: Check if OCP and cast to proper data-type if matching
assert float_datatype, "Minifloat quant returns only higher precision dtype"

@property
def _pre_round_int_value(self):
value = self.value
scale = self.scale
zero_point = self.zero_point
if self.scale.dtype == torch.bfloat16:
value = self.value.type(torch.float32)
scale = self.scale.type(torch.float32)
zero_point = self.zero_point.type(torch.float32)
int_value = value / scale
int_value = int_value + zero_point
return int_value

def int(self, float_datatype=False):
if self.is_valid:
fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width
int_scale = float_internal_scale(self.value, self.mantissa_bit_width, fp_internal_scale)
float_value = torch.round(self._pre_round_float_value) * int_scale
return float_value.type(self.scale.dtype)
int_value = round_ste(self._pre_round_int_value)
if float_datatype:
# Values at 8bit and lower can be represented exactly with float16 and bfloat16
# otherwise (e.g. Int16 bias), we upscale to float32
if self.bit_width <= 8.:
return int_value.type(self.scale.dtype)
else:
return int_value.type(torch.float32)
else:
if self.bit_width <= 8. and self.signed_t.item():
return int_value.to(torch.int8)
elif self.bit_width <= 8. and not self.signed_t.item():
return int_value.to(torch.uint8)
else:
return int_value.to(torch.int32)
else:
raise RuntimeError(f"FloatQuantTensor not valid.")
raise RuntimeError(f"IntQuantTensor not valid.")

@staticmethod
def check_input_type(tensor):
Expand Down
27 changes: 5 additions & 22 deletions src/brevitas_examples/common/generative/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from brevitas.core.stats import NegativeMinOrZero
from brevitas.core.stats.stats_op import HalfQuadraticOptimizerZeroPoint
from brevitas.core.stats.stats_wrapper import SCALAR_SHAPE
from brevitas.core.zero_point import StatsFromParameterZeroPoint
from brevitas.inject import ExtendedInjector
from brevitas.inject import this
from brevitas.inject import value
Expand All @@ -22,11 +23,13 @@
from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector
from brevitas.proxy.groupwise_int_runtime_quant import GroupwiseActQuantProxyFromInjector
from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector
from brevitas.quant.base import HQOWeightZeroPoint
from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat
from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
from brevitas.quant.scaled_int import Int8WeightPerChannelFloat
from brevitas.quant.scaled_int import Int8WeightPerChannelFloatHQO
from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat

Expand Down Expand Up @@ -63,34 +66,14 @@ class ShiftedUintWeightAsymmetricGroupQuant(ShiftedUint8WeightPerChannelFloat):
scaling_per_output_type = ScalingPerOutputType.GROUP


from brevitas.quant.scaled_int import Int8WeightPerChannelFloatHQO


class ShiftedUintWeightAsymmetricGroupQuantHQO(Int8WeightPerChannelFloatHQO):
class ShiftedUintWeightAsymmetricGroupQuantHQO(HQOWeightZeroPoint,
ShiftedUint8WeightPerChannelFloat):
"""
Block / group / vector signed asymmetric weight quantizer with float scales and zero-points.
"""
proxy_class = GroupwiseWeightQuantProxyFromInjector
scaling_per_output_type = ScalingPerOutputType.GROUP

# zero_point_input_shape = this.scaling_input_shape
# reshaped_zero_point_shape = this.reshaped_scaling_shape
# zero_point_shape = this.scaling_shape
# # inner_expanded_zero_point_shape = this.expanded_scaling_shape
# # expanded_zero_point_shape = this.expanded_scaling_shape
# zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl
# zero_point_stats_input_concat_dim = 0
# # zero_point_impl = ExpandReshapeZeroPointWrapper
# zero_point_stats_impl = HalfQuadraticOptimizerZeroPoint
# hqo_init_op_zp = NegativeMinOrZero
# scaling_stats_impl = AbsMinMax
# keepdim = True
# # zero-point is converted to a parameter right away
# zero_point_impl = ParameterFromStatsFromParameterZeroPoint
# quantize_zero_point = False
# signed = False
# inner_stats_input_view_shape_impl = torch.nn.Identity()


class Int8DynamicActPerTensorFloat(DynamicActProxyMixin, Int8ActPerTensorFloat):
"""
Expand Down
15 changes: 7 additions & 8 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import re

import numpy as np
# from optimum.amd.brevitas.accelerate_utils import offload_model
# from optimum.amd.brevitas.accelerate_utils import remove_hooks
# from optimum.exporters.onnx import onnx_export_from_model
from optimum.amd.brevitas.accelerate_utils import offload_model
from optimum.amd.brevitas.accelerate_utils import remove_hooks
from optimum.exporters.onnx import onnx_export_from_model
import torch
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
Expand Down Expand Up @@ -302,17 +302,17 @@ def main():
if args.weight_equalization:
print("Apply weight equalization...")
# In case of float16 model, we need to offload to account for missing ops
# model = offload_model(model)
model = offload_model(model)
apply_weight_equalization(model)
# remove_hooks(model)
remove_hooks(model)
print("Weight equalization applied.")

if args.act_equalization is not None:
# offload_model(model)
offload_model(model)
print("Apply act equalization (SmoothQuant)...")
apply_act_equalization(model, args.act_equalization, calibration_loader)
print("Act equalization applied.")
# remove_hooks(model)
remove_hooks(model)

if not args.no_quantize:
print("Applying model quantization...")
Expand All @@ -339,7 +339,6 @@ def main():
quantize_embedding=args.quantize_embedding)
# Tie back first/last layer weights in case they got untied
print("Model quantization applied.")

# If any equalization has taken places, the embedding layer and the fully connected one are
# not tied anymore, and they need to be treated as standalone, separate layers.
# In all other cases we can tie them back so to preserve memory.
Expand Down

0 comments on commit 5cc7648

Please sign in to comment.