Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 15, 2023
1 parent a0ed687 commit 89f5c13
Showing 1 changed file with 8 additions and 14 deletions.
22 changes: 8 additions & 14 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,7 @@
nn.ReLU,
nn.LeakyReLU)

_scale_invariant_op = (
torch.mul,
operator.mul,
operator.imul,
operator.__mul__,
operator.__imul__,
)
_scale_invariant_op = (torch.mul, operator.mul, operator.imul, operator.__mul__, operator.__imul__)

_select_op = (operator.getitem, operator.__getitem__)

Expand Down Expand Up @@ -442,13 +436,13 @@ def _no_equalize():
return _no_equalize()

# Instead of clipping very low values, which would cause their reciprocal to be very large
# thus hindering quantization, we set them to one, which is the no-op equivalent for equalization
sinks_range = torch.where((sinks_range <= EPSILON) | (srcs_range <= EPSILON),
torch.tensor(1., dtype=dtype, device=device),
sinks_range)
srcs_range = torch.where((sinks_range <= EPSILON) | (srcs_range <= EPSILON),
torch.tensor(1., dtype=dtype, device=device),
srcs_range)
# thus hindering quantization, we set both sources and sinks to one,
# which is the no-op equivalent for equalization.
channelwise_no_equalize = (sinks_range <= EPSILON) | (srcs_range <= EPSILON)
sinks_range = torch.where(
channelwise_no_equalize, torch.tensor(1., dtype=dtype, device=device), sinks_range)
srcs_range = torch.where(
channelwise_no_equalize, torch.tensor(1., dtype=dtype, device=device), srcs_range)

srcs_range = torch.pow(srcs_range, alpha)
sinks_range = torch.pow(sinks_range, 1 - alpha)
Expand Down

0 comments on commit 89f5c13

Please sign in to comment.