From 8e39b02eb5d8225370ec9c4e3de285347923b044 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 13 Dec 2023 13:31:15 +0000 Subject: [PATCH 01/12] Fix (graph/equalize): increase epsilon for float16 --- src/brevitas/graph/equalize.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index a9c492b97..e096171fb 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -28,6 +28,7 @@ __all__ = ['GraphActivationEqualization', 'LayerwiseActivationEqualization', 'EqualizeGraph'] EPSILON = 1e-9 +FLOAT16_EPSILON = 1e-4 _supported_layers = ( nn.ConvTranspose1d, @@ -334,6 +335,7 @@ def _cross_layer_equalization( # Determine device and type of tensors device = next(sinks[0].parameters()).device dtype = next(sinks[0].parameters()).dtype + epsilon = FLOAT16_EPSILON if dtype == torch.float16 else EPSILON # If equalization criteria are not met, we return a scalar one to indicate that no equalization # has been performed @@ -398,7 +400,7 @@ def _no_equalize(): scale_fn = _select_scale_computation_fn(scale_computation_type) sink_weights = [transpose(m, axis) for m, axis in sink_axes.items()] sinks_range = scale_fn(torch.cat([w.reshape(w.size(0), -1) for w in sink_weights], 1)) - sinks_range = torch.clamp(sinks_range, EPSILON) + sinks_range = torch.clamp(sinks_range, epsilon) # Determine the srcs_range based on where we are performing activation equalization or # weight equalization @@ -434,7 +436,7 @@ def _no_equalize(): srcs_range = torch.pow(srcs_range, alpha) sinks_range = torch.pow(sinks_range, 1 - alpha) scaling_factors = srcs_range / sinks_range - scaling_factors = torch.clamp(scaling_factors, EPSILON) + scaling_factors = torch.clamp(scaling_factors, epsilon) inverse_scaling_factors = torch.reciprocal(scaling_factors) if list_of_act_val is not None and list_of_insert_mul_node_fn is not None: From 3c26d5419c6d608b2fcd54f0b9e4759612f64bde Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 13 Dec 2023 14:36:20 +0000 Subject: [PATCH 02/12] Followup --- src/brevitas/graph/equalize.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index e096171fb..0012e9065 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -28,7 +28,7 @@ __all__ = ['GraphActivationEqualization', 'LayerwiseActivationEqualization', 'EqualizeGraph'] EPSILON = 1e-9 -FLOAT16_EPSILON = 1e-4 +FLOAT16_EPSILON = 2e-5 _supported_layers = ( nn.ConvTranspose1d, @@ -280,7 +280,8 @@ def _combine_weights_bias( weight = weight.data.reshape(weight.shape[0], -1) bias = bias.reshape(-1, 1) - weight = torch.where(torch.abs(weight) < EPSILON, torch.tensor(EPSILON).type_as(weight), weight) + epsilon = FLOAT16_EPSILON if weight.dtype == torch.float16 else EPSILON + weight = torch.where(torch.abs(weight) < epsilon, torch.tensor(epsilon).type_as(weight), weight) factor = torch.abs(bias) / torch.abs(weight) # From https://github.com/Xilinx/Vitis-AI/blob/master/src/vai_quantizer/vai_q_pytorch/nndct_shared/optimization/commander.py#L450 From 6beff1fb02e13b1b022bb1b801ca1691d109ab3e Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 15 Dec 2023 13:12:13 +0000 Subject: [PATCH 03/12] Follow up --- src/brevitas/graph/base.py | 9 ++++ src/brevitas/graph/equalize.py | 47 ++++++++++++------- src/brevitas/graph/standardize.py | 3 +- .../llm/llm_quant/equalize.py | 4 ++ src/brevitas_examples/llm/main.py | 7 ++- 5 files changed, 51 insertions(+), 19 deletions(-) diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index bbbbc27bb..d594c851d 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -296,6 +296,15 @@ class FnToModule(CallableToModule): def match_node(self, node: Node) -> bool: return node.op == 'call_function' and node.target is self.old_callable + def move_node_args_to_kwargs(self, node: Node): + super().move_node_args_to_kwargs(node) + # Moving to stateful modules, we remove the 'training' argument if it is passed to the + # functional version of the layer since it is not needed anymore + kwargs = dict(node.kwargs) + if 'training' in kwargs: + del kwargs['training'] + node.kwargs = immutable_dict(kwargs) + class MethodToModule(CallableToModule): diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 0012e9065..f753eba8c 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -27,8 +27,10 @@ __all__ = ['GraphActivationEqualization', 'LayerwiseActivationEqualization', 'EqualizeGraph'] -EPSILON = 1e-9 -FLOAT16_EPSILON = 2e-5 +# TODO: if we are able to run activation equalization in GPU + float16, we could have two separate +# epsilon factors for float16 (2e-5) vs float32/bfloat16 (1e-9). At the moment we are tied to one +# single epsilon for both cases. +EPSILON = 2e-5 _supported_layers = ( nn.ConvTranspose1d, @@ -74,6 +76,8 @@ _batch_norm = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d) +_ignore_ops = (getattr, 'size', 'contiguous') + # Required for being hashable @dataclass(eq=True, frozen=True) @@ -280,8 +284,7 @@ def _combine_weights_bias( weight = weight.data.reshape(weight.shape[0], -1) bias = bias.reshape(-1, 1) - epsilon = FLOAT16_EPSILON if weight.dtype == torch.float16 else EPSILON - weight = torch.where(torch.abs(weight) < epsilon, torch.tensor(epsilon).type_as(weight), weight) + weight = torch.where(torch.abs(weight) < EPSILON, torch.tensor(EPSILON).type_as(weight), weight) factor = torch.abs(bias) / torch.abs(weight) # From https://github.com/Xilinx/Vitis-AI/blob/master/src/vai_quantizer/vai_q_pytorch/nndct_shared/optimization/commander.py#L450 @@ -336,7 +339,6 @@ def _cross_layer_equalization( # Determine device and type of tensors device = next(sinks[0].parameters()).device dtype = next(sinks[0].parameters()).dtype - epsilon = FLOAT16_EPSILON if dtype == torch.float16 else EPSILON # If equalization criteria are not met, we return a scalar one to indicate that no equalization # has been performed @@ -401,7 +403,6 @@ def _no_equalize(): scale_fn = _select_scale_computation_fn(scale_computation_type) sink_weights = [transpose(m, axis) for m, axis in sink_axes.items()] sinks_range = scale_fn(torch.cat([w.reshape(w.size(0), -1) for w in sink_weights], 1)) - sinks_range = torch.clamp(sinks_range, epsilon) # Determine the srcs_range based on where we are performing activation equalization or # weight equalization @@ -434,10 +435,16 @@ def _no_equalize(): "Detected source and sink with non compatible shapes, equalization is skipped") return _no_equalize() + # Instead of clipping very low values, which would cause their reciprocal to be very large + # thus hindering quantization, we set them to one, which is the no-op equivalent for equalization + sinks_range = torch.where( + sinks_range > EPSILON, sinks_range, torch.tensor(1., dtype=dtype, device=device)) + srcs_range = torch.where( + srcs_range > EPSILON, srcs_range, torch.tensor(1., dtype=dtype, device=device)) srcs_range = torch.pow(srcs_range, alpha) + sinks_range = torch.pow(sinks_range, 1 - alpha) scaling_factors = srcs_range / sinks_range - scaling_factors = torch.clamp(scaling_factors, epsilon) inverse_scaling_factors = torch.reciprocal(scaling_factors) if list_of_act_val is not None and list_of_insert_mul_node_fn is not None: @@ -458,8 +465,8 @@ def _no_equalize(): torch.reshape(inverse_scaling_factors, src_broadcast_size)), attr='weight') for module, axis in sink_axes.items(): - src_broadcast_size = [1] * module.weight.ndim - src_broadcast_size[axis] = module.weight.size(axis) + sink_broadcast_size = [1] * module.weight.ndim + sink_broadcast_size[axis] = module.weight.size(axis) if isinstance(module, _batch_norm): # We re-compute the bias as function of running_mean and running_var to adjust the # additive factor for equalization. @@ -469,7 +476,7 @@ def _no_equalize(): module, module.bias.clone() + additive_factor * (scaling_factors - 1), attr='bias') _update_weights( module, - module.weight.clone() * torch.reshape(scaling_factors, src_broadcast_size), + module.weight.clone() * torch.reshape(scaling_factors, sink_broadcast_size), attr='weight') return scaling_factors @@ -578,6 +585,8 @@ def find_srcs(graph_model: GraphModule, starting_node: Node, node.op == 'call_function' and node.target in _residual_fns): find_srcs(graph_model, node, state) find_sinks(graph_model, node, state) + elif node.target in _ignore_ops: + continue else: # If we meet an unrecognized op, we add None to invalidate the region state.srcs.add(_UNSUPPORTED_OP) @@ -609,6 +618,8 @@ def find_sinks(graph_model: GraphModule, starting_node: Node, node.op == 'call_function' and node.target in _residual_fns): find_sinks(graph_model, node, state) find_srcs(graph_model, node, state) + elif node.target in _ignore_ops: + continue else: # If we meet an unrecognized op, we add None to invalidate the region state.sinks.add(_UNSUPPORTED_OP) @@ -764,16 +775,14 @@ def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **k # Extra check for batch_dim if hasattr(x, 'names') and 'N' in x.names: batch_dim = x.names.index('N') - x = x.transpose(0, batch_dim) self.batch_dim_act_map[name] = batch_dim + input_scales = self.scale_fn(x, dim=batch_dim) if name not in self.float_act_map: - self.float_act_map[name] = self.scale_fn(x, dim=batch_dim) + self.float_act_map[name] = input_scales else: - batch_data = torch.cat([self.float_act_map[name].unsqueeze(batch_dim), x], - dim=batch_dim) - self.float_act_map[name] = self.scale_fn(batch_data, dim=batch_dim) + self.float_act_map[name] = torch.max(self.float_act_map[name], input_scales) def insert_mul_node(self, scale, shape, axis, region, batch_dim=0): broadcastable_shape = [1] * len(shape) @@ -910,10 +919,8 @@ def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **k # Extra check for batch_dim if hasattr(x, 'names') and 'N' in x.names: batch_dim = x.names.index('N') - x = x.transpose(0, batch_dim) self.batch_dim_act_map[name] = batch_dim - if name not in self.float_act_map: self.float_act_map[name] = self.scale_fn(x, dim=batch_dim) else: @@ -921,6 +928,12 @@ def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **k dim=batch_dim) self.float_act_map[name] = self.scale_fn(batch_data, dim=batch_dim) + input_scales = self.scale_fn(x, dim=batch_dim) + if name not in self.float_act_map: + self.float_act_map[name] = input_scales + else: + self.float_act_map[name] = torch.max(self.float_act_map[name], input_scales) + def insert_mul_node(self, scale, shape, axis, act_node, batch_dim=0): broadcastable_shape = [1] * len(shape) broadcastable_shape[axis] = shape[axis] diff --git a/src/brevitas/graph/standardize.py b/src/brevitas/graph/standardize.py index 5ff3b6676..c3d667aee 100644 --- a/src/brevitas/graph/standardize.py +++ b/src/brevitas/graph/standardize.py @@ -107,7 +107,8 @@ class TorchFunctionalToModule(GraphTransform): nn.AvgPool1d), (F.avg_pool2d, nn.AvgPool2d), (F.avg_pool3d, nn.AvgPool3d), (F.adaptive_avg_pool1d, nn.AdaptiveAvgPool1d), (F.adaptive_avg_pool2d, - nn.AdaptiveAvgPool2d), (F.adaptive_avg_pool3d, nn.AdaptiveAvgPool3d)) + nn.AdaptiveAvgPool2d), (F.adaptive_avg_pool3d, + nn.AdaptiveAvgPool3d), (F.dropout, nn.Dropout)) def __init__(self, fn_to_module_map=FN_TO_MODULE_MAP): super().__init__() diff --git a/src/brevitas_examples/llm/llm_quant/equalize.py b/src/brevitas_examples/llm/llm_quant/equalize.py index f3e4c3b0d..f736971a9 100644 --- a/src/brevitas_examples/llm/llm_quant/equalize.py +++ b/src/brevitas_examples/llm/llm_quant/equalize.py @@ -10,6 +10,8 @@ from brevitas.fx.brevitas_tracer import value_trace from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.equalize import EqualizeGraph +from brevitas.graph.standardize import DuplicateSharedStatelessModule +from brevitas.graph.standardize import TorchFunctionalToModule from brevitas_examples.llm.llm_quant.run_utils import apply_layer_ptq_fn from brevitas_examples.llm.llm_quant.run_utils import cast_to_float32 @@ -50,6 +52,8 @@ def apply_act_equalization( # So we have to cast to fp32 first, trace, apply equalization, and then cast back with cast_to_float32(model, dtype): graph_model = value_trace(model, value_args=ref_kwargs) + graph_model = TorchFunctionalToModule().apply(graph_model) + graph_model = DuplicateSharedStatelessModule().apply(graph_model) # TODO this is currently running on CPU. We need Accelerate or a TorchDispatchMode # or an FX interpreter to run it on GPU warnings.warn( diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 4d8b2c3ef..05dc06981 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -303,9 +303,14 @@ def main(): quantize_embedding=args.quantize_embedding, seqlen=args.seqlen) # Tie back first/last layer weights in case they got untied - model.tie_weights() print("Model quantization applied.") + # If any equalization has taken places, the embedding layer and the fully connected one are + # not tied anymore, and they need to be treated as standalone, separate layers. + # In all other cases we can tie them back so to preserve memory. + if args.act_equalization is None and not args.weight_equalization: + model.tie_weights() + if args.act_calibration: print("Apply act calibration...") apply_calibration(model, calibration_loader, args.nsamples) From 7d724f622f7d66d2255f2d3821bf94b258858b14 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 15 Dec 2023 13:17:18 +0000 Subject: [PATCH 04/12] Fix --- src/brevitas/graph/equalize.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index f753eba8c..c28bce57d 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -727,7 +727,7 @@ def setup(self): for region in self.regions: batch_dim = 0 if hasattr(region, 'batch_first'): - batch_dim = 0 if region.batch_first == True else 1 + batch_dim = 0 if region.batch_first else 1 hook_fn = partial( self.forward_stats_hook, name=region, batch_dim=batch_dim, use_inp=True) @@ -844,7 +844,7 @@ def setup(self): for name in region.srcs + region.sinks: module = name_to_module[name] if hasattr(module, 'batch_first'): - batch_dim = 0 if module.batch_first == True else 1 + batch_dim = 0 if module.batch_first else 1 for name in region_to_search: act_module = name_to_module[name] use_inp = True if region_to_search == region.sinks else False @@ -920,14 +920,6 @@ def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **k if hasattr(x, 'names') and 'N' in x.names: batch_dim = x.names.index('N') - self.batch_dim_act_map[name] = batch_dim - if name not in self.float_act_map: - self.float_act_map[name] = self.scale_fn(x, dim=batch_dim) - else: - batch_data = torch.cat([self.float_act_map[name].unsqueeze(batch_dim), x], - dim=batch_dim) - self.float_act_map[name] = self.scale_fn(batch_data, dim=batch_dim) - input_scales = self.scale_fn(x, dim=batch_dim) if name not in self.float_act_map: self.float_act_map[name] = input_scales From 8f6d656fc473377fab73b855b68b18dc62448b2a Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 15 Dec 2023 13:25:14 +0000 Subject: [PATCH 05/12] Fix --- src/brevitas_examples/llm/llm_quant/equalize.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/equalize.py b/src/brevitas_examples/llm/llm_quant/equalize.py index f736971a9..d92ae64b6 100644 --- a/src/brevitas_examples/llm/llm_quant/equalize.py +++ b/src/brevitas_examples/llm/llm_quant/equalize.py @@ -28,6 +28,12 @@ def activation_equalization_iter(curr_layer, inps, outs, cached_values, alpha): return outs +def trace_and_standardize(model, ref_kwargs): + graph_model = value_trace(model, value_args=ref_kwargs) + graph_model = TorchFunctionalToModule().apply(graph_model) + graph_model = DuplicateSharedStatelessModule().apply(graph_model) + + @torch.no_grad() def apply_act_equalization( model, @@ -51,9 +57,7 @@ def apply_act_equalization( # We can't do fp16 tracing on CPU as many kernels are not implemented # So we have to cast to fp32 first, trace, apply equalization, and then cast back with cast_to_float32(model, dtype): - graph_model = value_trace(model, value_args=ref_kwargs) - graph_model = TorchFunctionalToModule().apply(graph_model) - graph_model = DuplicateSharedStatelessModule().apply(graph_model) + graph_model = trace_and_standardize(model, ref_kwargs=ref_kwargs) # TODO this is currently running on CPU. We need Accelerate or a TorchDispatchMode # or an FX interpreter to run it on GPU warnings.warn( @@ -74,5 +78,5 @@ def apply_weight_equalization(model, dtype, ref_kwargs, scale_computation_type=' # We can't do fp16 tracing on CPU as many kernels are not implemented # So we have to cast to fp32 first, trace, apply equalization, and then cast back with cast_to_float32(model, dtype): - graph_model = value_trace(model, value_args=ref_kwargs) + graph_model = trace_and_standardize(model, ref_kwargs=ref_kwargs) EqualizeGraph(scale_computation_type=scale_computation_type).apply(graph_model) From 06bcaaef983b3631b1362d911d9c60a4dc083ab4 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 15 Dec 2023 13:50:11 +0000 Subject: [PATCH 06/12] Fix --- src/brevitas/graph/equalize.py | 12 +++++++----- src/brevitas_examples/llm/llm_quant/equalize.py | 1 + 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index c28bce57d..04ae3904e 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -437,12 +437,14 @@ def _no_equalize(): # Instead of clipping very low values, which would cause their reciprocal to be very large # thus hindering quantization, we set them to one, which is the no-op equivalent for equalization - sinks_range = torch.where( - sinks_range > EPSILON, sinks_range, torch.tensor(1., dtype=dtype, device=device)) - srcs_range = torch.where( - srcs_range > EPSILON, srcs_range, torch.tensor(1., dtype=dtype, device=device)) - srcs_range = torch.pow(srcs_range, alpha) + sinks_range = torch.where((sinks_range < EPSILON) | (srcs_range < EPSILON), + torch.tensor(1., dtype=dtype, device=device), + sinks_range) + srcs_range = torch.where((sinks_range < EPSILON) | (srcs_range < EPSILON), + torch.tensor(1., dtype=dtype, device=device), + srcs_range) + srcs_range = torch.pow(srcs_range, alpha) sinks_range = torch.pow(sinks_range, 1 - alpha) scaling_factors = srcs_range / sinks_range inverse_scaling_factors = torch.reciprocal(scaling_factors) diff --git a/src/brevitas_examples/llm/llm_quant/equalize.py b/src/brevitas_examples/llm/llm_quant/equalize.py index d92ae64b6..412fd7476 100644 --- a/src/brevitas_examples/llm/llm_quant/equalize.py +++ b/src/brevitas_examples/llm/llm_quant/equalize.py @@ -32,6 +32,7 @@ def trace_and_standardize(model, ref_kwargs): graph_model = value_trace(model, value_args=ref_kwargs) graph_model = TorchFunctionalToModule().apply(graph_model) graph_model = DuplicateSharedStatelessModule().apply(graph_model) + return graph_model @torch.no_grad() From 74a802f3129fc1ae80ba41a6778dc00cc3382cf5 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 15 Dec 2023 16:09:02 +0000 Subject: [PATCH 07/12] Fix --- src/brevitas/graph/equalize.py | 2 ++ src/brevitas_examples/llm/llm_quant/equalize.py | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 04ae3904e..3e885dd22 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -922,6 +922,8 @@ def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **k if hasattr(x, 'names') and 'N' in x.names: batch_dim = x.names.index('N') + self.batch_dim_act_map[name] = batch_dim + input_scales = self.scale_fn(x, dim=batch_dim) if name not in self.float_act_map: self.float_act_map[name] = input_scales diff --git a/src/brevitas_examples/llm/llm_quant/equalize.py b/src/brevitas_examples/llm/llm_quant/equalize.py index 412fd7476..040035250 100644 --- a/src/brevitas_examples/llm/llm_quant/equalize.py +++ b/src/brevitas_examples/llm/llm_quant/equalize.py @@ -31,7 +31,6 @@ def activation_equalization_iter(curr_layer, inps, outs, cached_values, alpha): def trace_and_standardize(model, ref_kwargs): graph_model = value_trace(model, value_args=ref_kwargs) graph_model = TorchFunctionalToModule().apply(graph_model) - graph_model = DuplicateSharedStatelessModule().apply(graph_model) return graph_model From 26fd708ba44a48127294405f7dbc71f4d4146896 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 15 Dec 2023 16:12:57 +0000 Subject: [PATCH 08/12] remove useless import --- src/brevitas_examples/llm/llm_quant/equalize.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/brevitas_examples/llm/llm_quant/equalize.py b/src/brevitas_examples/llm/llm_quant/equalize.py index 040035250..2dc80ccae 100644 --- a/src/brevitas_examples/llm/llm_quant/equalize.py +++ b/src/brevitas_examples/llm/llm_quant/equalize.py @@ -10,7 +10,6 @@ from brevitas.fx.brevitas_tracer import value_trace from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.equalize import EqualizeGraph -from brevitas.graph.standardize import DuplicateSharedStatelessModule from brevitas.graph.standardize import TorchFunctionalToModule from brevitas_examples.llm.llm_quant.run_utils import apply_layer_ptq_fn from brevitas_examples.llm.llm_quant.run_utils import cast_to_float32 From 89e194e645971d4e388825dafa48cb5d6cdcf656 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 15 Dec 2023 16:32:09 +0000 Subject: [PATCH 09/12] Refactor --- src/brevitas/graph/equalize.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 3e885dd22..77c50c9f4 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -63,10 +63,18 @@ nn.ReLU, nn.LeakyReLU) -_scale_invariant_op = (torch.mul, operator.mul, operator.imul, operator.__mul__, operator.__imul__) +_scale_invariant_op = ( + torch.mul, + operator.mul, + operator.imul, + operator.__mul__, + operator.__imul__, +) _select_op = (operator.getitem, operator.__getitem__) +_reshaping_op = ('view', 'reshape', 'flatten', 'contiguous', torch.reshape, torch.flatten) + _scale_varying_activations = ( torch.nn.Sigmoid, torch.nn.Tanh, torch.nn.ReLU6, torch.nn.GELU, torch.nn.SiLU) @@ -76,7 +84,7 @@ _batch_norm = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d) -_ignore_ops = (getattr, 'size', 'contiguous') +_ignore_ops = (getattr, 'size') # Required for being hashable @@ -559,9 +567,7 @@ def _is_scale_invariant_function(node: Node) -> bool: def _is_reshaping_op(node: Node) -> bool: - return ( - node.op == 'call_function' and node.target in [torch.flatten, torch.reshape] or - node.op == 'call_method' and node.target in ['view', 'reshape', 'flatten']) + return node.target in _reshaping_op def find_srcs(graph_model: GraphModule, starting_node: Node, From 7b3f1aaac374bd62ee71719976189f658f97dc3e Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 15 Dec 2023 17:08:29 +0000 Subject: [PATCH 10/12] Comment --- src/brevitas/graph/equalize.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 77c50c9f4..655fdb9c3 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -721,8 +721,6 @@ def find_module(self, model, regions: List): """ Iterate through the model looking at immediate children of every module to look for supported modules. This allows us to stop the search when we meet a top-level module that is supported. - Specifically, it allows to map nn.MultiheadAttetion to its quantized counterpart and not its - Linear submodules. """ if isinstance(model, _supported_layers) and not isinstance(model, _batch_norm + (nn.LayerNorm,)): From a0ed6876a175ebdb49c2105209259c6779678088 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 15 Dec 2023 18:22:36 +0000 Subject: [PATCH 11/12] Smaller epsilon --- src/brevitas/graph/equalize.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 655fdb9c3..2b8b98199 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -27,10 +27,7 @@ __all__ = ['GraphActivationEqualization', 'LayerwiseActivationEqualization', 'EqualizeGraph'] -# TODO: if we are able to run activation equalization in GPU + float16, we could have two separate -# epsilon factors for float16 (2e-5) vs float32/bfloat16 (1e-9). At the moment we are tied to one -# single epsilon for both cases. -EPSILON = 2e-5 +EPSILON = 1e-9 _supported_layers = ( nn.ConvTranspose1d, @@ -292,7 +289,8 @@ def _combine_weights_bias( weight = weight.data.reshape(weight.shape[0], -1) bias = bias.reshape(-1, 1) - weight = torch.where(torch.abs(weight) < EPSILON, torch.tensor(EPSILON).type_as(weight), weight) + weight = torch.where( + torch.abs(weight) <= EPSILON, torch.tensor(EPSILON).type_as(weight), weight) factor = torch.abs(bias) / torch.abs(weight) # From https://github.com/Xilinx/Vitis-AI/blob/master/src/vai_quantizer/vai_q_pytorch/nndct_shared/optimization/commander.py#L450 @@ -445,10 +443,10 @@ def _no_equalize(): # Instead of clipping very low values, which would cause their reciprocal to be very large # thus hindering quantization, we set them to one, which is the no-op equivalent for equalization - sinks_range = torch.where((sinks_range < EPSILON) | (srcs_range < EPSILON), + sinks_range = torch.where((sinks_range <= EPSILON) | (srcs_range <= EPSILON), torch.tensor(1., dtype=dtype, device=device), sinks_range) - srcs_range = torch.where((sinks_range < EPSILON) | (srcs_range < EPSILON), + srcs_range = torch.where((sinks_range <= EPSILON) | (srcs_range <= EPSILON), torch.tensor(1., dtype=dtype, device=device), srcs_range) From 89f5c13a5c4fde5393d19d47ecea166d156ed7a6 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 15 Dec 2023 20:43:45 +0000 Subject: [PATCH 12/12] formatting --- src/brevitas/graph/equalize.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 2b8b98199..174552241 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -60,13 +60,7 @@ nn.ReLU, nn.LeakyReLU) -_scale_invariant_op = ( - torch.mul, - operator.mul, - operator.imul, - operator.__mul__, - operator.__imul__, -) +_scale_invariant_op = (torch.mul, operator.mul, operator.imul, operator.__mul__, operator.__imul__) _select_op = (operator.getitem, operator.__getitem__) @@ -442,13 +436,13 @@ def _no_equalize(): return _no_equalize() # Instead of clipping very low values, which would cause their reciprocal to be very large - # thus hindering quantization, we set them to one, which is the no-op equivalent for equalization - sinks_range = torch.where((sinks_range <= EPSILON) | (srcs_range <= EPSILON), - torch.tensor(1., dtype=dtype, device=device), - sinks_range) - srcs_range = torch.where((sinks_range <= EPSILON) | (srcs_range <= EPSILON), - torch.tensor(1., dtype=dtype, device=device), - srcs_range) + # thus hindering quantization, we set both sources and sinks to one, + # which is the no-op equivalent for equalization. + channelwise_no_equalize = (sinks_range <= EPSILON) | (srcs_range <= EPSILON) + sinks_range = torch.where( + channelwise_no_equalize, torch.tensor(1., dtype=dtype, device=device), sinks_range) + srcs_range = torch.where( + channelwise_no_equalize, torch.tensor(1., dtype=dtype, device=device), srcs_range) srcs_range = torch.pow(srcs_range, alpha) sinks_range = torch.pow(sinks_range, 1 - alpha)