Skip to content

Commit

Permalink
Smaller epsilon
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 15, 2023
1 parent 7b3f1aa commit a0ed687
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@

__all__ = ['GraphActivationEqualization', 'LayerwiseActivationEqualization', 'EqualizeGraph']

# TODO: if we are able to run activation equalization in GPU + float16, we could have two separate
# epsilon factors for float16 (2e-5) vs float32/bfloat16 (1e-9). At the moment we are tied to one
# single epsilon for both cases.
EPSILON = 2e-5
EPSILON = 1e-9

_supported_layers = (
nn.ConvTranspose1d,
Expand Down Expand Up @@ -292,7 +289,8 @@ def _combine_weights_bias(
weight = weight.data.reshape(weight.shape[0], -1)
bias = bias.reshape(-1, 1)

weight = torch.where(torch.abs(weight) < EPSILON, torch.tensor(EPSILON).type_as(weight), weight)
weight = torch.where(
torch.abs(weight) <= EPSILON, torch.tensor(EPSILON).type_as(weight), weight)
factor = torch.abs(bias) / torch.abs(weight)

# From https://github.com/Xilinx/Vitis-AI/blob/master/src/vai_quantizer/vai_q_pytorch/nndct_shared/optimization/commander.py#L450
Expand Down Expand Up @@ -445,10 +443,10 @@ def _no_equalize():

# Instead of clipping very low values, which would cause their reciprocal to be very large
# thus hindering quantization, we set them to one, which is the no-op equivalent for equalization
sinks_range = torch.where((sinks_range < EPSILON) | (srcs_range < EPSILON),
sinks_range = torch.where((sinks_range <= EPSILON) | (srcs_range <= EPSILON),
torch.tensor(1., dtype=dtype, device=device),
sinks_range)
srcs_range = torch.where((sinks_range < EPSILON) | (srcs_range < EPSILON),
srcs_range = torch.where((sinks_range <= EPSILON) | (srcs_range <= EPSILON),
torch.tensor(1., dtype=dtype, device=device),
srcs_range)

Expand Down

0 comments on commit a0ed687

Please sign in to comment.