Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Samoed committed Oct 1, 2024
1 parent 43d1b80 commit db16df6
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 16 deletions.
7 changes: 3 additions & 4 deletions autointent/modules/scoring/mlknn/mlknn.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _compute_cond(
idx_helper = np.arange(self._n_classes)
deltas_idx = deltas[idx_helper]
c[idx_helper, deltas_idx] += y[i]
cn[idx_helper, deltas_idx] += (1 - y[i])
cn[idx_helper, deltas_idx] += 1 - y[i]

c_sum = c.sum(axis=1)
cn_sum = cn.sum(axis=1)
Expand Down Expand Up @@ -82,13 +82,12 @@ def _get_neighbors(
[self._converter(candidates[self.ignore_first_neighbours :]) for candidates in query_res["metadatas"]]
)

def predict_labels(self, utterances: list[str]) -> NDArray[np.int64]:
def predict_labels(self, utterances: list[str], thresh: float = 0.5) -> NDArray[np.int64]:
probas = self.predict(utterances)
thresh = 0.5
return (probas > thresh).astype(int)

def predict(self, utterances: list[str]) -> NDArray[np.float64]:
result = np.zeros((len(utterances), self._n_classes), dtype=int)
result = np.zeros((len(utterances), self._n_classes), dtype=float)
neighbors_labels = self._get_neighbors(texts=utterances)

for instance in range(neighbors_labels.shape[0]):
Expand Down
22 changes: 10 additions & 12 deletions tests/test_modules/test_scoring/test_mlknn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_base_mlknn():
run_name = get_run_name("multiclass-cpu")
db_dir = get_db_dir("", run_name)

data = load_data("tests/minimal-optimization/data/clinc_subset.json", multilabel=False)
data = load_data("../../minimal-optimization/data/clinc_subset.json", multilabel=False)
utterance = [
{
"utterance": "why is there a hold on my american saving bank account",
Expand Down Expand Up @@ -41,16 +41,14 @@ def test_base_mlknn():

context.optimization_logs.cache["best_assets"]["retrieval"] = "sergeyzh/rubert-tiny-turbo"
scorer.fit(context)
np.testing.assert_almost_equal(0.75, scorer.score(context, scoring_f1))
np.testing.assert_almost_equal(0.6663752913752914, scorer.score(context, scoring_f1))
predictions = scorer.predict_labels(
np.array(
[
"why is there a hold on my american saving bank account",
"i am nost sure why my account is blocked",
"why is there a hold on my capital one checking account",
"i think my account is blocked but i do not know the reason",
"can you tell me why is my bank account frozen",
]
)
[
"why is there a hold on my american saving bank account",
"i am nost sure why my account is blocked",
"why is there a hold on my capital one checking account",
"i think my account is blocked but i do not know the reason",
"can you tell me why is my bank account frozen",
]
)
assert (predictions == np.array([[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]])).all()
assert (predictions == np.array([[0, 1, 0, 0], [0, 1, 0, 0], [0, 1, 0, 0], [0, 1, 0, 0], [0, 1, 0, 0]])).all()

0 comments on commit db16df6

Please sign in to comment.