Skip to content

Commit

Permalink
HQO
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Apr 17, 2024
1 parent a106a6d commit 948154e
Show file tree
Hide file tree
Showing 5 changed files with 331 additions and 1 deletion.
215 changes: 215 additions & 0 deletions src/brevitas/core/stats/stats_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import brevitas
from brevitas import config
from brevitas.core.function_wrapper.misc import Identity
from brevitas.core.function_wrapper.ops_ste import ScalarClampMinSte
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops import max_int
from brevitas.quant_tensor import _unpack_quant_tensor
Expand Down Expand Up @@ -544,3 +546,216 @@ def forward(self, x):
x = self.input_view_shape_impl(x)
self.internal_candidate = self.mse_init_op(x).detach()
return self.internal_candidate


class HalfQuadraticOptimizerScale(torch.nn.Module):
# References:
# https://mobiusml.github.io/hqq_blog/
# https://github.com/mobiusml/hqq?tab=readme-ov-file

def __init__(
self,
proxy_module,
mse_init_op,
inner_stats_input_view_shape_impl: torch.nn.Module,
scaling_min_val: Optional[float] = None,
stats_reduce_dim: Optional[int] = None,
int_scaling_impl=None,
bit_width_impl=None,
beta: float = 1e6,
kappa: float = 1.07,
lp_norm: float = .7,
hqo_iters: int = 1000):
super(HalfQuadraticOptimizerScale, self).__init__()
self.mse_init_op = mse_init_op
self.input_view_shape_impl = inner_stats_input_view_shape_impl
self.proxy_forward = proxy_module.forward
self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled)
self.internal_candidate = None
self.hqo_iters = hqo_iters
self.stats_reduce_dim = stats_reduce_dim
self.local_loss_mode: bool = False

self.beta = beta
self.kappa = kappa
self.lp_norm = lp_norm

self.int_scaling_impl = int_scaling_impl
self.msb_clamp_bit_width_impl = bit_width_impl
if scaling_min_val is not None and scaling_min_val != 0:
self.clamp_min_ste = ScalarClampMinSte(scaling_min_val)
else:
self.clamp_min_ste = Identity()

def parameter_search(self, xl, x):
best_loss = torch.tensor(float('inf'), device=x.device, dtype=x.dtype)
candidate = xl
candidate = self.input_view_shape_impl(candidate)
best_candidate = candidate
beta = self.beta
with torch.no_grad():
for i in range(0, self.hqo_iters):
self.internal_candidate = candidate
self.set_local_loss_mode(True)
quant_tensor = self.proxy_forward(x).detach()
self.set_local_loss_mode(False)
loss = torch.abs(quant_tensor.value - x).mean()

best_candidate = torch.where(loss < best_loss, candidate, best_candidate)
if loss >= best_loss:
break
best_loss = torch.min(loss, best_loss)
W_e = shrink_lp_op(x - quant_tensor.value, beta, self.lp_norm)
zero_point = quant_tensor.zero_point
num = self.input_view_shape_impl(x - W_e).detach()
den = self.input_view_shape_impl(quant_tensor.int() - zero_point).detach()
mask = (num != 0.) & (den != 0.)
if self.stats_reduce_dim is None:
candidate = masked_median(num / den, mask)
else:
candidate = masked_median(num / den, mask, dim=self.stats_reduce_dim)
candidate = self.input_view_shape_impl(candidate)
candidate[torch.isnan(candidate)] = self.internal_candidate[torch.isnan(candidate)]
candidate = self.clamp_min_ste(candidate)
bit_width = self.msb_clamp_bit_width_impl()
int_threshold = self.int_scaling_impl(bit_width)
candidate = candidate * int_threshold
beta *= self.kappa

return best_candidate

def optimize(self, x):
x_view = self.input_view_shape_impl(x)

init = self.mse_init_op(x_view).detach()
best_candidate = self.parameter_search(init, x)

# Save for evaluation by other modules (e.g. zp) invoking local loss mode
self.internal_candidate = best_candidate.detach()
torch.cuda.empty_cache()
return best_candidate

def forward(self, x):
if not self.local_loss_mode:
with torch.no_grad():
return self.optimize(x)
else:
# This is invoked for the zero-point whenever scale is being optimized first
if self.internal_candidate is None:
x = self.input_view_shape_impl(x)
self.internal_candidate = self.mse_init_op(x).detach()
return self.internal_candidate


class HalfQuadraticOptimizerZeroPoint(torch.nn.Module):
# References:
# https://mobiusml.github.io/hqq_blog/
# https://github.com/mobiusml/hqq?tab=readme-ov-file

def __init__(
self,
proxy_module,
mse_init_op: torch.nn.Module,
inner_stats_input_view_shape_impl: torch.nn.Module,
stats_reduce_dim: Optional[int] = None,
beta: float = 1.,
kappa: float = 1.01,
lp_norm: float = 1.,
hqo_iters: int = 1000):
super(HalfQuadraticOptimizerZeroPoint, self).__init__()
self.mse_init_op = mse_init_op
self.input_view_shape_impl = inner_stats_input_view_shape_impl
self.proxy_forward = proxy_module.forward
self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled)
self.internal_candidate = None
self.stats_reduce_dim = stats_reduce_dim
self.local_loss_mode: bool = False
self.beta = beta
self.kappa = kappa
self.lp_norm = lp_norm
self.hqo_iters = hqo_iters

def parameter_search(self, xl, x):
best_loss = torch.tensor(float('inf'), device=x.device, dtype=x.dtype)
candidate = xl
candidate = self.input_view_shape_impl(candidate)
best_candidate = candidate
with torch.no_grad():
for i in range(0, self.hqo_iters):
self.internal_candidate = candidate
self.set_local_loss_mode(True)
quant_tensor = self.proxy_forward(x).detach()
self.set_local_loss_mode(False)
loss = torch.abs(quant_tensor.value - x).mean()
best_candidate = torch.where(loss < best_loss, candidate, best_candidate)
if loss >= best_loss:
break
best_loss = torch.min(loss, best_loss)
W_e = shrink_lp_op(x - quant_tensor.value, self.beta, self.lp_norm)

val = self.input_view_shape_impl((x - W_e) -
quant_tensor.int() * quant_tensor.scale)
if self.stats_reduce_dim is None:
candidate = torch.mean(val)
else:
candidate = torch.mean(val, dim=self.stats_reduce_dim)
candidate = self.input_view_shape_impl(candidate)
self.beta *= self.kappa
return best_candidate

def optimize(self, x):
x_view = self.input_view_shape_impl(x)

init = self.mse_init_op(x_view).detach()
best_candidate = self.parameter_search(init, x)

# Save for evaluation by other modules (e.g. zp) invoking local loss mode
self.internal_candidate = best_candidate.detach()
torch.cuda.empty_cache()
return best_candidate

def forward(self, x):
if not self.local_loss_mode:
with torch.no_grad():
return self.optimize(x)
else:
# This is invoked for the zero-point whenever scale is being optimized first
if self.internal_candidate is None:
x = self.input_view_shape_impl(x)
self.internal_candidate = self.mse_init_op(x).detach()
return self.internal_candidate


def masked_median(x, mask, dim=None):
"""Compute the median of tensor x along dim, ignoring values where mask is False.
x and mask need to be broadcastable.
Args:
x (Tensor): Tensor to compute median of.
mask (BoolTensor): Same shape as x with True where x is valid and False
where x should be masked. Mask should not be all False in any column of
dimension dim to avoid NaNs from zero division.
dim (int, optional): Dimension to take median of. Defaults to 0.
Returns:
Tensor: Same shape as x, except dimension dim reduced.
"""
# uncomment this assert for safety but might impact performance
# assert (
# mask.sum(dim=dim).ne(0).all()
# ), "mask should not be all False in any column, causes zero division"
x_nan = x.float().masked_fill(~mask, float("nan"))
if dim is None:
x_median = x_nan.nanmedian()
else:
x_median, _ = x_nan.nanmedian(dim=dim)
return x_median


# Shrinking operator
def shrink_lp_op(x: Tensor, beta: float, lp_norm: float) -> Tensor:
if lp_norm == 1:
return torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - 1.0 / beta)
else:
return torch.sign(x) * torch.nn.functional.relu(
torch.abs(x) - (1.0 / beta) * torch.pow(torch.abs(x), lp_norm - 1))
52 changes: 52 additions & 0 deletions src/brevitas/quant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from brevitas.core.stats import MSE
from brevitas.core.stats import NegativeMinOrZero
from brevitas.core.stats import NegativePercentileOrZero
from brevitas.core.stats.stats_op import HalfQuadraticOptimizerScale
from brevitas.core.stats.stats_op import HalfQuadraticOptimizerZeroPoint
from brevitas.core.utils import SingleArgStatelessBuffer
from brevitas.core.zero_point import ParameterFromRuntimeZeroPoint
from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint
Expand Down Expand Up @@ -502,3 +504,53 @@ class MSEWeightZeroPoint(MSEZeroPoint):

class MSEActZeroPoint(MSEZeroPoint):
zero_point_impl = ParameterFromRuntimeZeroPoint


class HQOZeroPoint(ExtendedInjector):
"""
We leverage a sub-injector to avoid a name clash between scale and zero-point.
"""

zero_point_impl = ParameterFromStatsFromParameterZeroPoint

# per_channel = this.scaling_per_output_channel
# proxy_module = this.proxy_module
mse_init_op = NegativeMinOrZero
stats_impl_lol = HalfQuadraticOptimizerZeroPoint
# stats_reduce_dim = this.stats_reduce_dim
zero_point_stats_input_view_shape_impl = nn.Identity()

@value
def zero_point_stats_impl():
return this.stats_impl_lol

@value
def inner_stats_input_view_shape_impl(scaling_per_output_channel):
if scaling_per_output_channel:
return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
else:
return StatsInputViewShapeImpl.OVER_TENSOR


class HQOSymmetricScale(ExtendedInjector):
"""
We leverage a sub-injector to avoid a name clash between scale and zero-point.
"""

scaling_impl_type = ScalingImplType.PARAMETER_FROM_STATS
scaling_stats_input_view_shape_impl = nn.Identity()
# per_channel = this.scaling_per_output_channel
# proxy_module = this.proxy_module
mse_init_op = AbsMax
stats_impl = HalfQuadraticOptimizerScale
# stats_reduce_dim = this.stats_reduce_dim
@value
def scaling_stats_impl():
return this.stats_impl

@value
def inner_stats_input_view_shape_impl(scaling_per_output_channel):
if scaling_per_output_channel:
return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
else:
return StatsInputViewShapeImpl.OVER_TENSOR
25 changes: 25 additions & 0 deletions src/brevitas/quant/scaled_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from brevitas.core.function_wrapper import TensorClamp
from brevitas.quant.base import *
from brevitas.quant.base import HQOSymmetricScale
from brevitas.quant.solver.act import ActQuantSolver
from brevitas.quant.solver.bias import BiasQuantSolver
from brevitas.quant.solver.trunc import TruncQuantSolver
Expand Down Expand Up @@ -443,3 +444,27 @@ class Int8AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareZeroCenterWeight
>>> conv.quant_weight()
"""
bit_width = 8


class Int8WeightPerTensorFloatHQO(HQOSymmetricScale, Int8WeightPerTensorFloat):
"""
8-bit narrow per-tensor signed int weight quantizer with per-tensor floating-point scale factor computed
from HQO local loss.
Examples:
>>> from brevitas.nn import QuantLinear
>>> fc = QuantLinear(10, 5, bias=False, weight_quant=Int8WeightPerTensorFloatHQO)
"""
pass


class Int8WeightPerChannelFloatHQO(HQOSymmetricScale, Int8WeightPerChannelFloat):
"""
8-bit narrow per-tensor signed int weight quantizer with per-tensor floating-point scale factor computed
from HQO local loss.
Examples:
>>> from brevitas.nn import QuantLinear
>>> fc = QuantLinear(10, 5, bias=False, weight_quant=Int8WeightPerChannelFloatHQO)
"""
pass
25 changes: 25 additions & 0 deletions src/brevitas/quant/shifted_scaled_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

from brevitas.quant.base import *
from brevitas.quant.base import HQOZeroPoint
from brevitas.quant.solver.act import ActQuantSolver
from brevitas.quant.solver.bias import BiasQuantSolver
from brevitas.quant.solver.trunc import TruncQuantSolver
Expand Down Expand Up @@ -138,3 +139,27 @@ class ShiftedUint8WeightPerChannelFloatMSE(MSEAsymmetricScale,
>>> fc = QuantLinear(10, 5, bias=False, weight_quant=ShiftedUint8WeightPerChannelFloat)
"""
pass


class ShiftedUint8WeightPerTensorFloatHQO(HQOZeroPoint, ShiftedUint8WeightPerTensorFloat):
"""
8-bit per-tensor unsigned int weight quantizer with floating-point per-channel scale factor and integer
zero point. Both zero-point and scale factors are learned parameters initialized from HQO local losses.
Examples:
>>> from brevitas.nn import QuantLinear
>>> fc = QuantLinear(10, 5, bias=False, weight_quant=ShiftedUint8WeightPerTensorFloatHQO)
"""
quantize_zero_point = False


class ShiftedUint8WeightPerChannelFloatHQO(HQOZeroPoint, ShiftedUint8WeightPerChannelFloat):
"""
8-bit per-tensor unsigned int weight quantizer with floating-point per-channel scale factor and integer
zero point. Both zero-point and scale factors are learned parameters initialized from HQO local losses.
Examples:
>>> from brevitas.nn import QuantLinear
>>> fc = QuantLinear(10, 5, bias=False, weight_quant=ShiftedUint8WeightPerChannelFloatHQO)
"""
quantize_zero_point = False
Loading

0 comments on commit 948154e

Please sign in to comment.