diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 7d20bc340..a9c492b97 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -57,7 +57,7 @@ nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d, nn.Identity, - nn.ReLU, + nn.ReLU, nn.LeakyReLU) _scale_invariant_op = (torch.mul, operator.mul, operator.imul, operator.__mul__, operator.__imul__)