Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Apr 24, 2024
1 parent 948154e commit bbee578
Show file tree
Hide file tree
Showing 10 changed files with 167 additions and 79 deletions.
67 changes: 38 additions & 29 deletions src/brevitas/core/stats/stats_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,29 +556,29 @@ class HalfQuadraticOptimizerScale(torch.nn.Module):
def __init__(
self,
proxy_module,
mse_init_op,
hqo_init_op_scale,
inner_stats_input_view_shape_impl: torch.nn.Module,
scaling_min_val: Optional[float] = None,
stats_reduce_dim: Optional[int] = None,
int_scaling_impl=None,
bit_width_impl=None,
beta: float = 1e6,
kappa: float = 1.07,
lp_norm: float = .7,
hqo_iters: int = 1000):
hqo_beta_scale: float = 1e5,
hqo_kappa_scale: float = 1.01,
hqo_lp_norm_scale: float = .7,
hqo_iters_scale: int = 1000):
super(HalfQuadraticOptimizerScale, self).__init__()
self.mse_init_op = mse_init_op
self.hqo_init_op = hqo_init_op_scale
self.input_view_shape_impl = inner_stats_input_view_shape_impl
self.proxy_forward = proxy_module.forward
self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled)
self.internal_candidate = None
self.hqo_iters = hqo_iters
self.hqo_iters = hqo_iters_scale
self.stats_reduce_dim = stats_reduce_dim
self.local_loss_mode: bool = False

self.beta = beta
self.kappa = kappa
self.lp_norm = lp_norm
self.beta = hqo_beta_scale
self.kappa = hqo_kappa_scale
self.lp_norm = hqo_lp_norm_scale

self.int_scaling_impl = int_scaling_impl
self.msb_clamp_bit_width_impl = bit_width_impl
Expand Down Expand Up @@ -615,19 +615,19 @@ def parameter_search(self, xl, x):
else:
candidate = masked_median(num / den, mask, dim=self.stats_reduce_dim)
candidate = self.input_view_shape_impl(candidate)
candidate[torch.isnan(candidate)] = self.internal_candidate[torch.isnan(candidate)]
candidate = self.clamp_min_ste(candidate)
bit_width = self.msb_clamp_bit_width_impl()
int_threshold = self.int_scaling_impl(bit_width)
candidate = candidate * int_threshold
candidate[torch.isnan(candidate)] = self.internal_candidate[torch.isnan(candidate)]
candidate[torch.isinf(candidate)] = self.internal_candidate[torch.isinf(candidate)]
beta *= self.kappa

return best_candidate

def optimize(self, x):
x_view = self.input_view_shape_impl(x)

init = self.mse_init_op(x_view).detach()
init = self.hqo_init_op(x_view).detach()
best_candidate = self.parameter_search(init, x)

# Save for evaluation by other modules (e.g. zp) invoking local loss mode
Expand All @@ -643,7 +643,7 @@ def forward(self, x):
# This is invoked for the zero-point whenever scale is being optimized first
if self.internal_candidate is None:
x = self.input_view_shape_impl(x)
self.internal_candidate = self.mse_init_op(x).detach()
self.internal_candidate = self.hqo_init_op(x).detach()
return self.internal_candidate


Expand All @@ -655,25 +655,29 @@ class HalfQuadraticOptimizerZeroPoint(torch.nn.Module):
def __init__(
self,
proxy_module,
mse_init_op: torch.nn.Module,
hqo_init_op_zp: torch.nn.Module,
inner_stats_input_view_shape_impl: torch.nn.Module,
stats_reduce_dim: Optional[int] = None,
beta: float = 1.,
kappa: float = 1.01,
lp_norm: float = 1.,
hqo_iters: int = 1000):
inner_expanded_zero_point_shape=None,
reshaped_zero_point_shape=None,
hqo_beta_zp: float = 1e0,
hqo_kappa_zp: float = 1.01,
hqo_lp_norm_zp: float = .7,
hqo_iters_zp: int = 1000):
super(HalfQuadraticOptimizerZeroPoint, self).__init__()
self.mse_init_op = mse_init_op
self.hqo_init_op_zp = hqo_init_op_zp
self.input_view_shape_impl = inner_stats_input_view_shape_impl
self.proxy_forward = proxy_module.forward
self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled)
self.internal_candidate = None
self.stats_reduce_dim = stats_reduce_dim
self.local_loss_mode: bool = False
self.beta = beta
self.kappa = kappa
self.lp_norm = lp_norm
self.hqo_iters = hqo_iters
self.beta = hqo_beta_zp
self.kappa = hqo_kappa_zp
self.lp_norm = hqo_lp_norm_zp
self.hqo_iters = hqo_iters_zp
self.inner_expanded_zero_point_shape = inner_expanded_zero_point_shape
self.reshaped_zero_point_shape = reshaped_zero_point_shape

def parameter_search(self, xl, x):
best_loss = torch.tensor(float('inf'), device=x.device, dtype=x.dtype)
Expand All @@ -695,18 +699,23 @@ def parameter_search(self, xl, x):

val = self.input_view_shape_impl((x - W_e) -
quant_tensor.int() * quant_tensor.scale)
if self.inner_expanded_zero_point_shape is not None:
val = val.reshape(self.inner_expanded_zero_point_shape)
if self.stats_reduce_dim is None:
candidate = torch.mean(val)
else:
candidate = torch.mean(val, dim=self.stats_reduce_dim)
candidate = torch.mean(val, dim=self.stats_reduce_dim, keepdim=True)
candidate = self.input_view_shape_impl(candidate)
self.beta *= self.kappa
return best_candidate

def optimize(self, x):
x_view = self.input_view_shape_impl(x)

init = self.mse_init_op(x_view).detach()
init = self.hqo_init_op_zp(x_view).detach()
if self.reshaped_zero_point_shape is not None:
x = x.reshape(self.reshaped_zero_point_shape)

best_candidate = self.parameter_search(init, x)

# Save for evaluation by other modules (e.g. zp) invoking local loss mode
Expand All @@ -722,11 +731,11 @@ def forward(self, x):
# This is invoked for the zero-point whenever scale is being optimized first
if self.internal_candidate is None:
x = self.input_view_shape_impl(x)
self.internal_candidate = self.mse_init_op(x).detach()
self.internal_candidate = self.hqo_init_op_zp(x).detach()
return self.internal_candidate


def masked_median(x, mask, dim=None):
def masked_median(x, mask, dim=None, keepdim=False):
"""Compute the median of tensor x along dim, ignoring values where mask is False.
x and mask need to be broadcastable.
Expand All @@ -748,7 +757,7 @@ def masked_median(x, mask, dim=None):
if dim is None:
x_median = x_nan.nanmedian()
else:
x_median, _ = x_nan.nanmedian(dim=dim)
x_median, _ = x_nan.nanmedian(dim=dim, keepdim=keepdim)
return x_median


Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/core/zero_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False):
output_dict = super(ParameterFromStatsFromParameterZeroPoint, self).state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars)
# Avoid saving the init value
if not self.init_done:
if not self.init_done and not config._FULL_STATE_DICT:
del output_dict[prefix + 'value']
return output_dict

Expand Down
1 change: 0 additions & 1 deletion src/brevitas/graph/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,6 @@ def single_layer_update(self, percdamp=.01):
return
finally:
del self.H

for i1 in range(0, self.columns, self.blocksize):
i2 = min(i1 + self.blocksize, self.columns)
count = i2 - i1
Expand Down
50 changes: 28 additions & 22 deletions src/brevitas/quant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,8 @@ class MSEAsymmetricScaleSubInjector(MSESubInjectorBase):
mse_init_op = AbsMinMax
stats_impl = MSE
stats_reduce_dim = (this << 1).stats_reduce_dim
device = (this << 1).device
dtype = (this << 1).dtype


class MSEZeroPointSubInjector(MSESubInjectorBase):
Expand All @@ -455,6 +457,8 @@ class MSEZeroPointSubInjector(MSESubInjectorBase):
mse_search_method = 'grid'
stats_impl = MSE
stats_reduce_dim = (this << 1).stats_reduce_dim
device = (this << 1).device
dtype = (this << 1).dtype


class MSEAsymmetricScale(ExtendedInjector):
Expand Down Expand Up @@ -507,22 +511,14 @@ class MSEActZeroPoint(MSEZeroPoint):


class HQOZeroPoint(ExtendedInjector):
"""
We leverage a sub-injector to avoid a name clash between scale and zero-point.
"""

zero_point_impl = ParameterFromStatsFromParameterZeroPoint

# per_channel = this.scaling_per_output_channel
# proxy_module = this.proxy_module
mse_init_op = NegativeMinOrZero
stats_impl_lol = HalfQuadraticOptimizerZeroPoint
# stats_reduce_dim = this.stats_reduce_dim
hqo_init_op_zp = NegativeMinOrZero
stats_impl_zp = HalfQuadraticOptimizerZeroPoint
zero_point_stats_input_view_shape_impl = nn.Identity()

@value
def zero_point_stats_impl():
return this.stats_impl_lol
return this.stats_impl_zp

@value
def inner_stats_input_view_shape_impl(scaling_per_output_channel):
Expand All @@ -532,25 +528,35 @@ def inner_stats_input_view_shape_impl(scaling_per_output_channel):
return StatsInputViewShapeImpl.OVER_TENSOR


class HQOSymmetricScale(ExtendedInjector):
"""
We leverage a sub-injector to avoid a name clash between scale and zero-point.
"""

class HQOScale(ExtendedInjector):
scaling_impl_type = ScalingImplType.PARAMETER_FROM_STATS
scaling_stats_input_view_shape_impl = nn.Identity()
# per_channel = this.scaling_per_output_channel
# proxy_module = this.proxy_module
mse_init_op = AbsMax
stats_impl = HalfQuadraticOptimizerScale
# stats_reduce_dim = this.stats_reduce_dim

stats_impl_scale = HalfQuadraticOptimizerScale

@value
def scaling_stats_impl():
return this.stats_impl
return this.stats_impl_scale

@value
def inner_stats_input_view_shape_impl(scaling_per_output_channel):
if scaling_per_output_channel:
return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
else:
return StatsInputViewShapeImpl.OVER_TENSOR


class HQOAsymmetricScale(HQOScale):
hqo_init_op_scale = AbsMinMax


class HQOSymmetricScale(HQOScale):
hqo_init_op_scale = AbsMax


class HQOActZeroPoint(HQOZeroPoint):
zero_point_impl = ParameterFromRuntimeZeroPoint


class HQOWeightZeroPoint(HQOZeroPoint):
zero_point_impl = ParameterFromStatsFromParameterZeroPoint
18 changes: 17 additions & 1 deletion src/brevitas/quant/shifted_scaled_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# SPDX-License-Identifier: BSD-3-Clause

from brevitas.quant.base import *
from brevitas.quant.base import HQOActZeroPoint
from brevitas.quant.base import HQOAsymmetricScale
from brevitas.quant.base import HQOZeroPoint
from brevitas.quant.solver.act import ActQuantSolver
from brevitas.quant.solver.bias import BiasQuantSolver
Expand All @@ -16,7 +18,8 @@
'ShiftedUint8ActPerTensorFixedPointMSE',
'ShiftedUint8ActPerTensorFloatMSE',
'ShiftedUint8WeightPerTensorFloatMSE',
'ShiftedUint8WeightPerChannelFloatMSE']
'ShiftedUint8WeightPerChannelFloatMSE',
'ShiftedUint8ActPerTensorFloatHQO']


class ShiftedUint8ActPerTensorFixedPoint(ShiftedParamFromPercentileUintQuant,
Expand Down Expand Up @@ -163,3 +166,16 @@ class ShiftedUint8WeightPerChannelFloatHQO(HQOZeroPoint, ShiftedUint8WeightPerCh
>>> fc = QuantLinear(10, 5, bias=False, weight_quant=ShiftedUint8WeightPerChannelFloatHQO)
"""
quantize_zero_point = False


class ShiftedUint8ActPerTensorFloatHQO(HQOActZeroPoint, ShiftedUint8ActPerTensorFloat):
"""
8-bit per-tensor unsigned int activations quantizer with floating-point scale factor and
integer zero point. Both zero-point and scale factors are learned parameters initialized from
HQO local loss.
Examples:
>>> from brevitas.nn import QuantReLU
>>> act = QuantReLU(act_quant=ShiftedUint8ActPerTensorFloatHQO)
"""
quantize_zero_point = False
47 changes: 35 additions & 12 deletions src/brevitas_examples/common/generative/quant_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
import torch.nn as nn

import brevitas
import brevitas.config as config
from brevitas.core.function_wrapper.shape import PermuteDims
from brevitas.core.utils import inplace_tensor_add
from brevitas.core.utils import SliceTensor
from brevitas.core.zero_point import _ScaleShiftZeroPoint
from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint
from brevitas.function.ops_ste import abs_binary_sign_grad


Expand Down Expand Up @@ -58,12 +61,22 @@ class ExpandReshapeZeroPointWrapper(brevitas.jit.ScriptModule):
__constants__ = ['expanded_zero_point_shape', 'reshaped_zero_point_shape']

def __init__(
self, wrapped_zero_point_impl, expanded_zero_point_shape, reshaped_zero_point_shape):
self,
wrapped_zero_point_impl,
expanded_zero_point_shape,
reshaped_zero_point_shape,
dtype,
device):
super(ExpandReshapeZeroPointWrapper, self).__init__()
assert isinstance(wrapped_zero_point_impl, ParameterFromStatsFromParameterZeroPoint)
self.wrapped_zero_point_impl = wrapped_zero_point_impl
self.expanded_zero_point_shape = expanded_zero_point_shape
self.reshaped_zero_point_shape = reshaped_zero_point_shape
self.slice_tensor = SliceTensor()
self.init_done = False
self.local_loss_mode: bool = brevitas.jit.Attribute(False, bool)
self.wrapped_zero_point_impl.value = torch.nn.Parameter(
torch.full(reshaped_zero_point_shape, 0.0, dtype=dtype, device=device))

def unexpanded_zero_point(self, unexpanded_scale, bit_width):
"""
Expand All @@ -76,17 +89,27 @@ def unexpanded_zero_point(self, unexpanded_scale, bit_width):

@brevitas.jit.script_method
def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor):
# We have to break into wrapped_zero_point_impl since we need to expand and reshape
# Before we call into scale_shift_zero_point
zero_point_stats = self.wrapped_zero_point_impl.parameter_list_stats()
zero_point_stats = zero_point_stats.expand(self.expanded_zero_point_shape).contiguous()
# contiguous() above is to avoid an unsafe_view below
zero_point_stats = zero_point_stats.reshape(self.reshaped_zero_point_shape)
# slice tensor when required by partial quantization
zero_point_stats = self.slice_tensor(zero_point_stats)
zero_point = self.wrapped_zero_point_impl.scale_shift_zero_point(
-zero_point_stats, scale, bit_width)
return zero_point
if self.init_done:
value = self.slice_tensor(-self.wrapped_zero_point_impl.value)
return self.wrapped_zero_point_impl.scale_shift_zero_point(value, scale, bit_width)
else:
# We have to break into wrapped_zero_point_impl since we need to expand and reshape
# Before we call into scale_shift_zero_point
zero_point_stats = self.wrapped_zero_point_impl.parameter_list_stats()
zero_point_stats = zero_point_stats.expand(self.expanded_zero_point_shape).contiguous()
# contiguous() above is to avoid an unsafe_view below
zero_point_stats = zero_point_stats.reshape(self.reshaped_zero_point_shape)
# slice tensor when required by partial quantization
zero_point_stats = self.slice_tensor(zero_point_stats)
if self.local_loss_mode:
return self.wrapped_zero_point_impl.scale_shift_zero_point(
-zero_point_stats, scale, bit_width)
inplace_tensor_add(self.wrapped_zero_point_impl.value.detach(), zero_point_stats)
# self.wrapped_zero_point_impl.value.data = zero_point_stats
zero_point = self.wrapped_zero_point_impl.scale_shift_zero_point(
-zero_point_stats, scale, bit_width)
self.init_done = True
return zero_point


# TODO: restore JIT compatibility
Expand Down
Loading

0 comments on commit bbee578

Please sign in to comment.