diff --git a/numalogic/config/factory.py b/numalogic/config/factory.py index 347226f0..0004425c 100644 --- a/numalogic/config/factory.py +++ b/numalogic/config/factory.py @@ -23,7 +23,7 @@ TransformerAE, SparseTransformerAE, ) -from numalogic.models.threshold import StdDevThreshold +from numalogic.models.threshold import StdDevThreshold, StaticThreshold from numalogic.postprocess import TanhNorm from numalogic.preprocess import LogTransformer, StaticPowerTransformer from numalogic.tools.exceptions import UnknownConfigArgsError @@ -62,7 +62,7 @@ class PostprocessFactory(_ObjectFactory): class ThresholdFactory(_ObjectFactory): - _CLS_MAP = {"StdDevThreshold": StdDevThreshold} + _CLS_MAP = {"StdDevThreshold": StdDevThreshold, "StaticThreshold": StaticThreshold} class ModelFactory(_ObjectFactory): diff --git a/numalogic/models/threshold/__init__.py b/numalogic/models/threshold/__init__.py index 072397d4..e96c0110 100644 --- a/numalogic/models/threshold/__init__.py +++ b/numalogic/models/threshold/__init__.py @@ -1,3 +1,4 @@ from numalogic.models.threshold._std import StdDevThreshold +from numalogic.models.threshold._static import StaticThreshold -__all__ = ["StdDevThreshold"] +__all__ = ["StdDevThreshold", "StaticThreshold"] diff --git a/numalogic/models/threshold/_static.py b/numalogic/models/threshold/_static.py new file mode 100644 index 00000000..14b8c461 --- /dev/null +++ b/numalogic/models/threshold/_static.py @@ -0,0 +1,67 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy.typing as npt +from sklearn.base import BaseEstimator +from typing_extensions import Self + + +class StaticThreshold(BaseEstimator): + r""" + Simple and stateless static thresholding as an estimator. + + Values more than upper_limit is considered an outlier, + and are given an outlier_score. + + Values less than the upper_limit is considered an inlier, + and are given an inlier_score. + + Args: + upper_limit: upper threshold + outlier_score: static score given to values above upper threshold; + this has to be greater than inlier_score + inlier_score: static score given to values below upper threshold + """ + __slots__ = ("upper_limit", "outlier_score", "inlier_score") + + def __init__(self, upper_limit: float, outlier_score: float = 10.0, inlier_score: float = 0.5): + self.upper_limit = upper_limit + self.outlier_score = outlier_score + self.inlier_score = inlier_score + + assert ( + self.outlier_score > self.inlier_score + ), "Outlier score needs to be greater than inlier score" + + def fit(self, _: npt.NDArray[float]) -> Self: + """Does not do anything. Only for API compatibility""" + return self + + def predict(self, x_test: npt.NDArray[float]) -> npt.NDArray[float]: + """ + Returns an array of same shape as input. + 1 denotes anomaly. + """ + x_test = x_test.copy() + x_test[x_test < self.upper_limit] = 0.0 + x_test[x_test >= self.upper_limit] = 1.0 + return x_test + + def score_samples(self, x_test: npt.NDArray[float]) -> npt.NDArray[float]: + """ + Returns an array of same shape as input + with values being anomaly scores. + """ + x_test = x_test.copy() + x_test[x_test < self.upper_limit] = self.inlier_score + x_test[x_test >= self.upper_limit] = self.outlier_score + return x_test diff --git a/tests/models/test_threshold.py b/tests/models/test_threshold.py index 8a87cc11..9bc8f693 100644 --- a/tests/models/test_threshold.py +++ b/tests/models/test_threshold.py @@ -2,7 +2,7 @@ import numpy as np -from numalogic.models.threshold import StdDevThreshold +from numalogic.models.threshold import StdDevThreshold, StaticThreshold class TestStdDevThreshold(unittest.TestCase): @@ -23,5 +23,25 @@ def test_estimator_score(self): self.assertAlmostEqual(0.93317, np.mean(score), places=2) +class TestStaticThreshold(unittest.TestCase): + def setUp(self) -> None: + self.x = np.arange(20).reshape(10, 2).astype(float) + + def test_predict(self): + clf = StaticThreshold(upper_limit=5) + clf.fit(self.x) + y = clf.predict(self.x) + self.assertTupleEqual(self.x.shape, y.shape) + self.assertEqual(np.max(y), 1) + self.assertEqual(np.min(y), 0) + + def test_score(self): + clf = StaticThreshold(upper_limit=5.0) + y = clf.score_samples(self.x) + self.assertTupleEqual(self.x.shape, y.shape) + self.assertEqual(np.max(y), clf.outlier_score) + self.assertEqual(np.min(y), clf.inlier_score) + + if __name__ == "__main__": unittest.main()