diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index a9c492b97..e096171fb 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -28,6 +28,7 @@ __all__ = ['GraphActivationEqualization', 'LayerwiseActivationEqualization', 'EqualizeGraph'] EPSILON = 1e-9 +FLOAT16_EPSILON = 1e-4 _supported_layers = ( nn.ConvTranspose1d, @@ -334,6 +335,7 @@ def _cross_layer_equalization( # Determine device and type of tensors device = next(sinks[0].parameters()).device dtype = next(sinks[0].parameters()).dtype + epsilon = FLOAT16_EPSILON if dtype == torch.float16 else EPSILON # If equalization criteria are not met, we return a scalar one to indicate that no equalization # has been performed @@ -398,7 +400,7 @@ def _no_equalize(): scale_fn = _select_scale_computation_fn(scale_computation_type) sink_weights = [transpose(m, axis) for m, axis in sink_axes.items()] sinks_range = scale_fn(torch.cat([w.reshape(w.size(0), -1) for w in sink_weights], 1)) - sinks_range = torch.clamp(sinks_range, EPSILON) + sinks_range = torch.clamp(sinks_range, epsilon) # Determine the srcs_range based on where we are performing activation equalization or # weight equalization @@ -434,7 +436,7 @@ def _no_equalize(): srcs_range = torch.pow(srcs_range, alpha) sinks_range = torch.pow(sinks_range, 1 - alpha) scaling_factors = srcs_range / sinks_range - scaling_factors = torch.clamp(scaling_factors, EPSILON) + scaling_factors = torch.clamp(scaling_factors, epsilon) inverse_scaling_factors = torch.reciprocal(scaling_factors) if list_of_act_val is not None and list_of_insert_mul_node_fn is not None: