Skip to content

Commit

Permalink
Feat (core): add keepdim to min/max/percentile stats
Browse files Browse the repository at this point in the history
  • Loading branch information
volcacius committed Jul 6, 2023
1 parent 8d5035b commit cdc6abc
Showing 1 changed file with 32 additions and 19 deletions.
51 changes: 32 additions & 19 deletions src/brevitas/core/stats/stats_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,39 +19,46 @@


class NegativeMinOrZero(brevitas.jit.ScriptModule):
__constants__ = ['stats_reduce_dim']
__constants__ = ['stats_reduce_dim', 'keepdim']

def __init__(
self,
stats_reduce_dim: Optional[int] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None) -> None:
device: Optional[torch.device] = None,
keepdim: bool = False) -> None:
super(NegativeMinOrZero, self).__init__()
self.stats_reduce_dim = stats_reduce_dim
self.zero = StatelessBuffer(torch.tensor(0.0, dtype=dtype, device=device))
self.keepdim = keepdim

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tensor:
if self.stats_reduce_dim is None:
min_val = torch.min(x)
else:
min_val = torch.min(x, dim=self.stats_reduce_dim)[0]
min_val = torch.min(x, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0]
min_val = torch.where(
min_val <= self.zero().to(min_val.dtype), min_val, self.zero().to(min_val.dtype))
return min_val


class AbsPercentile(brevitas.jit.ScriptModule):
__constants__ = ['q', 'stats_reduce_dim']
__constants__ = ['q', 'stats_reduce_dim', 'keepdim']

def __init__(
self, high_percentile_q: float, stats_reduce_dim: Optional[int], percentile_q=None):
self,
high_percentile_q: float,
stats_reduce_dim: Optional[int],
percentile_q=None,
keepdim: bool = False):
super(AbsPercentile, self).__init__()
if percentile_q is not None:
raise RuntimeError("percentile_q is deprecated, please pass high_percentile_q.")
assert high_percentile_q <= 100, "q has to be a percentage"
self.q = high_percentile_q
self.stats_reduce_dim = stats_reduce_dim
self.keepdim = keepdim

@brevitas.jit.script_method
def forward(self, x: Tensor):
Expand All @@ -66,23 +73,25 @@ def forward(self, x: Tensor):
dim_slice = torch.narrow(x, dim=other_dim, start=0, length=1)
# k is 1-indexed, so round away from zero
k = int(math.floor(.01 * self.q * dim_slice.numel() + 0.5))
result = x.abs().kthvalue(k, dim=self.stats_reduce_dim).values
result = x.abs().kthvalue(k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values
return result


class NegativePercentileOrZero(brevitas.jit.ScriptModule):
__constants__ = ['stats_reduce_dim', 'q']
__constants__ = ['stats_reduce_dim', 'q', 'keepdim']

def __init__(
self,
low_percentile_q,
stats_reduce_dim: Optional[int] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None) -> None:
device: Optional[torch.device] = None,
keepdim: bool = False) -> None:
super(NegativePercentileOrZero, self).__init__()
self.stats_reduce_dim = stats_reduce_dim
self.q = low_percentile_q
self.zero = StatelessBuffer(torch.tensor(0.0, dtype=dtype, device=device))
self.keepdim = keepdim

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tensor:
Expand All @@ -97,24 +106,26 @@ def forward(self, x: Tensor) -> Tensor:
dim_slice = torch.narrow(x, dim=other_dim, start=0, length=1)
# k is 1-indexed, so round away from zero
k = int(math.ceil(.01 * self.q * dim_slice.numel()))
result = x.kthvalue(k, dim=self.stats_reduce_dim).values
result = x.kthvalue(k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values
result = torch.where(
result <= self.zero().to(result.dtype), result, self.zero().to(result.dtype))
return result


class PercentileInterval(brevitas.jit.ScriptModule):
__constants__ = ['stats_reduce_dim', 'low_q', 'high_q']
__constants__ = ['stats_reduce_dim', 'low_q', 'high_q', 'keepdim']

def __init__(
self,
low_percentile_q,
high_percentile_q,
stats_reduce_dim: Optional[int] = None) -> None:
stats_reduce_dim: Optional[int] = None,
keepdim: bool = False) -> None:
super(PercentileInterval, self).__init__()
self.stats_reduce_dim = stats_reduce_dim
self.low_q = low_percentile_q
self.high_q = high_percentile_q
self.keepdim = keepdim

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tensor:
Expand All @@ -132,8 +143,8 @@ def forward(self, x: Tensor) -> Tensor:
low_k = int(math.ceil(.01 * self.low_q * dim_slice.numel()))
# k is 1-indexed, so round away from zero
high_k = int(math.floor(.01 * self.high_q * dim_slice.numel() + 0.5))
low_result = x.kthvalue(low_k, dim=self.stats_reduce_dim).values
high_result = x.kthvalue(high_k, dim=self.stats_reduce_dim).values
low_result = x.kthvalue(low_k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values
high_result = x.kthvalue(high_k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values
interval = high_result - low_result
abs_interval = torch.abs(interval)
return abs_interval
Expand All @@ -142,32 +153,34 @@ def forward(self, x: Tensor) -> Tensor:
class AbsMax(brevitas.jit.ScriptModule):
__constants__ = ['stats_reduce_dim']

def __init__(self, stats_reduce_dim: Optional[int] = None) -> None:
def __init__(self, stats_reduce_dim: Optional[int] = None, keepdim: bool = False) -> None:
super(AbsMax, self).__init__()
self.stats_reduce_dim = stats_reduce_dim
self.keepdim = keepdim

@brevitas.jit.script_method
def forward(self, x: Tensor):
if self.stats_reduce_dim is None:
return torch.max(torch.abs(x))
else:
return torch.max(torch.abs(x), dim=self.stats_reduce_dim)[0]
return torch.max(torch.abs(x), dim=self.stats_reduce_dim, keepdim=self.keepdim)[0]


class AbsMinMax(brevitas.jit.ScriptModule):
__constants__ = ['stats_reduce_dim']
__constants__ = ['stats_reduce_dim', 'keepdim']

def __init__(self, stats_reduce_dim: Optional[int] = None) -> None:
def __init__(self, stats_reduce_dim: Optional[int] = None, keepdim: bool = False) -> None:
super(AbsMinMax, self).__init__()
self.stats_reduce_dim = stats_reduce_dim
self.keepdim = keepdim

@brevitas.jit.script_method
def forward(self, x: Tensor):
if self.stats_reduce_dim is None:
return torch.abs(torch.max(x) - torch.min(x))
else:
max_val = torch.max(x, dim=self.stats_reduce_dim)[0]
min_val = torch.min(x, dim=self.stats_reduce_dim)[0]
max_val = torch.max(x, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0]
min_val = torch.min(x, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0]
return torch.abs(max_val - min_val)


Expand Down

0 comments on commit cdc6abc

Please sign in to comment.