Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Whole-series anomaly detection #2326

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions aeon/anomaly_detection/whole_series/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Whole-series anomaly detection methods."""

__all__ = [
"BaseCollectionAnomalyDetector",
"OutlierDetectionClassifier",
]

from aeon.anomaly_detection.whole_series._outlier_detection import (
OutlierDetectionClassifier,
)
from aeon.anomaly_detection.whole_series.base import BaseCollectionAnomalyDetector
36 changes: 36 additions & 0 deletions aeon/anomaly_detection/whole_series/_outlier_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Basic outlier detection classifier."""

from sklearn.ensemble import IsolationForest

from aeon.anomaly_detection.whole_series.base import BaseCollectionAnomalyDetector
from aeon.base._base import _clone_estimator


class OutlierDetectionClassifier(BaseCollectionAnomalyDetector):
"""Basic outlier detection classifier."""

_tags = {
"X_inner_type": "numpy2D",
}

def __init__(self, estimator, random_state=None):
self.estimator = estimator
self.random_state = random_state

super().__init__()

def _fit(self, X, y=None):
self.estimator_ = _clone_estimator(
self.estimator, random_state=self.random_state
)
self.estimator_.fit(X, y)
return self

def _predict(self, X):
pred = self.estimator_.predict(X)
pred[pred == -1] = 0
return pred

@classmethod
def _get_test_params(cls, parameter_set="default"):
return {"estimator": IsolationForest(n_estimators=3)}
92 changes: 92 additions & 0 deletions aeon/anomaly_detection/whole_series/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""Abstract base class for whole-series/collection anomaly detectors."""

__maintainer__ = ["MatthewMiddlehurst"]
__all__ = ["BaseCollectionAnomalyDetector"]

from abc import abstractmethod
from typing import final

import numpy as np
import pandas as pd

from aeon.base import BaseCollectionEstimator


class BaseCollectionAnomalyDetector(BaseCollectionEstimator):
"""Collection anomaly detector base class."""

_tags = {
"fit_is_empty": False,
"requires_y": False,
}

def __init__(self):
super().__init__()

@final
def fit(self, X, y=None):
"""Fit."""
if self.get_tag("fit_is_empty"):
self.is_fitted = True
return self

if self.get_tag("requires_y"):
if y is None:
raise ValueError("Tag requires_y is true, but fit called with y=None")

# reset estimator at the start of fit
self.reset()

X = self._preprocess_collection(X)
if y is not None:
y = self._check_y(y, self.metadata_["n_cases"])

self._fit(X, y)

# this should happen last
self.is_fitted = True
return self

@final
def predict(self, X):
"""Predict."""
fit_empty = self.get_tag("fit_is_empty")
if not fit_empty:
self._check_is_fitted()

X = self._preprocess_collection(X, store_metadata=False)
# Check if X has the correct shape seen during fitting
self._check_shape(X)

return self._predict(X)

@abstractmethod
def _fit(self, X, y=None): ...

@abstractmethod
def _predict(self, X): ...

def _check_y(self, y, n_cases):
if not isinstance(y, (pd.Series, np.ndarray)):
raise TypeError(
f"y must be a np.array or a pd.Series, but found type: {type(y)}"
)
if isinstance(y, np.ndarray) and y.ndim > 1:
raise TypeError(f"y must be 1-dimensional, found {y.ndim} dimensions")

if not all([x == 0 or x == 1 for x in y]):
raise ValueError(
"y input must only contain 0 (not anomalous) or 1 (anomalous) values."
)

# Check matching number of labels
n_labels = y.shape[0]
if n_cases != n_labels:
raise ValueError(
f"Mismatch in number of cases. Found X = {n_cases} and y = {n_labels}"
)

if isinstance(y, pd.Series):
y = pd.Series.to_numpy(y)

return y
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Tests for all collection anomaly detectors."""


def _yield_collection_anomaly_detection_checks(
estimator_class, estimator_instances, datatypes
):
"""Yield all collection anomaly detection checks for an aeon estimator."""
# nothing currently!
return []
9 changes: 9 additions & 0 deletions aeon/testing/estimator_checking/_yield_estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sklearn.utils.estimator_checks import check_get_params_invariance

from aeon.anomaly_detection.base import BaseAnomalyDetector
from aeon.anomaly_detection.whole_series.base import BaseCollectionAnomalyDetector
from aeon.base import BaseAeonEstimator
from aeon.base._base import _clone_estimator
from aeon.classification import BaseClassifier
Expand All @@ -34,6 +35,9 @@
from aeon.testing.estimator_checking._yield_clustering_checks import (
_yield_clustering_checks,
)
from aeon.testing.estimator_checking._yield_collection_anomaly_detection_checks import (
_yield_collection_anomaly_detection_checks,
)
from aeon.testing.estimator_checking._yield_collection_transformation_checks import (
_yield_collection_transformation_checks,
)
Expand Down Expand Up @@ -152,6 +156,11 @@ def _yield_all_aeon_checks(
estimator_class, estimator_instances, datatypes
)

if issubclass(estimator_class, BaseCollectionAnomalyDetector):
yield from _yield_collection_anomaly_detection_checks(
estimator_class, estimator_instances, datatypes
)

if issubclass(estimator_class, BaseSimilaritySearch):
yield from _yield_similarity_search_checks(
estimator_class, estimator_instances, datatypes
Expand Down
2 changes: 2 additions & 0 deletions aeon/testing/testing_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np

from aeon.anomaly_detection.base import BaseAnomalyDetector
from aeon.anomaly_detection.whole_series.base import BaseCollectionAnomalyDetector
from aeon.base import BaseCollectionEstimator, BaseSeriesEstimator
from aeon.classification import BaseClassifier
from aeon.classification.early_classification import BaseEarlyClassifier
Expand Down Expand Up @@ -821,6 +822,7 @@ def _get_label_type_for_estimator(estimator):
or isinstance(estimator, BaseClusterer)
or isinstance(estimator, BaseCollectionTransformer)
or isinstance(estimator, BaseSimilaritySearch)
or isinstance(estimator, BaseCollectionAnomalyDetector)
):
label_type = "Classification"
elif isinstance(estimator, BaseRegressor):
Expand Down
2 changes: 2 additions & 0 deletions aeon/utils/base/_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@


from aeon.anomaly_detection.base import BaseAnomalyDetector
from aeon.anomaly_detection.whole_series.base import BaseCollectionAnomalyDetector
from aeon.base import BaseAeonEstimator, BaseCollectionEstimator, BaseSeriesEstimator
from aeon.classification.base import BaseClassifier
from aeon.classification.early_classification import BaseEarlyClassifier
Expand All @@ -37,6 +38,7 @@
"transformer": BaseTransformer,
# estimator types
"anomaly-detector": BaseAnomalyDetector,
"collection_anomaly_detector": BaseCollectionAnomalyDetector,
"collection-transformer": BaseCollectionTransformer,
"classifier": BaseClassifier,
"clusterer": BaseClusterer,
Expand Down
7 changes: 6 additions & 1 deletion aeon/utils/tags/_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,12 @@ class : identifier for the base class of objects this tag applies to
"point belongs to.",
},
"requires_y": {
"class": ["transformer", "anomaly-detector", "segmenter"],
"class": [
"transformer",
"anomaly-detector",
"collection_anomaly_detector",
"segmenter",
],
"type": "bool",
"description": "Does this estimator require y to be passed in its methods?",
},
Expand Down