Skip to content

Commit

Permalink
Fix (core): bug in zero-point statistics with positive only values (#670
Browse files Browse the repository at this point in the history
)
  • Loading branch information
volcacius authored Nov 10, 2023
1 parent 4d51f18 commit fc7ff8e
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions src/brevitas/core/stats/stats_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def forward(self, x: Tensor) -> Tensor:
min_val = torch.min(x)
else:
min_val = torch.min(x, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0]
min_val = torch.where(
min_val <= self.zero().to(min_val.dtype), min_val, self.zero().to(min_val.dtype))
min_val = torch.clamp(min_val, max=self.zero())
return min_val


Expand Down Expand Up @@ -107,8 +106,7 @@ def forward(self, x: Tensor) -> Tensor:
# k is 1-indexed, so round away from zero
k = int(math.ceil(.01 * self.q * dim_slice.numel()))
result = x.kthvalue(k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values
result = torch.where(
result <= self.zero().to(result.dtype), result, self.zero().to(result.dtype))
result = torch.clamp(result, max=self.zero())
return result


Expand All @@ -120,12 +118,15 @@ def __init__(
low_percentile_q,
high_percentile_q,
stats_reduce_dim: Optional[int] = None,
keepdim: bool = False) -> None:
keepdim: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None) -> None:
super(PercentileInterval, self).__init__()
self.stats_reduce_dim = stats_reduce_dim
self.low_q = low_percentile_q
self.high_q = high_percentile_q
self.keepdim = keepdim
self.zero = StatelessBuffer(torch.tensor(0.0, dtype=dtype, device=device))

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tensor:
Expand All @@ -145,6 +146,8 @@ def forward(self, x: Tensor) -> Tensor:
high_k = int(math.floor(.01 * self.high_q * dim_slice.numel() + 0.5))
low_result = x.kthvalue(low_k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values
high_result = x.kthvalue(high_k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values
# We need to make sure the lower bound is not positive to align with zero-point statistics
low_result = torch.clamp(low_result, max=self.zero())
interval = high_result - low_result
abs_interval = torch.abs(interval)
return abs_interval
Expand All @@ -169,19 +172,28 @@ def forward(self, x: Tensor):
class AbsMinMax(brevitas.jit.ScriptModule):
__constants__ = ['stats_reduce_dim', 'keepdim']

def __init__(self, stats_reduce_dim: Optional[int] = None, keepdim: bool = False) -> None:
def __init__(
self,
stats_reduce_dim: Optional[int] = None,
keepdim: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None) -> None:
super(AbsMinMax, self).__init__()
self.stats_reduce_dim = stats_reduce_dim
self.keepdim = keepdim
self.zero = StatelessBuffer(torch.tensor(0.0, dtype=dtype, device=device))

@brevitas.jit.script_method
def forward(self, x: Tensor):
if self.stats_reduce_dim is None:
return torch.abs(torch.max(x) - torch.min(x))
max_val = torch.max(x)
min_val = torch.min(x)
else:
max_val = torch.max(x, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0]
min_val = torch.min(x, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0]
return torch.abs(max_val - min_val)
# We need to make sure the lower bound is not positive to align with zero-point statistics
min_val = torch.clamp(min_val, max=self.zero())
return torch.abs(max_val - min_val)


class AbsMaxAve(brevitas.jit.ScriptModule):
Expand Down

0 comments on commit fc7ff8e

Please sign in to comment.