Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (graph/equalize): improvements for llm equalization #784

Merged
merged 12 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/brevitas/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,15 @@ class FnToModule(CallableToModule):
def match_node(self, node: Node) -> bool:
return node.op == 'call_function' and node.target is self.old_callable

def move_node_args_to_kwargs(self, node: Node):
super().move_node_args_to_kwargs(node)
# Moving to stateful modules, we remove the 'training' argument if it is passed to the
# functional version of the layer since it is not needed anymore
kwargs = dict(node.kwargs)
if 'training' in kwargs:
del kwargs['training']
node.kwargs = immutable_dict(kwargs)


class MethodToModule(CallableToModule):

Expand Down
54 changes: 31 additions & 23 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@

_select_op = (operator.getitem, operator.__getitem__)

_reshaping_op = ('view', 'reshape', 'flatten', 'contiguous', torch.reshape, torch.flatten)

_scale_varying_activations = (
torch.nn.Sigmoid, torch.nn.Tanh, torch.nn.ReLU6, torch.nn.GELU, torch.nn.SiLU)

Expand All @@ -73,6 +75,8 @@

_batch_norm = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)

_ignore_ops = (getattr, 'size')


# Required for being hashable
@dataclass(eq=True, frozen=True)
Expand Down Expand Up @@ -279,7 +283,8 @@ def _combine_weights_bias(
weight = weight.data.reshape(weight.shape[0], -1)
bias = bias.reshape(-1, 1)

weight = torch.where(torch.abs(weight) < EPSILON, torch.tensor(EPSILON).type_as(weight), weight)
weight = torch.where(
torch.abs(weight) <= EPSILON, torch.tensor(EPSILON).type_as(weight), weight)
factor = torch.abs(bias) / torch.abs(weight)

# From https://github.com/Xilinx/Vitis-AI/blob/master/src/vai_quantizer/vai_q_pytorch/nndct_shared/optimization/commander.py#L450
Expand Down Expand Up @@ -398,7 +403,6 @@ def _no_equalize():
scale_fn = _select_scale_computation_fn(scale_computation_type)
sink_weights = [transpose(m, axis) for m, axis in sink_axes.items()]
sinks_range = scale_fn(torch.cat([w.reshape(w.size(0), -1) for w in sink_weights], 1))
sinks_range = torch.clamp(sinks_range, EPSILON)

# Determine the srcs_range based on where we are performing activation equalization or
# weight equalization
Expand Down Expand Up @@ -431,10 +435,18 @@ def _no_equalize():
"Detected source and sink with non compatible shapes, equalization is skipped")
return _no_equalize()

# Instead of clipping very low values, which would cause their reciprocal to be very large
# thus hindering quantization, we set both sources and sinks to one,
# which is the no-op equivalent for equalization.
channelwise_no_equalize = (sinks_range <= EPSILON) | (srcs_range <= EPSILON)
sinks_range = torch.where(
channelwise_no_equalize, torch.tensor(1., dtype=dtype, device=device), sinks_range)
srcs_range = torch.where(
channelwise_no_equalize, torch.tensor(1., dtype=dtype, device=device), srcs_range)

srcs_range = torch.pow(srcs_range, alpha)
sinks_range = torch.pow(sinks_range, 1 - alpha)
scaling_factors = srcs_range / sinks_range
scaling_factors = torch.clamp(scaling_factors, EPSILON)
inverse_scaling_factors = torch.reciprocal(scaling_factors)

if list_of_act_val is not None and list_of_insert_mul_node_fn is not None:
Expand All @@ -455,8 +467,8 @@ def _no_equalize():
torch.reshape(inverse_scaling_factors, src_broadcast_size)),
attr='weight')
for module, axis in sink_axes.items():
src_broadcast_size = [1] * module.weight.ndim
src_broadcast_size[axis] = module.weight.size(axis)
sink_broadcast_size = [1] * module.weight.ndim
sink_broadcast_size[axis] = module.weight.size(axis)
if isinstance(module, _batch_norm):
# We re-compute the bias as function of running_mean and running_var to adjust the
# additive factor for equalization.
Expand All @@ -466,7 +478,7 @@ def _no_equalize():
module, module.bias.clone() + additive_factor * (scaling_factors - 1), attr='bias')
_update_weights(
module,
module.weight.clone() * torch.reshape(scaling_factors, src_broadcast_size),
module.weight.clone() * torch.reshape(scaling_factors, sink_broadcast_size),
attr='weight')

return scaling_factors
Expand Down Expand Up @@ -547,9 +559,7 @@ def _is_scale_invariant_function(node: Node) -> bool:


def _is_reshaping_op(node: Node) -> bool:
return (
node.op == 'call_function' and node.target in [torch.flatten, torch.reshape] or
node.op == 'call_method' and node.target in ['view', 'reshape', 'flatten'])
return node.target in _reshaping_op


def find_srcs(graph_model: GraphModule, starting_node: Node,
Expand All @@ -575,6 +585,8 @@ def find_srcs(graph_model: GraphModule, starting_node: Node,
node.op == 'call_function' and node.target in _residual_fns):
find_srcs(graph_model, node, state)
find_sinks(graph_model, node, state)
elif node.target in _ignore_ops:
continue
else:
# If we meet an unrecognized op, we add None to invalidate the region
state.srcs.add(_UNSUPPORTED_OP)
Expand Down Expand Up @@ -606,6 +618,8 @@ def find_sinks(graph_model: GraphModule, starting_node: Node,
node.op == 'call_function' and node.target in _residual_fns):
find_sinks(graph_model, node, state)
find_srcs(graph_model, node, state)
elif node.target in _ignore_ops:
continue
else:
# If we meet an unrecognized op, we add None to invalidate the region
state.sinks.add(_UNSUPPORTED_OP)
Expand Down Expand Up @@ -699,8 +713,6 @@ def find_module(self, model, regions: List):
"""
Iterate through the model looking at immediate children of every module to look for supported modules.
This allows us to stop the search when we meet a top-level module that is supported.
Specifically, it allows to map nn.MultiheadAttetion to its quantized counterpart and not its
Linear submodules.
"""
if isinstance(model,
_supported_layers) and not isinstance(model, _batch_norm + (nn.LayerNorm,)):
Expand All @@ -713,7 +725,7 @@ def setup(self):
for region in self.regions:
batch_dim = 0
if hasattr(region, 'batch_first'):
batch_dim = 0 if region.batch_first == True else 1
batch_dim = 0 if region.batch_first else 1

hook_fn = partial(
self.forward_stats_hook, name=region, batch_dim=batch_dim, use_inp=True)
Expand Down Expand Up @@ -761,16 +773,14 @@ def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **k
# Extra check for batch_dim
if hasattr(x, 'names') and 'N' in x.names:
batch_dim = x.names.index('N')
x = x.transpose(0, batch_dim)

self.batch_dim_act_map[name] = batch_dim

input_scales = self.scale_fn(x, dim=batch_dim)
if name not in self.float_act_map:
self.float_act_map[name] = self.scale_fn(x, dim=batch_dim)
self.float_act_map[name] = input_scales
else:
batch_data = torch.cat([self.float_act_map[name].unsqueeze(batch_dim), x],
dim=batch_dim)
self.float_act_map[name] = self.scale_fn(batch_data, dim=batch_dim)
self.float_act_map[name] = torch.max(self.float_act_map[name], input_scales)

def insert_mul_node(self, scale, shape, axis, region, batch_dim=0):
broadcastable_shape = [1] * len(shape)
Expand Down Expand Up @@ -832,7 +842,7 @@ def setup(self):
for name in region.srcs + region.sinks:
module = name_to_module[name]
if hasattr(module, 'batch_first'):
batch_dim = 0 if module.batch_first == True else 1
batch_dim = 0 if module.batch_first else 1
for name in region_to_search:
act_module = name_to_module[name]
use_inp = True if region_to_search == region.sinks else False
Expand Down Expand Up @@ -907,16 +917,14 @@ def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **k
# Extra check for batch_dim
if hasattr(x, 'names') and 'N' in x.names:
batch_dim = x.names.index('N')
x = x.transpose(0, batch_dim)

self.batch_dim_act_map[name] = batch_dim
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved

input_scales = self.scale_fn(x, dim=batch_dim)
if name not in self.float_act_map:
self.float_act_map[name] = self.scale_fn(x, dim=batch_dim)
self.float_act_map[name] = input_scales
else:
batch_data = torch.cat([self.float_act_map[name].unsqueeze(batch_dim), x],
dim=batch_dim)
self.float_act_map[name] = self.scale_fn(batch_data, dim=batch_dim)
self.float_act_map[name] = torch.max(self.float_act_map[name], input_scales)

def insert_mul_node(self, scale, shape, axis, act_node, batch_dim=0):
broadcastable_shape = [1] * len(shape)
Expand Down
3 changes: 2 additions & 1 deletion src/brevitas/graph/standardize.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ class TorchFunctionalToModule(GraphTransform):
nn.AvgPool1d), (F.avg_pool2d, nn.AvgPool2d),
(F.avg_pool3d, nn.AvgPool3d), (F.adaptive_avg_pool1d, nn.AdaptiveAvgPool1d),
(F.adaptive_avg_pool2d,
nn.AdaptiveAvgPool2d), (F.adaptive_avg_pool3d, nn.AdaptiveAvgPool3d))
nn.AdaptiveAvgPool2d), (F.adaptive_avg_pool3d,
nn.AdaptiveAvgPool3d), (F.dropout, nn.Dropout))

def __init__(self, fn_to_module_map=FN_TO_MODULE_MAP):
super().__init__()
Expand Down
11 changes: 9 additions & 2 deletions src/brevitas_examples/llm/llm_quant/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from brevitas.fx.brevitas_tracer import value_trace
from brevitas.graph.equalize import activation_equalization_mode
from brevitas.graph.equalize import EqualizeGraph
from brevitas.graph.standardize import TorchFunctionalToModule
from brevitas_examples.llm.llm_quant.run_utils import apply_layer_ptq_fn
from brevitas_examples.llm.llm_quant.run_utils import cast_to_float32

Expand All @@ -26,6 +27,12 @@ def activation_equalization_iter(curr_layer, inps, outs, cached_values, alpha):
return outs


def trace_and_standardize(model, ref_kwargs):
graph_model = value_trace(model, value_args=ref_kwargs)
graph_model = TorchFunctionalToModule().apply(graph_model)
return graph_model


@torch.no_grad()
def apply_act_equalization(
model,
Expand All @@ -49,7 +56,7 @@ def apply_act_equalization(
# We can't do fp16 tracing on CPU as many kernels are not implemented
# So we have to cast to fp32 first, trace, apply equalization, and then cast back
with cast_to_float32(model, dtype):
graph_model = value_trace(model, value_args=ref_kwargs)
graph_model = trace_and_standardize(model, ref_kwargs=ref_kwargs)
# TODO this is currently running on CPU. We need Accelerate or a TorchDispatchMode
# or an FX interpreter to run it on GPU
warnings.warn(
Expand All @@ -70,5 +77,5 @@ def apply_weight_equalization(model, dtype, ref_kwargs, scale_computation_type='
# We can't do fp16 tracing on CPU as many kernels are not implemented
# So we have to cast to fp32 first, trace, apply equalization, and then cast back
with cast_to_float32(model, dtype):
graph_model = value_trace(model, value_args=ref_kwargs)
graph_model = trace_and_standardize(model, ref_kwargs=ref_kwargs)
EqualizeGraph(scale_computation_type=scale_computation_type).apply(graph_model)
7 changes: 6 additions & 1 deletion src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,14 @@ def main():
quantize_embedding=args.quantize_embedding,
seqlen=args.seqlen)
# Tie back first/last layer weights in case they got untied
model.tie_weights()
print("Model quantization applied.")

# If any equalization has taken places, the embedding layer and the fully connected one are
# not tied anymore, and they need to be treated as standalone, separate layers.
# In all other cases we can tie them back so to preserve memory.
if args.act_equalization is None and not args.weight_equalization:
model.tie_weights()

if args.act_calibration:
print("Apply act calibration...")
apply_calibration(model, calibration_loader, args.nsamples)
Expand Down
Loading