Skip to content

Commit

Permalink
feat: static threshold estimator (#136)
Browse files Browse the repository at this point in the history
Signed-off-by: Avik Basu <[email protected]>
  • Loading branch information
ab93 authored Feb 8, 2023
1 parent d3488c9 commit 2ac1c2f
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 4 deletions.
4 changes: 2 additions & 2 deletions numalogic/config/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
TransformerAE,
SparseTransformerAE,
)
from numalogic.models.threshold import StdDevThreshold
from numalogic.models.threshold import StdDevThreshold, StaticThreshold
from numalogic.postprocess import TanhNorm
from numalogic.preprocess import LogTransformer, StaticPowerTransformer
from numalogic.tools.exceptions import UnknownConfigArgsError
Expand Down Expand Up @@ -62,7 +62,7 @@ class PostprocessFactory(_ObjectFactory):


class ThresholdFactory(_ObjectFactory):
_CLS_MAP = {"StdDevThreshold": StdDevThreshold}
_CLS_MAP = {"StdDevThreshold": StdDevThreshold, "StaticThreshold": StaticThreshold}


class ModelFactory(_ObjectFactory):
Expand Down
3 changes: 2 additions & 1 deletion numalogic/models/threshold/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from numalogic.models.threshold._std import StdDevThreshold
from numalogic.models.threshold._static import StaticThreshold

__all__ = ["StdDevThreshold"]
__all__ = ["StdDevThreshold", "StaticThreshold"]
67 changes: 67 additions & 0 deletions numalogic/models/threshold/_static.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2022 The Numaproj Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import numpy.typing as npt
from sklearn.base import BaseEstimator
from typing_extensions import Self


class StaticThreshold(BaseEstimator):
r"""
Simple and stateless static thresholding as an estimator.
Values more than upper_limit is considered an outlier,
and are given an outlier_score.
Values less than the upper_limit is considered an inlier,
and are given an inlier_score.
Args:
upper_limit: upper threshold
outlier_score: static score given to values above upper threshold;
this has to be greater than inlier_score
inlier_score: static score given to values below upper threshold
"""
__slots__ = ("upper_limit", "outlier_score", "inlier_score")

def __init__(self, upper_limit: float, outlier_score: float = 10.0, inlier_score: float = 0.5):
self.upper_limit = upper_limit
self.outlier_score = outlier_score
self.inlier_score = inlier_score

assert (
self.outlier_score > self.inlier_score
), "Outlier score needs to be greater than inlier score"

def fit(self, _: npt.NDArray[float]) -> Self:
"""Does not do anything. Only for API compatibility"""
return self

def predict(self, x_test: npt.NDArray[float]) -> npt.NDArray[float]:
"""
Returns an array of same shape as input.
1 denotes anomaly.
"""
x_test = x_test.copy()
x_test[x_test < self.upper_limit] = 0.0
x_test[x_test >= self.upper_limit] = 1.0
return x_test

def score_samples(self, x_test: npt.NDArray[float]) -> npt.NDArray[float]:
"""
Returns an array of same shape as input
with values being anomaly scores.
"""
x_test = x_test.copy()
x_test[x_test < self.upper_limit] = self.inlier_score
x_test[x_test >= self.upper_limit] = self.outlier_score
return x_test
22 changes: 21 additions & 1 deletion tests/models/test_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from numalogic.models.threshold import StdDevThreshold
from numalogic.models.threshold import StdDevThreshold, StaticThreshold


class TestStdDevThreshold(unittest.TestCase):
Expand All @@ -23,5 +23,25 @@ def test_estimator_score(self):
self.assertAlmostEqual(0.93317, np.mean(score), places=2)


class TestStaticThreshold(unittest.TestCase):
def setUp(self) -> None:
self.x = np.arange(20).reshape(10, 2).astype(float)

def test_predict(self):
clf = StaticThreshold(upper_limit=5)
clf.fit(self.x)
y = clf.predict(self.x)
self.assertTupleEqual(self.x.shape, y.shape)
self.assertEqual(np.max(y), 1)
self.assertEqual(np.min(y), 0)

def test_score(self):
clf = StaticThreshold(upper_limit=5.0)
y = clf.score_samples(self.x)
self.assertTupleEqual(self.x.shape, y.shape)
self.assertEqual(np.max(y), clf.outlier_score)
self.assertEqual(np.min(y), clf.inlier_score)


if __name__ == "__main__":
unittest.main()

0 comments on commit 2ac1c2f

Please sign in to comment.