diff --git a/autointent/modules/scoring/mlknn/mlknn.py b/autointent/modules/scoring/mlknn/mlknn.py index 3133dc2..7939440 100644 --- a/autointent/modules/scoring/mlknn/mlknn.py +++ b/autointent/modules/scoring/mlknn/mlknn.py @@ -50,7 +50,7 @@ def _compute_cond( idx_helper = np.arange(self._n_classes) deltas_idx = deltas[idx_helper] c[idx_helper, deltas_idx] += y[i] - cn[idx_helper, deltas_idx] += (1 - y[i]) + cn[idx_helper, deltas_idx] += 1 - y[i] c_sum = c.sum(axis=1) cn_sum = cn.sum(axis=1) @@ -82,13 +82,12 @@ def _get_neighbors( [self._converter(candidates[self.ignore_first_neighbours :]) for candidates in query_res["metadatas"]] ) - def predict_labels(self, utterances: list[str]) -> NDArray[np.int64]: + def predict_labels(self, utterances: list[str], thresh: float = 0.5) -> NDArray[np.int64]: probas = self.predict(utterances) - thresh = 0.5 return (probas > thresh).astype(int) def predict(self, utterances: list[str]) -> NDArray[np.float64]: - result = np.zeros((len(utterances), self._n_classes), dtype=int) + result = np.zeros((len(utterances), self._n_classes), dtype=float) neighbors_labels = self._get_neighbors(texts=utterances) for instance in range(neighbors_labels.shape[0]): diff --git a/tests/test_modules/test_scoring/test_mlknn.py b/tests/test_modules/test_scoring/test_mlknn.py index e5a2fcf..c18858a 100644 --- a/tests/test_modules/test_scoring/test_mlknn.py +++ b/tests/test_modules/test_scoring/test_mlknn.py @@ -12,7 +12,7 @@ def test_base_mlknn(): run_name = get_run_name("multiclass-cpu") db_dir = get_db_dir("", run_name) - data = load_data("tests/minimal-optimization/data/clinc_subset.json", multilabel=False) + data = load_data("../../minimal-optimization/data/clinc_subset.json", multilabel=False) utterance = [ { "utterance": "why is there a hold on my american saving bank account", @@ -41,16 +41,14 @@ def test_base_mlknn(): context.optimization_logs.cache["best_assets"]["retrieval"] = "sergeyzh/rubert-tiny-turbo" scorer.fit(context) - np.testing.assert_almost_equal(0.75, scorer.score(context, scoring_f1)) + np.testing.assert_almost_equal(0.6663752913752914, scorer.score(context, scoring_f1)) predictions = scorer.predict_labels( - np.array( - [ - "why is there a hold on my american saving bank account", - "i am nost sure why my account is blocked", - "why is there a hold on my capital one checking account", - "i think my account is blocked but i do not know the reason", - "can you tell me why is my bank account frozen", - ] - ) + [ + "why is there a hold on my american saving bank account", + "i am nost sure why my account is blocked", + "why is there a hold on my capital one checking account", + "i think my account is blocked but i do not know the reason", + "can you tell me why is my bank account frozen", + ] ) - assert (predictions == np.array([[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]])).all() + assert (predictions == np.array([[0, 1, 0, 0], [0, 1, 0, 0], [0, 1, 0, 0], [0, 1, 0, 0], [0, 1, 0, 0]])).all()