diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 2b8b98199..174552241 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -60,13 +60,7 @@ nn.ReLU, nn.LeakyReLU) -_scale_invariant_op = ( - torch.mul, - operator.mul, - operator.imul, - operator.__mul__, - operator.__imul__, -) +_scale_invariant_op = (torch.mul, operator.mul, operator.imul, operator.__mul__, operator.__imul__) _select_op = (operator.getitem, operator.__getitem__) @@ -442,13 +436,13 @@ def _no_equalize(): return _no_equalize() # Instead of clipping very low values, which would cause their reciprocal to be very large - # thus hindering quantization, we set them to one, which is the no-op equivalent for equalization - sinks_range = torch.where((sinks_range <= EPSILON) | (srcs_range <= EPSILON), - torch.tensor(1., dtype=dtype, device=device), - sinks_range) - srcs_range = torch.where((sinks_range <= EPSILON) | (srcs_range <= EPSILON), - torch.tensor(1., dtype=dtype, device=device), - srcs_range) + # thus hindering quantization, we set both sources and sinks to one, + # which is the no-op equivalent for equalization. + channelwise_no_equalize = (sinks_range <= EPSILON) | (srcs_range <= EPSILON) + sinks_range = torch.where( + channelwise_no_equalize, torch.tensor(1., dtype=dtype, device=device), sinks_range) + srcs_range = torch.where( + channelwise_no_equalize, torch.tensor(1., dtype=dtype, device=device), srcs_range) srcs_range = torch.pow(srcs_range, alpha) sinks_range = torch.pow(sinks_range, 1 - alpha)