From 6743faff4d1b3280fc45dd8918ce35f4966c28d5 Mon Sep 17 00:00:00 2001 From: voorhs Date: Tue, 19 Nov 2024 16:05:19 +0300 Subject: [PATCH] add tests for AdaptivePredictor --- .../datafiles/default-multilabel-config.yaml | 5 ++-- autointent/modules/prediction/adaptive.py | 18 ++++++------ tests/assets/configs/multilabel.yaml | 3 +- tests/modules/prediction/test_adaptive.py | 29 +++++++++++++++++++ tests/nodes/test_predicton.py | 1 + 5 files changed, 44 insertions(+), 12 deletions(-) create mode 100644 tests/modules/prediction/test_adaptive.py diff --git a/autointent/datafiles/default-multilabel-config.yaml b/autointent/datafiles/default-multilabel-config.yaml index 941dd732..c76822cc 100644 --- a/autointent/datafiles/default-multilabel-config.yaml +++ b/autointent/datafiles/default-multilabel-config.yaml @@ -5,7 +5,7 @@ nodes: search_space: - module_type: vector_db k: [10] - model_name: + embedder_name: - deepvk/USER-bge-m3 - node_type: scoring metric: scoring_roc_auc @@ -18,4 +18,5 @@ nodes: metric: prediction_accuracy search_space: - module_type: threshold - thresh: [0.5] \ No newline at end of file + thresh: [0.5] + - module_type: adaptive \ No newline at end of file diff --git a/autointent/modules/prediction/adaptive.py b/autointent/modules/prediction/adaptive.py index dc1e6674..2f1bb761 100644 --- a/autointent/modules/prediction/adaptive.py +++ b/autointent/modules/prediction/adaptive.py @@ -20,8 +20,8 @@ class AdaptivePredictorDumpMetadata(TypedDict): r: float - multilabel: bool tags: list[Tag] | None + n_classes: int class AdaptivePredictor(PredictionModule): @@ -29,7 +29,7 @@ class AdaptivePredictor(PredictionModule): n_classes: int _r: float tags: list[Tag] | None - name = "adapt" + name = "adaptive" def __init__(self, search_space: list[float] | None = None) -> None: self.search_space = search_space if search_space is not None else default_search_space @@ -53,7 +53,7 @@ def fit( consider using other predictor algorithms""" raise WrongClassificationError(msg) self.n_classes = ( - len(labels[0]) if self.multilabel and isinstance(labels[0], list) else len(set(labels).difference([-1])) + len(labels[0]) if multilabel and isinstance(labels[0], list) else len(set(labels).difference([-1])) ) metrics_list = [] @@ -73,7 +73,7 @@ def predict(self, scores: npt.NDArray[Any]) -> npt.NDArray[Any]: def dump(self, path: str) -> None: dump_dir = Path(path) - metadata = AdaptivePredictorDumpMetadata(r=self._r, multilabel=self.multilabel, tags=self.tags) + metadata = AdaptivePredictorDumpMetadata(r=self._r, tags=self.tags, n_classes=self.n_classes) with (dump_dir / self.metadata_dict_name).open("w") as file: json.dump(metadata, file, indent=4) @@ -85,18 +85,18 @@ def load(self, path: str) -> None: metadata: AdaptivePredictorDumpMetadata = json.load(file) self._r = metadata["r"] - self.multilabel = metadata["multilabel"] + self.n_classes = metadata["n_classes"] self.tags = [Tag(**tag) for tag in metadata["tags"] if metadata["tags"] and isinstance(metadata["tags"], list)] # type: ignore[arg-type, union-attr] self.metadata = metadata -def _find_threshes(r: float, scores: npt.NDArray[Any]) -> npt.NDArray[Any]: +def get_adapted_threshes(r: float, scores: npt.NDArray[Any]) -> npt.NDArray[Any]: return r * np.max(scores, axis=1) + (1 - r) * np.min(scores, axis=1) # type: ignore[no-any-return] def multilabel_predict(scores: npt.NDArray[Any], r: float, tags: list[Tag] | None) -> npt.NDArray[Any]: - thresh = _find_threshes(r, scores) - res = (scores >= thresh[None, :]).astype(int) # suspicious + thresh = get_adapted_threshes(r, scores) + res = (scores >= thresh[:, None]).astype(int) # suspicious if tags: res = apply_tags(res, scores, tags) return res @@ -105,4 +105,4 @@ def multilabel_predict(scores: npt.NDArray[Any], r: float, tags: list[Tag] | Non def multilabel_score(y_true: list[LabelType], y_pred: npt.NDArray[Any]) -> float: y_true_, y_pred_ = transform(y_true, y_pred) - return f1_score(y_pred, y_true, average="weighted") # type: ignore[no-any-return] + return f1_score(y_pred_, y_true_, average="weighted") # type: ignore[no-any-return] diff --git a/tests/assets/configs/multilabel.yaml b/tests/assets/configs/multilabel.yaml index e9d439da..59937073 100644 --- a/tests/assets/configs/multilabel.yaml +++ b/tests/assets/configs/multilabel.yaml @@ -21,4 +21,5 @@ nodes: search_space: - module_type: threshold thresh: [0.5, [0.5, 0.5, 0.5]] - - module_type: tunable \ No newline at end of file + - module_type: tunable + - module_type: adaptive \ No newline at end of file diff --git a/tests/modules/prediction/test_adaptive.py b/tests/modules/prediction/test_adaptive.py new file mode 100644 index 00000000..551ab01f --- /dev/null +++ b/tests/modules/prediction/test_adaptive.py @@ -0,0 +1,29 @@ +import numpy as np +import pytest + +from autointent.modules import AdaptivePredictor +from autointent.modules.prediction.utils import InvalidNumClassesError, WrongClassificationError + + +def test_multilabel(multilabel_fit_data): + predictor = AdaptivePredictor() + predictor.fit(*multilabel_fit_data) + scores = np.array([[0.2, 0.9, 0], [0.8, 0, 0.6], [0, 0.4, 0.7]]) + predictions = predictor.predict(scores) + desired = np.array([[0, 1, 0], [1, 0, 1], [0, 1, 1]]) + + np.testing.assert_array_equal(predictions, desired) + + +def test_fails_on_wrong_n_classes_predict(multilabel_fit_data): + predictor = AdaptivePredictor() + predictor.fit(*multilabel_fit_data) + scores = np.array([[0.1, 0.9], [0.8, 0.2], [0.3, 0.7]]) + with pytest.raises(InvalidNumClassesError): + predictor.predict(scores) + + +def test_fails_on_wrong_clf_problem(multiclass_fit_data): + predictor = AdaptivePredictor() + with pytest.raises(WrongClassificationError): + predictor.fit(*multiclass_fit_data) diff --git a/tests/nodes/test_predicton.py b/tests/nodes/test_predicton.py index 9c675434..41301c0d 100644 --- a/tests/nodes/test_predicton.py +++ b/tests/nodes/test_predicton.py @@ -58,6 +58,7 @@ def test_prediction_multilabel(scoring_optimizer_multilabel): "search_space": [ {"module_type": "threshold", "thresh": [0.5]}, {"module_type": "tunable", "n_trials": [None, 3]}, + {"module_type": "adaptive"} ], }