Skip to content

Commit

Permalink
Add per tensor/row/group dynamic scale support, some dtype improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
volcacius committed Jul 4, 2023
1 parent c795363 commit 5e2d00a
Show file tree
Hide file tree
Showing 8 changed files with 251 additions and 117 deletions.
7 changes: 4 additions & 3 deletions src/brevitas_examples/llm/llm_quant/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def activation_equalization_iter(curr_layer, inps, outs, cached_values, alpha):
@torch.no_grad()
def apply_act_equalization(
model,
dtype,
act_equalization_type,
dataloader,
nsamples,
Expand All @@ -47,7 +48,7 @@ def apply_act_equalization(
assert ref_kwargs is not None, "Ref kwargs required to perform tracing and lift the model into FX."
# We can't do fp16 tracing on CPU as many kernels are not implemented
# So we have to cast to fp32 first, trace, apply equalization, and then cast back
with cast_to_float32(model):
with cast_to_float32(model, dtype):
graph_model = value_trace(model, value_args=ref_kwargs)
# TODO this is currently running on CPU. We need Accelerate or a TorchDispatchMode
# or an FX interpreter to run it on GPU
Expand All @@ -65,9 +66,9 @@ def apply_act_equalization(


@torch.no_grad()
def apply_weight_equalization(model, ref_kwargs, scale_computation_type='range'):
def apply_weight_equalization(model, dtype, ref_kwargs, scale_computation_type='range'):
# We can't do fp16 tracing on CPU as many kernels are not implemented
# So we have to cast to fp32 first, trace, apply equalization, and then cast back
with cast_to_float32(model):
with cast_to_float32(model, dtype):
graph_model = value_trace(model, value_args=ref_kwargs)
EqualizeGraph(scale_computation_type=scale_computation_type).apply(graph_model)
4 changes: 2 additions & 2 deletions src/brevitas_examples/llm/llm_quant/ln_affine_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ def merge_layernorm_affine_params(graph_model):


@torch.no_grad()
def apply_layernorm_affine_merge(model, ref_kwargs):
def apply_layernorm_affine_merge(model, dtype, ref_kwargs):
# We can't do fp16 tracing on CPU as many kernels are not implemented
# So we have to cast to fp32 first, trace, apply merging, and then cast back
with cast_to_float32(model):
with cast_to_float32(model, dtype):
graph_model = value_trace(model, value_args=ref_kwargs)
merge_layernorm_affine_params(graph_model)
101 changes: 48 additions & 53 deletions src/brevitas_examples/llm/llm_quant/quant_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import brevitas
from brevitas.core.function_wrapper.shape import PermuteDims
from brevitas.core.utils import SliceTensor
from brevitas.core.utils import StatelessBuffer


class OverSubChannelBlockView(brevitas.jit.ScriptModule):
Expand All @@ -33,58 +32,6 @@ def forward(self, x: torch.Tensor):
return y


class AbsMaxKeepDim(brevitas.jit.ScriptModule):
__constants__ = ['stats_reduce_dim']

def __init__(self, stats_reduce_dim) -> None:
super(AbsMaxKeepDim, self).__init__()
self.stats_reduce_dim = stats_reduce_dim

@brevitas.jit.script_method
def forward(self, x: Tensor):
if self.stats_reduce_dim is not None:
y = torch.max(torch.abs(x), dim=self.stats_reduce_dim, keepdim=True)[0]
else:
y = torch.max(torch.abs(x))
return y


class AbsMinMaxKeepDim(brevitas.jit.ScriptModule):
__constants__ = ['stats_reduce_dim']

def __init__(self, stats_reduce_dim: Optional[int] = None) -> None:
super(AbsMinMaxKeepDim, self).__init__()
self.stats_reduce_dim = stats_reduce_dim

@brevitas.jit.script_method
def forward(self, x: Tensor):
if self.stats_reduce_dim is None:
return torch.abs(torch.max(x) - torch.min(x))
else:
max_val = torch.max(x, dim=self.stats_reduce_dim, keepdim=True)[0]
min_val = torch.min(x, dim=self.stats_reduce_dim, keepdim=True)[0]
return torch.abs(max_val - min_val)


class NegativeMinOrZeroKeepDim(brevitas.jit.ScriptModule):
__constants__ = ['stats_reduce_dim']

def __init__(self, stats_reduce_dim: Optional[int] = None) -> None:
super(NegativeMinOrZeroKeepDim, self).__init__()
self.stats_reduce_dim = stats_reduce_dim
self.zero = StatelessBuffer(torch.tensor(0.0))

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tensor:
if self.stats_reduce_dim is None:
min_val = torch.min(x, keepdim=True)
else:
min_val = torch.min(x, dim=self.stats_reduce_dim, keepdim=True)[0]
min_val = torch.where(
min_val <= self.zero().to(min_val.dtype), min_val, self.zero().to(min_val.dtype))
return min_val


class ExpandReshapeScalingWrapper(brevitas.jit.ScriptModule):
__constants__ = ['expanded_scaling_shape', 'reshaped_scaling_shape']

Expand Down Expand Up @@ -138,3 +85,51 @@ def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor):
zero_point = self.wrapped_zero_point_impl.scale_shift_zero_point(
-zero_point_stats, scale, bit_width)
return zero_point


class RuntimeDynamicStatsScaling(brevitas.jit.ScriptModule):
__constants__ = ['dynamic_scaling_broadcastable_shape']

def __init__(
self,
scaling_stats_impl: nn.Module,
dynamic_scaling_broadcastable_shape: Tuple[int, ...],
scaling_stats_input_view_shape_impl: nn.Module) -> None:
super(RuntimeDynamicStatsScaling, self).__init__()
self.scaling_stats_input_view_shape_impl = scaling_stats_input_view_shape_impl
self.stats_impl = scaling_stats_impl
self.dynamic_scaling_broadcastable_shape = dynamic_scaling_broadcastable_shape

@brevitas.jit.script_method
def forward(self, x) -> Tensor:
x = self.scaling_stats_input_view_shape_impl(x)
x = self.stats_impl(x)
x = x.view(self.dynamic_scaling_broadcastable_shape)
return x


class RuntimeDynamicGroupStatsScaling(brevitas.jit.ScriptModule):

def __init__(self, group_size: int, group_dim: int, scaling_stats_impl: nn.Module) -> None:
super(RuntimeDynamicGroupStatsScaling, self).__init__()
self.group_size = group_size
self.group_dim = group_dim
self.scaling_stats_impl = scaling_stats_impl

@brevitas.jit.script_method
def group_scaling_reshape(self, stats_input):
tensor_shape = stats_input.shape
tensor_shape_list = list(tensor_shape)
tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size)
tensor_shape_list.insert(self.group_dim + 1, self.group_size)
stats_input = stats_input.view(tensor_shape_list)
return stats_input

@brevitas.jit.script_method
def forward(self, stats_input) -> Tensor:
stats_input_reshaped = self.group_scaling_reshape(stats_input)
out = self.scaling_stats_impl(stats_input_reshaped)
out = torch.clamp_min(out, min=torch.tensor(1e-6, device=out.device, dtype=out.dtype))
out = out.expand(stats_input_reshaped.shape)
out = out.reshape(stats_input.shape)
return out
148 changes: 113 additions & 35 deletions src/brevitas_examples/llm/llm_quant/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloatMSE
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE
from brevitas_examples.llm.llm_quant.quantizers import Int8ActDynamicPerGroupFloat
from brevitas_examples.llm.llm_quant.quantizers import Int8ActDynamicPerRowFloat
from brevitas_examples.llm.llm_quant.quantizers import Int8ActDynamicPerTensorFloat
from brevitas_examples.llm.llm_quant.quantizers import Int8ActPerRowFloat
from brevitas_examples.llm.llm_quant.quantizers import Int8ActPerRowFloatMSE
from brevitas_examples.llm.llm_quant.quantizers import IntWeightSymmetricGroupQuant
Expand Down Expand Up @@ -62,62 +65,78 @@
'sym': Int8WeightPerChannelFixedPointMSE},},}}

INPUT_QUANT_MAP = {
'float': {
'stats': {
'per_tensor': {
'sym': Int8ActPerTensorFloat, 'asym': ShiftedUint8ActPerTensorFloat},
'per_row': {
'sym': Int8ActPerRowFloat, 'asym': ShiftedUint8ActPerRowFloat},},
'mse': {
'per_tensor': {
'sym': Int8ActPerTensorFloatMSE, 'asym': ShiftedUint8ActPerTensorFloatMSE},
'per_row': {
'sym': Int8ActPerRowFloatMSE, 'asym': ShiftedUint8ActPerRowFloatMSE},},},
'po2': {
'stats': {
'per_tensor': {
'sym': Int8ActPerTensorFixedPoint},},
'mse': {
'per_tensor': {
'sym': Int8ActPerTensorFixedPointMSE},},}}
'static': {
'float': {
'stats': {
'per_tensor': {
'sym': Int8ActPerTensorFloat, 'asym': ShiftedUint8ActPerTensorFloat},
'per_row': {
'sym': Int8ActPerRowFloat, 'asym': ShiftedUint8ActPerRowFloat},},
'mse': {
'per_tensor': {
'sym': Int8ActPerTensorFloatMSE, 'asym': ShiftedUint8ActPerTensorFloatMSE},
'per_row': {
'sym': Int8ActPerRowFloatMSE, 'asym': ShiftedUint8ActPerRowFloatMSE},},},
'po2': {
'stats': {
'per_tensor': {
'sym': Int8ActPerTensorFixedPoint},},
'mse': {
'per_tensor': {
'sym': Int8ActPerTensorFixedPointMSE},},}},
'dynamic': {
'float': {
'stats': {
'per_tensor': {
'sym': Int8ActDynamicPerTensorFloat},
'per_row': {
'sym': Int8ActDynamicPerRowFloat},
'per_group': {
'sym': Int8ActDynamicPerGroupFloat},}}}}


def quantize_model(
model,
dtype,
weight_bit_width,
weight_param_method,
weight_scale_type,
weight_scale_precision,
weight_quant_type,
weight_quant_granularity,
weight_group_size,
quantize_weight_zero_point,
input_bit_width=None,
input_scale_precision=None,
input_scale_type=None,
input_param_method=None,
input_quant_type=None,
input_quant_granularity=None,
input_group_size=None,
quantize_input_zero_point=False,
seqlen=None):
"""
Replace float layers with quant layers in the target model
"""
# Retrive base input and weight quantizers
weight_quant = WEIGHT_QUANT_MAP[weight_scale_type][weight_param_method][
weight_quant = WEIGHT_QUANT_MAP[weight_scale_precision][weight_param_method][
weight_quant_granularity][weight_quant_type]
if input_bit_width is not None:
input_quant = INPUT_QUANT_MAP[input_scale_type][input_param_method][
input_quant = INPUT_QUANT_MAP[input_scale_type][input_scale_precision][input_param_method][
input_quant_granularity][input_quant_type]
# Some activations in MHA should always be symmetric
sym_input_quant = INPUT_QUANT_MAP[input_scale_type][input_param_method][
input_quant_granularity]['sym']
# Linear layers with 2d input should always be per tensor
per_tensor_input_quant = INPUT_QUANT_MAP[input_scale_type][input_param_method][
'per_tensor'][input_quant_type]
sym_input_quant = INPUT_QUANT_MAP[input_scale_type][input_scale_precision][
input_param_method][input_quant_granularity]['sym']
# Linear layers with 2d input should always be per tensor or per group, as there is no row dimension
if input_quant_granularity == 'per_tensor' or input_quant_granularity == 'per_row':
linear_2d_input_quant = INPUT_QUANT_MAP[input_scale_type][input_scale_precision][
input_param_method]['per_tensor'][input_quant_type]
else:
assert input_quant_granularity == 'per_group'
linear_2d_input_quant = input_quant
else:
input_quant = None
sym_input_quant = None
per_tensor_input_quant = None
linear_2d_input_quant = None

# Modify the weight quantizer based on the arguments passed in
weight_quant = weight_quant.let(
Expand All @@ -129,7 +148,7 @@ def quantize_model(
# weight scale is converted to a standalone parameter
# This is done already by default in the per_group quantizer
if weight_quant_granularity != 'per_group':
weight_quant = weight_quant.let(weight_scale_impl_type='parameter_from_stats')
weight_quant = weight_quant.let(scaling_impl_type='parameter_from_stats')
# weight zero-point is converted to a standalone parameter
# This is done already by default in the per_group quantizer
if weight_quant_type == 'asym' and weight_quant_granularity != 'per_group':
Expand All @@ -142,20 +161,34 @@ def quantize_model(
'bit_width': input_bit_width,
'quantize_zero_point': quantize_input_zero_point,
'dtype': dtype})
if input_quant_granularity == 'per_row':
if input_scale_type == 'static' and input_quant_granularity == 'per_row':
# QuantMHA internally always uses Seq, B, E
input_quant = input_quant.let(
**{
'channel_dim': 0,
'per_channel_broadcastable_shape': (seqlen, 1, 1),
'scaling_stats_permute_dims': (0, 1, 2)})
elif input_scale_type == 'dynamic':
if input_quant_granularity == 'per_tensor':
input_quant = input_quant.let(
**{
'dynamic_scaling_broadcastable_shape': (1, -1, 1),
'permute_dims': (1, 0, 2),
'stats_reduce_dim': 1})
elif input_quant_granularity == 'per_row':
input_quant = input_quant.let(
**{
'dynamic_scaling_broadcastable_shape': (seqlen, -1, 1),
'permute_dims': (1, 0, 2),
'stats_reduce_dim': 2})
elif input_quant_granularity == 'per_group':
input_quant = input_quant.let(**{'group_dim': 2, 'group_size': input_group_size})
if sym_input_quant is not None:
sym_input_quant = sym_input_quant.let(
**{
'bit_width': input_bit_width,
'quantize_zero_point': quantize_input_zero_point,
'dtype': dtype})
if input_quant_granularity == 'per_row':
if input_scale_type == 'static' and input_quant_granularity == 'per_row':
q_scaled_quant = sym_input_quant.let(
**{
'per_channel_broadcastable_shape': (1, seqlen, 1),
Expand All @@ -166,19 +199,64 @@ def quantize_model(
'scaling_stats_permute_dims': (2, 0, 1)})
v_quant = q_scaled_quant
attn_output_weights_quant = q_scaled_quant
elif input_scale_type == 'dynamic':
if input_quant_granularity == 'per_tensor':
q_scaled_quant = sym_input_quant.let(
**{
'dynamic_scaling_broadcastable_shape': (-1, 1, 1),
'permute_dims': None,
'stats_reduce_dim': 1})
k_transposed_quant = sym_input_quant.let(
**{
'dynamic_scaling_broadcastable_shape': (-1, 1, 1),
'permute_dims': None,
'stats_reduce_dim': 1})
elif input_quant_granularity == 'per_row':
q_scaled_quant = sym_input_quant.let(
**{
'dynamic_scaling_broadcastable_shape': (-1, seqlen, 1),
'permute_dims': None,
'stats_reduce_dim': 2})
k_transposed_quant = sym_input_quant.let(
**{
'dynamic_scaling_broadcastable_shape': (-1, 1, seqlen),
'permute_dims': None,
'stats_reduce_dim': 1})
elif input_quant_granularity == 'per_group':
q_scaled_quant = sym_input_quant.let(
**{
'group_dim': 2, 'group_size': input_group_size})
k_transposed_quant = sym_input_quant.let(
**{
'group_dim': 1, 'group_size': input_group_size})
v_quant = q_scaled_quant
attn_output_weights_quant = q_scaled_quant
else:
q_scaled_quant = v_quant = k_transposed_quant = attn_output_weights_quant = sym_input_quant
else:
q_scaled_quant = v_quant = k_transposed_quant = attn_output_weights_quant = None
if per_tensor_input_quant is not None:
per_tensor_input_quant = per_tensor_input_quant.let(
if linear_2d_input_quant is not None:
linear_2d_input_quant = linear_2d_input_quant.let(
**{
'bit_width': input_bit_width,
'quantize_zero_point': quantize_input_zero_point,
'dtype': dtype})
if input_scale_type == 'dynamic':
# Note: this breaks if applied to 3d Linear inputs,
# in case standard MHA layers haven't been inserted
if input_quant_granularity == 'per_tensor' or input_quant_granularity == 'per_row':
linear_2d_input_quant = linear_2d_input_quant.let(
**{
'dynamic_scaling_broadcastable_shape': (-1, 1),
'permute_dims': None,
'stats_reduce_dim': 1})
elif input_quant_granularity == 'per_group':
linear_2d_input_quant = linear_2d_input_quant.let(
**{
'group_dim': 1, 'group_size': input_group_size})

quant_linear_kwargs = {
'input_quant': per_tensor_input_quant, 'weight_quant': weight_quant, 'dtype': dtype}
'input_quant': linear_2d_input_quant, 'weight_quant': weight_quant, 'dtype': dtype}

quant_mha_kwargs = {
'in_proj_input_quant': input_quant,
Expand All @@ -190,7 +268,7 @@ def quantize_model(
'q_scaled_quant': q_scaled_quant,
'k_transposed_quant': k_transposed_quant,
'v_quant': v_quant,
'out_proj_input_quant': per_tensor_input_quant,
'out_proj_input_quant': linear_2d_input_quant,
'out_proj_weight_quant': weight_quant,
'out_proj_bias_quant': None,
'out_proj_output_quant': None,
Expand Down
Loading

0 comments on commit 5e2d00a

Please sign in to comment.