Skip to content

Commit

Permalink
Fix (graph/equalize): increase epsilon for float16
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 13, 2023
1 parent 84f4225 commit 7c13021
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
__all__ = ['GraphActivationEqualization', 'LayerwiseActivationEqualization', 'EqualizeGraph']

EPSILON = 1e-9
FLOAT16_EPSILON = 1e-4

_supported_layers = (
nn.ConvTranspose1d,
Expand Down Expand Up @@ -334,6 +335,7 @@ def _cross_layer_equalization(
# Determine device and type of tensors
device = next(sinks[0].parameters()).device
dtype = next(sinks[0].parameters()).dtype
epsilon = FLOAT16_EPSILON if dtype == torch.float16 else EPSILON

# If equalization criteria are not met, we return a scalar one to indicate that no equalization
# has been performed
Expand Down Expand Up @@ -398,7 +400,7 @@ def _no_equalize():
scale_fn = _select_scale_computation_fn(scale_computation_type)
sink_weights = [transpose(m, axis) for m, axis in sink_axes.items()]
sinks_range = scale_fn(torch.cat([w.reshape(w.size(0), -1) for w in sink_weights], 1))
sinks_range = torch.clamp(sinks_range, EPSILON)
sinks_range = torch.clamp(sinks_range, epsilon)

# Determine the srcs_range based on where we are performing activation equalization or
# weight equalization
Expand Down Expand Up @@ -434,7 +436,7 @@ def _no_equalize():
srcs_range = torch.pow(srcs_range, alpha)
sinks_range = torch.pow(sinks_range, 1 - alpha)
scaling_factors = srcs_range / sinks_range
scaling_factors = torch.clamp(scaling_factors, EPSILON)
scaling_factors = torch.clamp(scaling_factors, epsilon)
inverse_scaling_factors = torch.reciprocal(scaling_factors)

if list_of_act_val is not None and list_of_insert_mul_node_fn is not None:
Expand Down

0 comments on commit 7c13021

Please sign in to comment.