Skip to content

Commit

Permalink
Feat (graph/equalize): improvements for llm equalization (#784)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Dec 15, 2023
1 parent 52daf86 commit 7c490d6
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 27 deletions.
9 changes: 9 additions & 0 deletions src/brevitas/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,15 @@ class FnToModule(CallableToModule):
def match_node(self, node: Node) -> bool:
return node.op == 'call_function' and node.target is self.old_callable

def move_node_args_to_kwargs(self, node: Node):
super().move_node_args_to_kwargs(node)
# Moving to stateful modules, we remove the 'training' argument if it is passed to the
# functional version of the layer since it is not needed anymore
kwargs = dict(node.kwargs)
if 'training' in kwargs:
del kwargs['training']
node.kwargs = immutable_dict(kwargs)


class MethodToModule(CallableToModule):

Expand Down
54 changes: 31 additions & 23 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@

_select_op = (operator.getitem, operator.__getitem__)

_reshaping_op = ('view', 'reshape', 'flatten', 'contiguous', torch.reshape, torch.flatten)

_scale_varying_activations = (
torch.nn.Sigmoid, torch.nn.Tanh, torch.nn.ReLU6, torch.nn.GELU, torch.nn.SiLU)

Expand All @@ -73,6 +75,8 @@

_batch_norm = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)

_ignore_ops = (getattr, 'size')


# Required for being hashable
@dataclass(eq=True, frozen=True)
Expand Down Expand Up @@ -279,7 +283,8 @@ def _combine_weights_bias(
weight = weight.data.reshape(weight.shape[0], -1)
bias = bias.reshape(-1, 1)

weight = torch.where(torch.abs(weight) < EPSILON, torch.tensor(EPSILON).type_as(weight), weight)
weight = torch.where(
torch.abs(weight) <= EPSILON, torch.tensor(EPSILON).type_as(weight), weight)
factor = torch.abs(bias) / torch.abs(weight)

# From https://github.com/Xilinx/Vitis-AI/blob/master/src/vai_quantizer/vai_q_pytorch/nndct_shared/optimization/commander.py#L450
Expand Down Expand Up @@ -398,7 +403,6 @@ def _no_equalize():
scale_fn = _select_scale_computation_fn(scale_computation_type)
sink_weights = [transpose(m, axis) for m, axis in sink_axes.items()]
sinks_range = scale_fn(torch.cat([w.reshape(w.size(0), -1) for w in sink_weights], 1))
sinks_range = torch.clamp(sinks_range, EPSILON)

# Determine the srcs_range based on where we are performing activation equalization or
# weight equalization
Expand Down Expand Up @@ -431,10 +435,18 @@ def _no_equalize():
"Detected source and sink with non compatible shapes, equalization is skipped")
return _no_equalize()

# Instead of clipping very low values, which would cause their reciprocal to be very large
# thus hindering quantization, we set both sources and sinks to one,
# which is the no-op equivalent for equalization.
channelwise_no_equalize = (sinks_range <= EPSILON) | (srcs_range <= EPSILON)
sinks_range = torch.where(
channelwise_no_equalize, torch.tensor(1., dtype=dtype, device=device), sinks_range)
srcs_range = torch.where(
channelwise_no_equalize, torch.tensor(1., dtype=dtype, device=device), srcs_range)

srcs_range = torch.pow(srcs_range, alpha)
sinks_range = torch.pow(sinks_range, 1 - alpha)
scaling_factors = srcs_range / sinks_range
scaling_factors = torch.clamp(scaling_factors, EPSILON)
inverse_scaling_factors = torch.reciprocal(scaling_factors)

if list_of_act_val is not None and list_of_insert_mul_node_fn is not None:
Expand All @@ -455,8 +467,8 @@ def _no_equalize():
torch.reshape(inverse_scaling_factors, src_broadcast_size)),
attr='weight')
for module, axis in sink_axes.items():
src_broadcast_size = [1] * module.weight.ndim
src_broadcast_size[axis] = module.weight.size(axis)
sink_broadcast_size = [1] * module.weight.ndim
sink_broadcast_size[axis] = module.weight.size(axis)
if isinstance(module, _batch_norm):
# We re-compute the bias as function of running_mean and running_var to adjust the
# additive factor for equalization.
Expand All @@ -466,7 +478,7 @@ def _no_equalize():
module, module.bias.clone() + additive_factor * (scaling_factors - 1), attr='bias')
_update_weights(
module,
module.weight.clone() * torch.reshape(scaling_factors, src_broadcast_size),
module.weight.clone() * torch.reshape(scaling_factors, sink_broadcast_size),
attr='weight')

return scaling_factors
Expand Down Expand Up @@ -547,9 +559,7 @@ def _is_scale_invariant_function(node: Node) -> bool:


def _is_reshaping_op(node: Node) -> bool:
return (
node.op == 'call_function' and node.target in [torch.flatten, torch.reshape] or
node.op == 'call_method' and node.target in ['view', 'reshape', 'flatten'])
return node.target in _reshaping_op


def find_srcs(graph_model: GraphModule, starting_node: Node,
Expand All @@ -575,6 +585,8 @@ def find_srcs(graph_model: GraphModule, starting_node: Node,
node.op == 'call_function' and node.target in _residual_fns):
find_srcs(graph_model, node, state)
find_sinks(graph_model, node, state)
elif node.target in _ignore_ops:
continue
else:
# If we meet an unrecognized op, we add None to invalidate the region
state.srcs.add(_UNSUPPORTED_OP)
Expand Down Expand Up @@ -606,6 +618,8 @@ def find_sinks(graph_model: GraphModule, starting_node: Node,
node.op == 'call_function' and node.target in _residual_fns):
find_sinks(graph_model, node, state)
find_srcs(graph_model, node, state)
elif node.target in _ignore_ops:
continue
else:
# If we meet an unrecognized op, we add None to invalidate the region
state.sinks.add(_UNSUPPORTED_OP)
Expand Down Expand Up @@ -699,8 +713,6 @@ def find_module(self, model, regions: List):
"""
Iterate through the model looking at immediate children of every module to look for supported modules.
This allows us to stop the search when we meet a top-level module that is supported.
Specifically, it allows to map nn.MultiheadAttetion to its quantized counterpart and not its
Linear submodules.
"""
if isinstance(model,
_supported_layers) and not isinstance(model, _batch_norm + (nn.LayerNorm,)):
Expand All @@ -713,7 +725,7 @@ def setup(self):
for region in self.regions:
batch_dim = 0
if hasattr(region, 'batch_first'):
batch_dim = 0 if region.batch_first == True else 1
batch_dim = 0 if region.batch_first else 1

hook_fn = partial(
self.forward_stats_hook, name=region, batch_dim=batch_dim, use_inp=True)
Expand Down Expand Up @@ -761,16 +773,14 @@ def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **k
# Extra check for batch_dim
if hasattr(x, 'names') and 'N' in x.names:
batch_dim = x.names.index('N')
x = x.transpose(0, batch_dim)

self.batch_dim_act_map[name] = batch_dim

input_scales = self.scale_fn(x, dim=batch_dim)
if name not in self.float_act_map:
self.float_act_map[name] = self.scale_fn(x, dim=batch_dim)
self.float_act_map[name] = input_scales
else:
batch_data = torch.cat([self.float_act_map[name].unsqueeze(batch_dim), x],
dim=batch_dim)
self.float_act_map[name] = self.scale_fn(batch_data, dim=batch_dim)
self.float_act_map[name] = torch.max(self.float_act_map[name], input_scales)

def insert_mul_node(self, scale, shape, axis, region, batch_dim=0):
broadcastable_shape = [1] * len(shape)
Expand Down Expand Up @@ -832,7 +842,7 @@ def setup(self):
for name in region.srcs + region.sinks:
module = name_to_module[name]
if hasattr(module, 'batch_first'):
batch_dim = 0 if module.batch_first == True else 1
batch_dim = 0 if module.batch_first else 1
for name in region_to_search:
act_module = name_to_module[name]
use_inp = True if region_to_search == region.sinks else False
Expand Down Expand Up @@ -907,16 +917,14 @@ def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **k
# Extra check for batch_dim
if hasattr(x, 'names') and 'N' in x.names:
batch_dim = x.names.index('N')
x = x.transpose(0, batch_dim)

self.batch_dim_act_map[name] = batch_dim

input_scales = self.scale_fn(x, dim=batch_dim)
if name not in self.float_act_map:
self.float_act_map[name] = self.scale_fn(x, dim=batch_dim)
self.float_act_map[name] = input_scales
else:
batch_data = torch.cat([self.float_act_map[name].unsqueeze(batch_dim), x],
dim=batch_dim)
self.float_act_map[name] = self.scale_fn(batch_data, dim=batch_dim)
self.float_act_map[name] = torch.max(self.float_act_map[name], input_scales)

def insert_mul_node(self, scale, shape, axis, act_node, batch_dim=0):
broadcastable_shape = [1] * len(shape)
Expand Down
3 changes: 2 additions & 1 deletion src/brevitas/graph/standardize.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ class TorchFunctionalToModule(GraphTransform):
nn.AvgPool1d), (F.avg_pool2d, nn.AvgPool2d),
(F.avg_pool3d, nn.AvgPool3d), (F.adaptive_avg_pool1d, nn.AdaptiveAvgPool1d),
(F.adaptive_avg_pool2d,
nn.AdaptiveAvgPool2d), (F.adaptive_avg_pool3d, nn.AdaptiveAvgPool3d))
nn.AdaptiveAvgPool2d), (F.adaptive_avg_pool3d,
nn.AdaptiveAvgPool3d), (F.dropout, nn.Dropout))

def __init__(self, fn_to_module_map=FN_TO_MODULE_MAP):
super().__init__()
Expand Down
11 changes: 9 additions & 2 deletions src/brevitas_examples/llm/llm_quant/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from brevitas.fx.brevitas_tracer import value_trace
from brevitas.graph.equalize import activation_equalization_mode
from brevitas.graph.equalize import EqualizeGraph
from brevitas.graph.standardize import TorchFunctionalToModule
from brevitas_examples.llm.llm_quant.run_utils import apply_layer_ptq_fn
from brevitas_examples.llm.llm_quant.run_utils import cast_to_float32

Expand All @@ -26,6 +27,12 @@ def activation_equalization_iter(curr_layer, inps, outs, cached_values, alpha):
return outs


def trace_and_standardize(model, ref_kwargs):
graph_model = value_trace(model, value_args=ref_kwargs)
graph_model = TorchFunctionalToModule().apply(graph_model)
return graph_model


@torch.no_grad()
def apply_act_equalization(
model,
Expand All @@ -49,7 +56,7 @@ def apply_act_equalization(
# We can't do fp16 tracing on CPU as many kernels are not implemented
# So we have to cast to fp32 first, trace, apply equalization, and then cast back
with cast_to_float32(model, dtype):
graph_model = value_trace(model, value_args=ref_kwargs)
graph_model = trace_and_standardize(model, ref_kwargs=ref_kwargs)
# TODO this is currently running on CPU. We need Accelerate or a TorchDispatchMode
# or an FX interpreter to run it on GPU
warnings.warn(
Expand All @@ -70,5 +77,5 @@ def apply_weight_equalization(model, dtype, ref_kwargs, scale_computation_type='
# We can't do fp16 tracing on CPU as many kernels are not implemented
# So we have to cast to fp32 first, trace, apply equalization, and then cast back
with cast_to_float32(model, dtype):
graph_model = value_trace(model, value_args=ref_kwargs)
graph_model = trace_and_standardize(model, ref_kwargs=ref_kwargs)
EqualizeGraph(scale_computation_type=scale_computation_type).apply(graph_model)
7 changes: 6 additions & 1 deletion src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,14 @@ def main():
quantize_embedding=args.quantize_embedding,
seqlen=args.seqlen)
# Tie back first/last layer weights in case they got untied
model.tie_weights()
print("Model quantization applied.")

# If any equalization has taken places, the embedding layer and the fully connected one are
# not tied anymore, and they need to be treated as standalone, separate layers.
# In all other cases we can tie them back so to preserve memory.
if args.act_equalization is None and not args.weight_equalization:
model.tie_weights()

if args.act_calibration:
print("Apply act calibration...")
apply_calibration(model, calibration_loader, args.nsamples)
Expand Down

0 comments on commit 7c490d6

Please sign in to comment.