Skip to content

Commit

Permalink
Fix (llm): small fixes to LLM
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Sep 27, 2024
1 parent b28ac0f commit b08e0ac
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 4 deletions.
5 changes: 5 additions & 0 deletions src/brevitas/core/function_wrapper/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,11 @@ def __init__(self, expanded_groupwise_shape, group_size, group_dim) -> None:

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
# This one is a bit tricky but we could end up here:
# - If we quantize the zero point, which will already have expanded shape matching the scale (although no padding, but we don't need the padding)
# - Groupwise HQO quantization, where weight will already have been padded and expanded
if len(x.shape) == len(self.expanded_groupwise_shape):
return x
y = torch.nn.functional.pad(
x, padding(x, self.group_size, self.group_dim), mode='constant', value=0.)
y = y.view(self.expanded_groupwise_shape)
Expand Down
6 changes: 5 additions & 1 deletion src/brevitas/core/stats/stats_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,14 +692,18 @@ def parameter_search(self, xl, x):
self.set_local_loss_mode(False)
qt_value = self.input_view_shape_impl(quant_tensor.value)
qt_scale = self.input_view_shape_impl(quant_tensor.scale)
qt_int = self.input_view_shape_impl(quant_tensor.int())
qt_zp = self.input_view_shape_impl(quant_tensor.zero_point)
qt_int = qt_value / qt_scale + qt_zp
loss = torch.abs(qt_value - x).mean()
best_candidate = torch.where(loss < best_loss, candidate, best_candidate)
if loss >= best_loss:
break
best_loss = torch.min(loss, best_loss)
W_e = shrink_lp_op(x - qt_value, self.beta, self.lp_norm)

# Compared to the original formulation, the value we're looking for is:
# - scaled by qt_scale
# - opposite sign
val = self.input_view_shape_impl((x - W_e) - qt_int * qt_scale)

if self.stats_reduce_dim is None:
Expand Down
9 changes: 6 additions & 3 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,16 @@
'sym': Int8WeightPerChannelFixedPoint},
'per_group': {
'sym': MXInt8Weight, 'asym': ShiftedMXUInt8Weight}},
'hqo': {
'per_group': {
'asym': MXHQO}},
'mse': {
'per_tensor': {
'sym': Int8WeightPerTensorFixedPointMSE},
'per_channel': {
'sym': Int8WeightPerChannelFixedPointMSE}},
'per_group': {
'sym': MXInt8WeightMSE, 'asym': ShiftedMXUInt8WeightMSE}}},
'sym': Int8WeightPerChannelFixedPointMSE},
'per_group': {
'sym': MXInt8WeightMSE, 'asym': ShiftedMXUInt8WeightMSE}}}},
'float': {
'float_scale': {
'stats': {
Expand Down

0 comments on commit b08e0ac

Please sign in to comment.