Skip to content

Commit

Permalink
add tests for AdaptivePredictor
Browse files Browse the repository at this point in the history
  • Loading branch information
voorhs committed Nov 19, 2024
1 parent ad401b3 commit 6743faf
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 12 deletions.
5 changes: 3 additions & 2 deletions autointent/datafiles/default-multilabel-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ nodes:
search_space:
- module_type: vector_db
k: [10]
model_name:
embedder_name:
- deepvk/USER-bge-m3
- node_type: scoring
metric: scoring_roc_auc
Expand All @@ -18,4 +18,5 @@ nodes:
metric: prediction_accuracy
search_space:
- module_type: threshold
thresh: [0.5]
thresh: [0.5]
- module_type: adaptive
18 changes: 9 additions & 9 deletions autointent/modules/prediction/adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@

class AdaptivePredictorDumpMetadata(TypedDict):
r: float
multilabel: bool
tags: list[Tag] | None
n_classes: int


class AdaptivePredictor(PredictionModule):
metadata_dict_name = "metadata.json"
n_classes: int
_r: float
tags: list[Tag] | None
name = "adapt"
name = "adaptive"

def __init__(self, search_space: list[float] | None = None) -> None:
self.search_space = search_space if search_space is not None else default_search_space
Expand All @@ -53,7 +53,7 @@ def fit(
consider using other predictor algorithms"""
raise WrongClassificationError(msg)
self.n_classes = (
len(labels[0]) if self.multilabel and isinstance(labels[0], list) else len(set(labels).difference([-1]))
len(labels[0]) if multilabel and isinstance(labels[0], list) else len(set(labels).difference([-1]))
)

metrics_list = []
Expand All @@ -73,7 +73,7 @@ def predict(self, scores: npt.NDArray[Any]) -> npt.NDArray[Any]:
def dump(self, path: str) -> None:
dump_dir = Path(path)

metadata = AdaptivePredictorDumpMetadata(r=self._r, multilabel=self.multilabel, tags=self.tags)
metadata = AdaptivePredictorDumpMetadata(r=self._r, tags=self.tags, n_classes=self.n_classes)

with (dump_dir / self.metadata_dict_name).open("w") as file:
json.dump(metadata, file, indent=4)
Expand All @@ -85,18 +85,18 @@ def load(self, path: str) -> None:
metadata: AdaptivePredictorDumpMetadata = json.load(file)

self._r = metadata["r"]
self.multilabel = metadata["multilabel"]
self.n_classes = metadata["n_classes"]
self.tags = [Tag(**tag) for tag in metadata["tags"] if metadata["tags"] and isinstance(metadata["tags"], list)] # type: ignore[arg-type, union-attr]
self.metadata = metadata


def _find_threshes(r: float, scores: npt.NDArray[Any]) -> npt.NDArray[Any]:
def get_adapted_threshes(r: float, scores: npt.NDArray[Any]) -> npt.NDArray[Any]:
return r * np.max(scores, axis=1) + (1 - r) * np.min(scores, axis=1) # type: ignore[no-any-return]


def multilabel_predict(scores: npt.NDArray[Any], r: float, tags: list[Tag] | None) -> npt.NDArray[Any]:
thresh = _find_threshes(r, scores)
res = (scores >= thresh[None, :]).astype(int) # suspicious
thresh = get_adapted_threshes(r, scores)
res = (scores >= thresh[:, None]).astype(int) # suspicious
if tags:
res = apply_tags(res, scores, tags)
return res
Expand All @@ -105,4 +105,4 @@ def multilabel_predict(scores: npt.NDArray[Any], r: float, tags: list[Tag] | Non
def multilabel_score(y_true: list[LabelType], y_pred: npt.NDArray[Any]) -> float:
y_true_, y_pred_ = transform(y_true, y_pred)

return f1_score(y_pred, y_true, average="weighted") # type: ignore[no-any-return]
return f1_score(y_pred_, y_true_, average="weighted") # type: ignore[no-any-return]
3 changes: 2 additions & 1 deletion tests/assets/configs/multilabel.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ nodes:
search_space:
- module_type: threshold
thresh: [0.5, [0.5, 0.5, 0.5]]
- module_type: tunable
- module_type: tunable
- module_type: adaptive
29 changes: 29 additions & 0 deletions tests/modules/prediction/test_adaptive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import numpy as np
import pytest

from autointent.modules import AdaptivePredictor
from autointent.modules.prediction.utils import InvalidNumClassesError, WrongClassificationError


def test_multilabel(multilabel_fit_data):
predictor = AdaptivePredictor()
predictor.fit(*multilabel_fit_data)
scores = np.array([[0.2, 0.9, 0], [0.8, 0, 0.6], [0, 0.4, 0.7]])
predictions = predictor.predict(scores)
desired = np.array([[0, 1, 0], [1, 0, 1], [0, 1, 1]])

np.testing.assert_array_equal(predictions, desired)


def test_fails_on_wrong_n_classes_predict(multilabel_fit_data):
predictor = AdaptivePredictor()
predictor.fit(*multilabel_fit_data)
scores = np.array([[0.1, 0.9], [0.8, 0.2], [0.3, 0.7]])
with pytest.raises(InvalidNumClassesError):
predictor.predict(scores)


def test_fails_on_wrong_clf_problem(multiclass_fit_data):
predictor = AdaptivePredictor()
with pytest.raises(WrongClassificationError):
predictor.fit(*multiclass_fit_data)
1 change: 1 addition & 0 deletions tests/nodes/test_predicton.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def test_prediction_multilabel(scoring_optimizer_multilabel):
"search_space": [
{"module_type": "threshold", "thresh": [0.5]},
{"module_type": "tunable", "n_trials": [None, 3]},
{"module_type": "adaptive"}
],
}

Expand Down

0 comments on commit 6743faf

Please sign in to comment.