From de779c089a406912ab4f6ae9bb414e613480a800 Mon Sep 17 00:00:00 2001 From: truff4ut <43514910+truff4ut@users.noreply.github.com> Date: Tue, 12 Nov 2024 13:58:25 +0300 Subject: [PATCH] Add enhanced output (#38) Co-authored-by: voorhs Co-authored-by: Roman Solomatin <36135455+Samoed@users.noreply.github.com> Co-authored-by: Darinka <39233990+Darinochka@users.noreply.github.com> --- autointent/configs/inference_cli.py | 1 + autointent/context/embedder.py | 5 +- autointent/modules/base.py | 7 + autointent/modules/regexp.py | 108 +++++-- autointent/modules/retrieval/vectordb.py | 10 +- .../scoring/description/description.py | 8 +- autointent/modules/scoring/dnnc/dnnc.py | 47 ++- autointent/modules/scoring/knn/knn.py | 26 +- autointent/modules/scoring/mlknn/mlknn.py | 67 ++-- autointent/nodes/__init__.py | 12 +- autointent/nodes/nodes_info/__init__.py | 11 +- autointent/nodes/nodes_info/regexp.py | 2 +- .../nodes/optimization/node_optimizer.py | 2 + autointent/pipeline/inference/cli_endpoint.py | 17 +- .../pipeline/inference/inference_pipeline.py | 45 +++ .../predict_with_metadata/testbed.ipynb | 301 ++++++++++++++++++ tests/conftest.py | 7 +- tests/modules/retrieval/test_vectordb.py | 6 +- tests/modules/scoring/test_description.py | 6 +- tests/modules/scoring/test_dnnc.py | 18 +- tests/modules/scoring/test_knn.py | 24 +- tests/modules/scoring/test_linear.py | 6 +- tests/modules/scoring/test_mlknn.py | 24 +- tests/modules/test_regex.py | 127 ++------ tests/nodes/conftest.py | 26 +- tests/nodes/test_retrieval.py | 10 +- tests/nodes/test_scoring.py | 10 +- tests/pipeline/test_inference.py | 15 +- tests/pipeline/test_optimization.py | 1 + 29 files changed, 694 insertions(+), 255 deletions(-) create mode 100644 experiments/predict_with_metadata/testbed.ipynb diff --git a/autointent/configs/inference_cli.py b/autointent/configs/inference_cli.py index b0d8a6d4..992759eb 100644 --- a/autointent/configs/inference_cli.py +++ b/autointent/configs/inference_cli.py @@ -11,6 +11,7 @@ class InferenceConfig: source_dir: str output_path: str log_level: LogLevel = LogLevel.ERROR + with_metadata: bool = False cs = ConfigStore.instance() diff --git a/autointent/context/embedder.py b/autointent/context/embedder.py index e7d5b70b..c0ccae5a 100644 --- a/autointent/context/embedder.py +++ b/autointent/context/embedder.py @@ -81,5 +81,8 @@ def embed(self, utterances: list[str]) -> npt.NDArray[np.float32]: if self.max_length is not None: self.embedding_model.max_seq_length = self.max_length return self.embedding_model.encode( - utterances, convert_to_numpy=True, batch_size=self.batch_size, normalize_embeddings=True, + utterances, + convert_to_numpy=True, + batch_size=self.batch_size, + normalize_embeddings=True, ) # type: ignore[return-value] diff --git a/autointent/modules/base.py b/autointent/modules/base.py index e89ef49c..6706828c 100644 --- a/autointent/modules/base.py +++ b/autointent/modules/base.py @@ -48,6 +48,13 @@ def load(self, path: str) -> None: def predict(self, *args: list[str] | npt.NDArray[Any], **kwargs: dict[str, Any]) -> npt.NDArray[Any]: """inference""" + def predict_with_metadata( + self, + *args: list[str] | npt.NDArray[Any], + **kwargs: dict[str, Any], + ) -> tuple[npt.NDArray[Any], list[dict[str, Any]] | None]: + return self.predict(*args, **kwargs), None + @classmethod @abstractmethod def from_context(cls, context: Context, **kwargs: dict[str, Any]) -> Self: diff --git a/autointent/modules/regexp.py b/autointent/modules/regexp.py index 2c71a81c..d5567634 100644 --- a/autointent/modules/regexp.py +++ b/autointent/modules/regexp.py @@ -1,10 +1,13 @@ +import json import re -from typing import TypedDict +from pathlib import Path +from typing import Any, TypedDict from typing_extensions import Self from autointent import Context from autointent.context.data_handler.data_handler import RegexPatterns +from autointent.context.data_handler.schemas import Intent from autointent.context.optimization_info.data_models import Artifact from autointent.custom_types import LabelType from autointent.metrics.regexp import RegexpMetricFn @@ -19,43 +22,60 @@ class RegexPatternsCompiled(TypedDict): class RegExp(Module): - name = "regexp" - - def __init__(self, regexp_patterns: list[RegexPatterns]) -> None: - self.regexp_patterns = regexp_patterns - @classmethod def from_context(cls, context: Context) -> Self: - return cls( - regexp_patterns=context.data_handler.regexp_patterns, - ) - - def fit(self, utterances: list[str], labels: list[LabelType]) -> None: - self.regexp_patterns_compiled: list[RegexPatternsCompiled] = [ - { - "id": dct["id"], - "regexp_full_match": [re.compile(ptn, flags=re.IGNORECASE) for ptn in dct["regexp_full_match"]], - "regexp_partial_match": [re.compile(ptn, flags=re.IGNORECASE) for ptn in dct["regexp_partial_match"]], - } - for dct in self.regexp_patterns + return cls() + + def fit(self, intents: list[dict[str, Any]]) -> None: + intents_parsed = [Intent(**dct) for dct in intents] + self.regexp_patterns = [ + RegexPatterns( + id=intent.id, + regexp_full_match=intent.regexp_full_match, + regexp_partial_match=intent.regexp_partial_match, + ) + for intent in intents_parsed ] + self._compile_regex_patterns() def predict(self, utterances: list[str]) -> list[LabelType]: - return [list(self._predict_single(ut)) for ut in utterances] - - def _match(self, text: str, intent_record: RegexPatternsCompiled) -> bool: - full_match = any(ptn.fullmatch(text) for ptn in intent_record["regexp_full_match"]) - if full_match: - return True - return any(ptn.match(text) for ptn in intent_record["regexp_partial_match"]) + return [self._predict_single(utterance)[0] for utterance in utterances] + + def predict_with_metadata( + self, + utterances: list[str], + ) -> tuple[list[LabelType], list[dict[str, Any]] | None]: + predictions, metadata = [], [] + for utterance in utterances: + prediction, matches = self._predict_single(utterance) + predictions.append(prediction) + metadata.append(matches) + return predictions, metadata + + def _match(self, utterance: str, intent_record: RegexPatternsCompiled) -> dict[str, list[str]]: + full_matches = [ + pattern.pattern + for pattern in intent_record["regexp_full_match"] + if pattern.fullmatch(utterance) is not None + ] + partial_matches = [ + pattern.pattern + for pattern in intent_record["regexp_partial_match"] + if pattern.search(utterance) is not None + ] + return {"full_matches": full_matches, "partial_matches": partial_matches} - def _predict_single(self, utterance: str) -> set[int]: + def _predict_single(self, utterance: str) -> tuple[LabelType, dict[str, list[str]]]: # todo test this - return { - intent_record["id"] - for intent_record in self.regexp_patterns_compiled - if self._match(utterance, intent_record) - } + prediction = set() + matches: dict[str, list[str]] = {"full_matches": [], "partial_matches": []} + for intent_record in self.regexp_patterns_compiled: + intent_matches = self._match(utterance, intent_record) + if intent_matches["full_matches"] or intent_matches["partial_matches"]: + prediction.add(intent_record["id"]) + matches["full_matches"].extend(intent_matches["full_matches"]) + matches["partial_matches"].extend(intent_matches["partial_matches"]) + return list(prediction), matches def score(self, context: Context, metric_fn: RegexpMetricFn) -> float: # TODO add parameter to a whole pipeline (or just to regexp module): @@ -78,7 +98,29 @@ def get_assets(self) -> Artifact: return Artifact() def load(self, path: str) -> None: - pass + dump_dir = Path(path) + + with (dump_dir / self.metadata_dict_name).open() as file: + self.regexp_patterns = json.load(file) + + self._compile_regex_patterns() def dump(self, path: str) -> None: - pass + dump_dir = Path(path) + + with (dump_dir / self.metadata_dict_name).open("w") as file: + json.dump(self.regexp_patterns, file, indent=4) + + def _compile_regex_patterns(self) -> None: + self.regexp_patterns_compiled = [ + RegexPatternsCompiled( + id=regexp_patterns["id"], + regexp_full_match=[ + re.compile(pattern, flags=re.IGNORECASE) for pattern in regexp_patterns["regexp_full_match"] + ], + regexp_partial_match=[ + re.compile(ptn, flags=re.IGNORECASE) for ptn in regexp_patterns["regexp_partial_match"] + ], + ) + for regexp_patterns in self.regexp_patterns + ] diff --git a/autointent/modules/retrieval/vectordb.py b/autointent/modules/retrieval/vectordb.py index 2802409f..0fe0a22a 100644 --- a/autointent/modules/retrieval/vectordb.py +++ b/autointent/modules/retrieval/vectordb.py @@ -32,11 +32,9 @@ def __init__( batch_size: int = 32, max_length: int | None = None, ) -> None: - if db_dir is None: - db_dir = str(get_db_dir()) self.embedder_name = embedder_name self.device = device - self.db_dir = db_dir + self._db_dir = db_dir self.batch_size = batch_size self.max_length = max_length @@ -58,6 +56,12 @@ def from_context( max_length=context.get_max_length(), ) + @property + def db_dir(self) -> str: + if self._db_dir is None: + self._db_dir = str(get_db_dir()) + return self._db_dir + def fit(self, utterances: list[str], labels: list[LabelType]) -> None: vector_index_client = VectorIndexClient( self.device, self.db_dir, embedder_batch_size=self.batch_size, embedder_max_length=self.max_length diff --git a/autointent/modules/scoring/description/description.py b/autointent/modules/scoring/description/description.py index 37f06924..b56586bd 100644 --- a/autointent/modules/scoring/description/description.py +++ b/autointent/modules/scoring/description/description.py @@ -11,7 +11,6 @@ from autointent.context import Context from autointent.context.embedder import Embedder from autointent.context.vector_index_client import VectorIndex, VectorIndexClient -from autointent.context.vector_index_client.cache import get_db_dir from autointent.custom_types import LabelType from autointent.modules.scoring.base import ScoringModule @@ -30,22 +29,19 @@ class DescriptionScorer(ScoringModule): precomputed_embeddings: bool = False embedding_model_subdir: str = "embedding_model" _vector_index: VectorIndex + db_dir: str name = "description" def __init__( self, embedder_name: str, - db_dir: Path | None = None, temperature: float = 1.0, device: str = "cpu", batch_size: int = 32, max_length: int | None = None, ) -> None: - if db_dir is None: - db_dir = get_db_dir() self.temperature = temperature self.device = device - self.db_dir = db_dir self.embedder_name = embedder_name self.batch_size = batch_size self.max_length = max_length @@ -66,10 +62,10 @@ def from_context( instance = cls( temperature=temperature, device=context.get_device(), - db_dir=context.get_db_dir(), embedder_name=embedder_name, ) instance.precomputed_embeddings = precomputed_embeddings + instance.db_dir = str(context.get_db_dir()) return instance def get_embedder_name(self) -> str: diff --git a/autointent/modules/scoring/dnnc/dnnc.py b/autointent/modules/scoring/dnnc/dnnc.py index 806c0c5f..f79c736d 100644 --- a/autointent/modules/scoring/dnnc/dnnc.py +++ b/autointent/modules/scoring/dnnc/dnnc.py @@ -52,18 +52,21 @@ def __init__( batch_size: int = 32, max_length: int | None = None, ) -> None: - if db_dir is None: - db_dir = str(get_db_dir()) - self.cross_encoder_name = cross_encoder_name self.embedder_name = embedder_name self.k = k self.train_head = train_head self.device = device - self.db_dir = db_dir + self._db_dir = db_dir self.batch_size = batch_size self.max_length = max_length + @property + def db_dir(self) -> str: + if self._db_dir is None: + self._db_dir = str(get_db_dir()) + return self._db_dir + @classmethod def from_context( cls, @@ -114,19 +117,18 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None: self.model = model def predict(self, utterances: list[str]) -> npt.NDArray[Any]: - """ - Return - --- - `(n_queries, n_classes)` matrix with zeros everywhere except the class of the best neighbor utterance - """ - labels, _, texts = self.vector_index.query( - utterances, - self.k, - ) + return self._predict(utterances)[0] - cross_encoder_scores = self._get_cross_encoder_scores(utterances, texts) - - return self._build_result(cross_encoder_scores, labels) + def predict_with_metadata( + self, + utterances: list[str], + ) -> tuple[npt.NDArray[Any], list[dict[str, Any]] | None]: + scores, neighbors, neighbors_scores = self._predict(utterances) + metadata = [ + {"neighbors": utterance_neighbors, "scores": utterance_neighbors_scores} + for utterance_neighbors, utterance_neighbors_scores in zip(neighbors, neighbors_scores, strict=True) + ] + return scores, metadata def _get_cross_encoder_scores(self, utterances: list[str], candidates: list[list[str]]) -> list[list[float]]: """ @@ -215,6 +217,19 @@ def load(self, path: str) -> None: else: self.model = CrossEncoder(crossencoder_dir, device=self.device) + def _predict( + self, + utterances: list[str], + ) -> tuple[npt.NDArray[Any], list[list[str]], list[list[float]]]: + labels, _, neigbors = self.vector_index.query( + utterances, + self.k, + ) + + cross_encoder_scores = self._get_cross_encoder_scores(utterances, neigbors) + + return self._build_result(cross_encoder_scores, labels), neigbors, cross_encoder_scores + def build_result(scores: npt.NDArray[Any], labels: npt.NDArray[Any], n_classes: int) -> npt.NDArray[Any]: res = np.zeros((len(scores), n_classes)) diff --git a/autointent/modules/scoring/knn/knn.py b/autointent/modules/scoring/knn/knn.py index 4bce5ba7..81167285 100644 --- a/autointent/modules/scoring/knn/knn.py +++ b/autointent/modules/scoring/knn/knn.py @@ -49,16 +49,20 @@ def __init__( - closest: each sample has a non zero weight iff is the closest sample of some class - `device`: str, something like "cuda:0" or "cuda:0,1,2", a device to store embedding function """ - if db_dir is None: - db_dir = str(get_db_dir()) self.embedder_name = embedder_name self.k = k self.weights = weights - self.db_dir = db_dir + self._db_dir = db_dir self.device = device self.batch_size = batch_size self.max_length = max_length + @property + def db_dir(self) -> str: + if self._db_dir is None: + self._db_dir = str(get_db_dir()) + return self._db_dir + @classmethod def from_context( cls, @@ -107,8 +111,15 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None: self._vector_index = vector_index_client.create_index(self.embedder_name, utterances, labels) def predict(self, utterances: list[str]) -> npt.NDArray[Any]: - labels, distances, _ = self._vector_index.query(utterances, self.k) - return apply_weights(np.array(labels), np.array(distances), self.weights, self.n_classes, self.multilabel) + return self._predict(utterances)[0] + + def predict_with_metadata( + self, + utterances: list[str], + ) -> tuple[npt.NDArray[Any], list[dict[str, Any]] | None]: + scores, neighbors = self._predict(utterances) + metadata = [{"neighbors": utterance_neighbors} for utterance_neighbors in neighbors] + return scores, metadata def clear_cache(self) -> None: self._vector_index.clear_ram() @@ -145,3 +156,8 @@ def load(self, path: str) -> None: embedder_max_length=self.metadata["max_length"], ) self._vector_index = vector_index_client.get_index(self.embedder_name) + + def _predict(self, utterances: list[str]) -> tuple[npt.NDArray[Any], list[list[str]]]: + labels, distances, neigbors = self._vector_index.query(utterances, self.k) + scores = apply_weights(np.array(labels), np.array(distances), self.weights, self.n_classes, self.multilabel) + return scores, neigbors diff --git a/autointent/modules/scoring/mlknn/mlknn.py b/autointent/modules/scoring/mlknn/mlknn.py index b3b903ea..04c7f1cf 100644 --- a/autointent/modules/scoring/mlknn/mlknn.py +++ b/autointent/modules/scoring/mlknn/mlknn.py @@ -50,17 +50,21 @@ def __init__( batch_size: int = 32, max_length: int | None = None, ) -> None: - if db_dir is None: - db_dir = str(get_db_dir()) self.k = k self.embedder_name = embedder_name self.s = s self.ignore_first_neighbours = ignore_first_neighbours - self.db_dir = db_dir + self._db_dir = db_dir self.device = device self.batch_size = batch_size self.max_length = max_length + @property + def db_dir(self) -> str: + if self._db_dir is None: + self._db_dir = str(get_db_dir()) + return self._db_dir + @classmethod def from_context( cls, @@ -128,7 +132,7 @@ def _compute_cond(self) -> tuple[NDArray[np.float64], NDArray[np.float64]]: c = np.zeros((self.n_classes, self.k + 1), dtype=int) cn = np.zeros((self.n_classes, self.k + 1), dtype=int) - neighbors_labels = self._get_neighbors(self.features) + neighbors_labels, _ = self._get_neighbors(self.features) for i in range(self.labels.shape[0]): deltas = np.sum(neighbors_labels[i], axis=0).astype(int) @@ -145,37 +149,33 @@ def _compute_cond(self) -> tuple[NDArray[np.float64], NDArray[np.float64]]: return cond_prob_true, cond_prob_false - def _get_neighbors(self, queries: list[str] | NDArray[Any]) -> NDArray[np.int64]: - """ - retrieve nearest neighbors and return their labels in binary format - - Return - --- - array of shape (n_queries, n_candidates, n_classes) - """ - labels, _, _ = self.vector_index.query( + def _get_neighbors( + self, + queries: list[str] | NDArray[Any], + ) -> tuple[NDArray[np.int64], list[list[str]]]: + labels, _, neighbors = self.vector_index.query( queries, self.k + self.ignore_first_neighbours, ) - return np.array([candidates[self.ignore_first_neighbours :] for candidates in labels]) + return ( + np.array([candidates[self.ignore_first_neighbours :] for candidates in labels]), + neighbors, + ) def predict_labels(self, utterances: list[str], thresh: float = 0.5) -> NDArray[np.int64]: probas = self.predict(utterances) return (probas > thresh).astype(int) def predict(self, utterances: list[str]) -> NDArray[np.float64]: - result = np.zeros((len(utterances), self.n_classes), dtype=float) - neighbors_labels = self._get_neighbors(utterances) - - for instance in range(neighbors_labels.shape[0]): - deltas = np.sum(neighbors_labels[instance], axis=0).astype(int) - - for label in range(self.n_classes): - p_true = self._prior_prob_true[label] * self._cond_prob_true[label, deltas[label]] - p_false = self._prior_prob_false[label] * self._cond_prob_false[label, deltas[label]] - result[instance, label] = p_true / (p_true + p_false) + return self._predict(utterances)[0] - return result + def predict_with_metadata( + self, + utterances: list[str], + ) -> tuple[NDArray[Any], list[dict[str, Any]] | None]: + scores, neighbors = self._predict(utterances) + metadata = [{"neighbors": utterance_neighbors} for utterance_neighbors in neighbors] + return scores, metadata def clear_cache(self) -> None: self.vector_index.clear_ram() @@ -222,3 +222,20 @@ def load(self, path: str) -> None: embedder_max_length=self.metadata["max_length"], ) self.vector_index = vector_index_client.get_index(self.embedder_name) + + def _predict( + self, + utterances: list[str], + ) -> tuple[NDArray[np.float64], list[list[str]]]: + result = np.zeros((len(utterances), self.n_classes), dtype=float) + neighbors_labels, neighbors = self._get_neighbors(utterances) + + for instance in range(neighbors_labels.shape[0]): + deltas = np.sum(neighbors_labels[instance], axis=0).astype(int) + + for label in range(self.n_classes): + p_true = self._prior_prob_true[label] * self._cond_prob_true[label, deltas[label]] + p_false = self._prior_prob_false[label] * self._cond_prob_false[label, deltas[label]] + result[instance, label] = p_true / (p_true + p_false) + + return result, neighbors diff --git a/autointent/nodes/__init__.py b/autointent/nodes/__init__.py index c2e4ba4a..a436675e 100644 --- a/autointent/nodes/__init__.py +++ b/autointent/nodes/__init__.py @@ -1,5 +1,13 @@ from .inference import InferenceNode -from .nodes_info import NodeInfo, PredictionNodeInfo, RetrievalNodeInfo, ScoringNodeInfo +from .nodes_info import NodeInfo, PredictionNodeInfo, RegExpNodeInfo, RetrievalNodeInfo, ScoringNodeInfo from .optimization import NodeOptimizer -__all__ = ["InferenceNode", "NodeInfo", "PredictionNodeInfo", "RetrievalNodeInfo", "ScoringNodeInfo", "NodeOptimizer"] +__all__ = [ + "InferenceNode", + "NodeInfo", + "PredictionNodeInfo", + "RegExpNodeInfo", + "RetrievalNodeInfo", + "ScoringNodeInfo", + "NodeOptimizer", +] diff --git a/autointent/nodes/nodes_info/__init__.py b/autointent/nodes/nodes_info/__init__.py index affb3272..9923f30d 100644 --- a/autointent/nodes/nodes_info/__init__.py +++ b/autointent/nodes/nodes_info/__init__.py @@ -2,6 +2,7 @@ from .base import NodeInfo from .prediction import PredictionNodeInfo +from .regexp import RegExpNodeInfo from .retrieval import RetrievalNodeInfo from .scoring import ScoringNodeInfo @@ -9,6 +10,14 @@ NodeType.retrieval: RetrievalNodeInfo(), NodeType.scoring: ScoringNodeInfo(), NodeType.prediction: PredictionNodeInfo(), + NodeType.regexp: RegExpNodeInfo(), } -__all__ = ["NodeInfo", "PredictionNodeInfo", "RetrievalNodeInfo", "ScoringNodeInfo", "NODES_INFO"] +__all__ = [ + "NodeInfo", + "PredictionNodeInfo", + "RegExpNodeInfo", + "RetrievalNodeInfo", + "ScoringNodeInfo", + "NODES_INFO", +] diff --git a/autointent/nodes/nodes_info/regexp.py b/autointent/nodes/nodes_info/regexp.py index c87bffb4..ccf9aef3 100644 --- a/autointent/nodes/nodes_info/regexp.py +++ b/autointent/nodes/nodes_info/regexp.py @@ -10,7 +10,7 @@ from .base import NodeInfo -class RegExpNode(NodeInfo): +class RegExpNodeInfo(NodeInfo): metrics_available: ClassVar[Mapping[str, RegexpMetricFn]] = funcs_to_dict( regexp_partial_accuracy, regexp_partial_precision, diff --git a/autointent/nodes/optimization/node_optimizer.py b/autointent/nodes/optimization/node_optimizer.py index 40ad41d3..b785e99c 100644 --- a/autointent/nodes/optimization/node_optimizer.py +++ b/autointent/nodes/optimization/node_optimizer.py @@ -95,6 +95,8 @@ def module_fit(self, module: Module, context: Context) -> None: elif self.node_info.node_type == "prediction": labels, scores = get_prediction_evaluation_data(context) args = (scores, labels, context.data_handler.tags) # type: ignore[assignment] + elif self.node_info.node_type == "regexp": + args = () # type: ignore[assignment] else: msg = "something's wrong" self._logger.error(msg) diff --git a/autointent/pipeline/inference/cli_endpoint.py b/autointent/pipeline/inference/cli_endpoint.py index 40f5d6b6..9756b0bb 100644 --- a/autointent/pipeline/inference/cli_endpoint.py +++ b/autointent/pipeline/inference/cli_endpoint.py @@ -1,7 +1,6 @@ import json import logging from pathlib import Path -from typing import TYPE_CHECKING import hydra import yaml @@ -11,9 +10,6 @@ from .inference_pipeline import InferencePipeline -if TYPE_CHECKING: - from autointent.custom_types import LabelType - @hydra.main(config_name="inference_config", config_path=".", version_base=None) def main(cfg: InferenceConfig) -> None: @@ -31,9 +27,12 @@ def main(cfg: InferenceConfig) -> None: # instantiate pipeline pipeline = InferencePipeline.from_dict_config(inference_config["nodes_configs"]) - # send data to pipeline - labels: list[LabelType] = pipeline.predict(data) + if cfg.with_metadata: + rich_output = pipeline.predict_with_metadata(data) + with Path(cfg.output_path).open("w") as file: + json.dump(rich_output.model_dump(exclude_none=True), file, indent=4) - # save results - with Path(cfg.output_path).open("w") as file: - json.dump(labels, file, cls=NumpyEncoder) + else: + output = pipeline.predict(data) + with Path(cfg.output_path).open("w") as file: + json.dump(output, file, indent=4, cls=NumpyEncoder) diff --git a/autointent/pipeline/inference/inference_pipeline.py b/autointent/pipeline/inference/inference_pipeline.py index 581352e0..e6f76206 100644 --- a/autointent/pipeline/inference/inference_pipeline.py +++ b/autointent/pipeline/inference/inference_pipeline.py @@ -1,5 +1,6 @@ from typing import Any +from pydantic import BaseModel from typing_extensions import Self from autointent.configs.node import InferenceNodeConfig @@ -8,6 +9,21 @@ from autointent.nodes.inference import InferenceNode +class InferencePipelineUtteranceOutput(BaseModel): + utterance: str + prediction: LabelType + regexp_prediction: LabelType | None + regexp_prediction_metadata: Any + score: list[float] + score_metadata: Any + + +class InferencePipelineOutput(BaseModel): + predictions: list[LabelType] + regexp_predictions: list[LabelType] | None = None + utterances: list[InferencePipelineUtteranceOutput] | None = None + + class InferencePipeline: def __init__(self, nodes: list[InferenceNode]) -> None: self.nodes = {node.node_type: node for node in nodes} @@ -27,6 +43,35 @@ def predict(self, utterances: list[str]) -> list[LabelType]: scores = self.nodes[NodeType.scoring].module.predict(utterances) return self.nodes[NodeType.prediction].module.predict(scores) # type: ignore[return-value] + def predict_with_metadata(self, utterances: list[str]) -> InferencePipelineOutput: + scores, scores_metadata = self.nodes["scoring"].module.predict_with_metadata(utterances) + predictions = self.nodes["prediction"].module.predict(scores) + regexp_predictions, regexp_predictions_metadata = None, None + if "regexp" in self.nodes: + regexp_predictions, regexp_predictions_metadata = self.nodes["regexp"].module.predict_with_metadata( + utterances, + ) + + outputs = [] + for idx, utterance in enumerate(utterances): + output = InferencePipelineUtteranceOutput( + utterance=utterance, + prediction=predictions[idx], + regexp_prediction=regexp_predictions[idx] if regexp_predictions is not None else None, + regexp_prediction_metadata=regexp_predictions_metadata[idx] + if regexp_predictions_metadata is not None + else None, + score=scores[idx], + score_metadata=scores_metadata[idx] if scores_metadata is not None else None, + ) + outputs.append(output) + + return InferencePipelineOutput( + predictions=predictions, + regexp_predictions=regexp_predictions, + utterances=outputs, + ) + def fit(self, utterances: list[str], labels: list[LabelType]) -> None: pass diff --git a/experiments/predict_with_metadata/testbed.ipynb b/experiments/predict_with_metadata/testbed.ipynb new file mode 100644 index 00000000..d88b2b64 --- /dev/null +++ b/experiments/predict_with_metadata/testbed.ipynb @@ -0,0 +1,301 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "from typing import Literal\n", + "from autointent.pipeline.optimization.utils import load_config\n", + "\n", + "TaskType = Literal[\"multiclass\", \"multilabel\", \"description\"]\n", + "\n", + "\n", + "def setup_environment() -> tuple[str, str]:\n", + " logs_dir = Path.cwd() / \"logs\"\n", + " db_dir = logs_dir / \"db\"\n", + " dump_dir = logs_dir / \"modules_dump\"\n", + " return db_dir, dump_dir, logs_dir\n", + "\n", + "def get_search_space_path(task_type: TaskType):\n", + " return Path(\"/home/voorhs/repos/AutoIntent/tests/assets/configs\").joinpath(f\"{task_type}.yaml\")\n", + "\n", + "\n", + "def get_search_space(task_type: TaskType):\n", + " path = get_search_space_path(task_type)\n", + " return load_config(str(path), multilabel=task_type == \"multilabel\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(PosixPath('/home/voorhs/repos/AutoIntent/experiments/predict_with_metadata/logs/db'),\n", + " PosixPath('/home/voorhs/repos/AutoIntent/experiments/predict_with_metadata/logs/modules_dump'),\n", + " PosixPath('/home/voorhs/repos/AutoIntent/experiments/predict_with_metadata/logs'))" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "setup_environment()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'nodes': [{'node_type': 'retrieval',\n", + " 'metric': 'retrieval_hit_rate',\n", + " 'search_space': [{'module_type': 'vector_db',\n", + " 'k': [10],\n", + " 'embedder_name': ['sentence-transformers/all-MiniLM-L6-v2',\n", + " 'avsolatorio/GIST-small-Embedding-v0']}]},\n", + " {'node_type': 'scoring',\n", + " 'metric': 'scoring_roc_auc',\n", + " 'search_space': [{'module_type': 'knn',\n", + " 'k': [5, 10],\n", + " 'weights': ['uniform', 'distance', 'closest']},\n", + " {'module_type': 'linear'},\n", + " {'module_type': 'dnnc',\n", + " 'cross_encoder_name': ['cross-encoder/ms-marco-MiniLM-L-6-v2',\n", + " 'avsolatorio/GIST-small-Embedding-v0'],\n", + " 'k': [1, 3],\n", + " 'train_head': [False, True]}]},\n", + " {'node_type': 'prediction',\n", + " 'metric': 'prediction_accuracy',\n", + " 'search_space': [{'module_type': 'threshold',\n", + " 'thresh': [0.5, [0.5, 0.5, 0.5]]},\n", + " {'module_type': 'tunable'},\n", + " {'module_type': 'argmax'},\n", + " {'module_type': 'jinoos'}]}]}" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "get_search_space(\"multiclass\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "from autointent.context.utils import load_data\n", + "\n", + "\n", + "def get_dataset_path():\n", + " return Path(\"/home/voorhs/repos/AutoIntent/tests/assets/data\").joinpath(\"clinc_subset.json\")\n", + "\n", + "\n", + "def get_dataset():\n", + " return load_data(get_dataset_path())" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "task_type = \"multiclass\"" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at avsolatorio/GIST-small-Embedding-v0 and are newly initialized: ['classifier.bias', 'classifier.weight']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at avsolatorio/GIST-small-Embedding-v0 and are newly initialized: ['classifier.bias', 'classifier.weight']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at avsolatorio/GIST-small-Embedding-v0 and are newly initialized: ['classifier.bias', 'classifier.weight']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at avsolatorio/GIST-small-Embedding-v0 and are newly initialized: ['classifier.bias', 'classifier.weight']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "[I 2024-11-11 13:13:01,596] A new study created in memory with name: no-name-066e9d3e-65d1-45d8-b87b-29160b3a9f1f\n" + ] + } + ], + "source": [ + "from autointent.pipeline.optimization import PipelineOptimizer\n", + "from autointent.configs.optimization_cli import LoggingConfig, VectorIndexConfig, EmbedderConfig\n", + "\n", + "db_dir, dump_dir, logs_dir = setup_environment()\n", + "search_space = get_search_space(task_type)\n", + "\n", + "pipeline_optimizer = PipelineOptimizer.from_dict_config(search_space)\n", + "\n", + "pipeline_optimizer.set_config(LoggingConfig(dirpath=Path(logs_dir).resolve(), dump_modules=True))\n", + "pipeline_optimizer.set_config(VectorIndexConfig(db_dir=Path(db_dir).resolve(), device=\"cpu\", save_db=True))\n", + "pipeline_optimizer.set_config(EmbedderConfig(batch_size=16, max_length=32))\n", + "\n", + "\n", + "dataset = get_dataset()\n", + "context = pipeline_optimizer.optimize_from_dataset(dataset, force_multilabel=(task_type == \"multilabel\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "from autointent.pipeline.inference import InferencePipeline\n", + "\n", + "\n", + "inference_pipeline = InferencePipeline.from_context(context)\n", + "prediction = inference_pipeline.predict_with_metadata([\"123\", \"hello world\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2, 2]\n", + "[InferencePipelineUtteranceOutput(utterance='123', prediction=2, regexp_prediction=None, regexp_prediction_metadata=None, score=[0.0, 0.4, 0.6], score_metadata={'neighbors': ['set my alarm for getting up', 'wake me up at noon tomorrow', 'i am nost sure why my account is blocked', 'i think my account is blocked but i do not know the reason', 'please set an alarm for mid day']}),\n", + " InferencePipelineUtteranceOutput(utterance='hello world', prediction=2, regexp_prediction=None, regexp_prediction_metadata=None, score=[0.0, 0.4, 0.6], score_metadata={'neighbors': ['wake me up at noon tomorrow', 'set my alarm for getting up', 'please set an alarm for mid day', 'why is there a hold on my american saving bank account', 'i am nost sure why my account is blocked']})]\n" + ] + } + ], + "source": [ + "from pprint import pprint\n", + "\n", + "pprint(prediction.predictions)\n", + "pprint(prediction.utterances)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if task_type == \"multilabel\":\n", + " assert prediction.shape == (2, len(dataset.intents))\n", + "else:\n", + " assert prediction.shape == (2,)\n", + "\n", + "context.dump()\n", + "context.vector_index_client.delete_db()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "autointent-D7M6VOhJ-py3.12", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/conftest.py b/tests/conftest.py index af217b98..5de9a125 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,11 +13,10 @@ def setup_environment() -> tuple[str, str]: return db_dir, dump_dir, logs_dir -@pytest.fixture -def dataset_path(): +def get_dataset_path(): return ires.files("tests.assets.data").joinpath("clinc_subset.json") @pytest.fixture -def dataset(dataset_path): - return load_data(dataset_path) +def dataset(): + return load_data(get_dataset_path()) diff --git a/tests/modules/retrieval/test_vectordb.py b/tests/modules/retrieval/test_vectordb.py index 2d72753e..d188afb2 100644 --- a/tests/modules/retrieval/test_vectordb.py +++ b/tests/modules/retrieval/test_vectordb.py @@ -1,7 +1,9 @@ from autointent.modules.retrieval.vectordb import VectorDBModule +from tests.conftest import setup_environment -def test_get_assets_returns_correct_artifact(tmp_path): - module = VectorDBModule(k=5, embedder_name="sergeyzh/rubert-tiny-turbo", db_dir=str(tmp_path)) +def test_get_assets_returns_correct_artifact(): + db_dir, dump_dir, logs_dir = setup_environment() + module = VectorDBModule(k=5, embedder_name="sergeyzh/rubert-tiny-turbo", db_dir=db_dir) artifact = module.get_assets() assert artifact.embedder_name == "sergeyzh/rubert-tiny-turbo" diff --git a/tests/modules/scoring/test_description.py b/tests/modules/scoring/test_description.py index c50e0a2e..d3668d67 100644 --- a/tests/modules/scoring/test_description.py +++ b/tests/modules/scoring/test_description.py @@ -17,7 +17,7 @@ def test_description_scorer(dataset, expected_prediction, multilabel): db_dir, dump_dir, logs_dir = setup_environment() data_handler = DataHandler(dataset, force_multilabel=multilabel) - scorer = DescriptionScorer(embedder_name="sergeyzh/rubert-tiny-turbo", db_dir=db_dir, temperature=0.3) + scorer = DescriptionScorer(embedder_name="sergeyzh/rubert-tiny-turbo", temperature=0.3, device="cpu") scorer.fit(data_handler.utterances_train, data_handler.labels_train, data_handler.label_description) assert scorer.description_vectors.shape[0] == len(data_handler.label_description) @@ -36,4 +36,8 @@ def test_description_scorer(dataset, expected_prediction, multilabel): assert predictions.shape == (len(test_utterances), len(data_handler.label_description)) np.testing.assert_almost_equal(predictions, np.array(expected_prediction).reshape(predictions.shape), decimal=1) + predictions, metadata = scorer.predict_with_metadata(test_utterances) + assert len(predictions) == len(test_utterances) + assert metadata is None + scorer.clear_cache() diff --git a/tests/modules/scoring/test_dnnc.py b/tests/modules/scoring/test_dnnc.py index efadda20..9bc3d941 100644 --- a/tests/modules/scoring/test_dnnc.py +++ b/tests/modules/scoring/test_dnnc.py @@ -6,14 +6,20 @@ from tests.conftest import setup_environment -@pytest.mark.xfail(reason="This test is failing on windows, because have different score") -@pytest.mark.parametrize(("train_head", "pred_score"), [(True, 1), (False, 0.5)]) +@pytest.mark.parametrize(("train_head", "pred_score"), [(True, 1)]) def test_base_dnnc(dataset, train_head, pred_score): db_dir, dump_dir, logs_dir = setup_environment() data_handler = DataHandler(dataset) - scorer = DNNCScorer("sergeyzh/rubert-tiny-turbo", k=3, train_head=train_head) + scorer = DNNCScorer( + cross_encoder_name="cross-encoder/ms-marco-MiniLM-L-6-v2", + embedder_name="sergeyzh/rubert-tiny-turbo", + k=3, + train_head=train_head, + db_dir=db_dir, + device="cpu", + ) scorer.fit(data_handler.utterances_train, data_handler.labels_train) test_data = [ @@ -25,4 +31,10 @@ def test_base_dnnc(dataset, train_head, pred_score): ] predictions = scorer.predict(test_data) np.testing.assert_almost_equal(np.array([[0.0, pred_score, 0.0]] * len(test_data)), predictions, decimal=2) + + predictions, metadata = scorer.predict_with_metadata(test_data) + assert len(predictions) == len(test_data) + assert "neighbors" in metadata[0] + assert "scores" in metadata[0] + scorer.clear_cache() diff --git a/tests/modules/scoring/test_knn.py b/tests/modules/scoring/test_knn.py index 735adc1e..994491ee 100644 --- a/tests/modules/scoring/test_knn.py +++ b/tests/modules/scoring/test_knn.py @@ -10,18 +10,22 @@ def test_base_knn(dataset): data_handler = DataHandler(dataset) - scorer = KNNScorer(k=3, weights="distance", embedder_name="sergeyzh/rubert-tiny-turbo", db_dir=db_dir) + scorer = KNNScorer(k=3, weights="distance", embedder_name="sergeyzh/rubert-tiny-turbo", db_dir=db_dir, device="cpu") + + test_data = [ + "why is there a hold on my american saving bank account", + "i am nost sure why my account is blocked", + "why is there a hold on my capital one checking account", + "i think my account is blocked but i do not know the reason", + "can you tell me why is my bank account frozen", + ] scorer.fit(data_handler.utterances_train, data_handler.labels_train) - predictions = scorer.predict( - [ - "why is there a hold on my american saving bank account", - "i am nost sure why my account is blocked", - "why is there a hold on my capital one checking account", - "i think my account is blocked but i do not know the reason", - "can you tell me why is my bank account frozen", - ] - ) + predictions = scorer.predict(test_data) assert ( predictions == np.array([[0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]]) ).all() + + predictions, metadata = scorer.predict_with_metadata(test_data) + assert len(predictions) == len(test_data) + assert "neighbors" in metadata[0] diff --git a/tests/modules/scoring/test_linear.py b/tests/modules/scoring/test_linear.py index 9ac62182..77400371 100644 --- a/tests/modules/scoring/test_linear.py +++ b/tests/modules/scoring/test_linear.py @@ -10,7 +10,7 @@ def test_base_linear(dataset): data_handler = DataHandler(dataset) - scorer = LinearScorer("sergeyzh/rubert-tiny-turbo") + scorer = LinearScorer(embedder_name="sergeyzh/rubert-tiny-turbo", device="cpu") scorer.fit(data_handler.utterances_train, data_handler.labels_train) test_data = [ @@ -42,3 +42,7 @@ def test_base_linear(dataset): predictions, decimal=2, ) + + predictions, metadata = scorer.predict_with_metadata(test_data) + assert len(predictions) == len(test_data) + assert metadata is None diff --git a/tests/modules/scoring/test_mlknn.py b/tests/modules/scoring/test_mlknn.py index 6cf44386..e8cfae50 100644 --- a/tests/modules/scoring/test_mlknn.py +++ b/tests/modules/scoring/test_mlknn.py @@ -24,16 +24,20 @@ def test_base_mlknn(dataset): ) data_handler = DataHandler(dataset, test_dataset, force_multilabel=True) - scorer = MLKnnScorer(db_dir=db_dir, k=3, embedder_name="sergeyzh/rubert-tiny-turbo") + scorer = MLKnnScorer(embedder_name="sergeyzh/rubert-tiny-turbo", k=3, db_dir=db_dir, device="cpu") scorer.fit(data_handler.utterances_train, data_handler.labels_train) - predictions = scorer.predict_labels( - [ - "why is there a hold on my american saving bank account", - "i am nost sure why my account is blocked", - "why is there a hold on my capital one checking account", - "i think my account is blocked but i do not know the reason", - "can you tell me why is my bank account frozen", - ] - ) + test_data = [ + "why is there a hold on my american saving bank account", + "i am nost sure why my account is blocked", + "why is there a hold on my capital one checking account", + "i think my account is blocked but i do not know the reason", + "can you tell me why is my bank account frozen", + ] + + predictions = scorer.predict_labels(test_data) assert (predictions == np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0], [0, 1, 0], [0, 1, 0]])).all() + + predictions, metadata = scorer.predict_with_metadata(test_data) + assert len(predictions) == len(test_data) + assert "neighbors" in metadata[0] diff --git a/tests/modules/test_regex.py b/tests/modules/test_regex.py index bd834d89..e6480c9b 100644 --- a/tests/modules/test_regex.py +++ b/tests/modules/test_regex.py @@ -1,101 +1,30 @@ import pytest -from autointent import Context -from autointent.context.data_handler import Dataset -from autointent.metrics import retrieval_hit_rate, scoring_roc_auc -from autointent.modules import RegExp, VectorDBModule -from tests.conftest import setup_environment - - -@pytest.mark.xfail(reason="Issues with intent_id") -def test_base_regex(): - db_dir, dump_dir, logs_dir = setup_environment() - - data = { - "utterances": [ - { - "text": "can i make a reservation for redrobin", - "label": 0, - }, - { - "text": "is it possible to make a reservation at redrobin", - "label": 0, - }, - { - "text": "does redrobin take reservations", - "label": 0, - }, - { - "text": "are reservations taken at redrobin", - "label": 0, - }, - { - "text": "does redrobin do reservations", - "label": 0, - }, - { - "text": "why is there a hold on my american saving bank account", - "label": 1, - }, - { - "text": "i am nost sure why my account is blocked", - "label": 1, - }, - { - "text": "why is there a hold on my capital one checking account", - "label": 1, - }, - { - "text": "i think my account is blocked but i do not know the reason", - "label": 1, - }, - { - "text": "can you tell me why is my bank account frozen", - "label": 1, - }, - ], - "intents": [ - { - "id": 0, - "name": "accept_reservations", - "regexp_full_match": [".*"], - "regexp_partial_match": [".*"], - }, - { - "id": 1, - "name": "account_blocked", - "regexp_full_match": [".*"], - "regexp_partial_match": [".*"], - }, - ], - } - - context = Context( - dataset=Dataset.model_validate(data), - dump_dir=dump_dir, - db_dir=db_dir, - ) - - retrieval_params = {"k": 3, "embedder_name": "sergeyzh/rubert-tiny-turbo"} - vector_db = VectorDBModule(**retrieval_params) - vector_db.fit(context) - metric_value = vector_db.score(context, retrieval_hit_rate) - artifact = vector_db.get_assets() - context.optimization_info.log_module_optimization( - node_type="retrieval", - module_type="vector_db", - module_params=retrieval_params, - metric_value=metric_value, - metric_name="retrieval_hit_rate_macro", - artifact=artifact, - module_dump_dir="", - ) +from autointent.modules import RegExp + + +@pytest.mark.parametrize( + ("partial_match", "expected_predictions"), + [(".*", [[0, 1], [0, 1], [0, 1], [0, 1], [0, 1]]), ("frozen", [[0], [0], [0], [0], [0, 1]])], +) +def test_base_regex(partial_match, expected_predictions): + train_data = [ + { + "id": 0, + "name": "accept_reservations", + "regexp_full_match": [".*"], + "regexp_partial_match": [".*"], + }, + { + "id": 1, + "name": "account_blocked", + "regexp_partial_match": [partial_match], + }, + ] - scorer = RegExp() + matcher = RegExp() + matcher.fit(train_data) - scorer.fit(context) - score, _ = scorer.score(context, scoring_roc_auc) - assert score == 0.5 test_data = [ "why is there a hold on my american saving bank account", "i am nost sure why my account is blocked", @@ -103,5 +32,11 @@ def test_base_regex(): "i think my account is blocked but i do not know the reason", "can you tell me why is my bank account frozen", ] - predictions = scorer.predict(test_data) - assert predictions == [[0, 1], [0, 1], [0, 1], [0, 1], [0, 1]] + predictions = matcher.predict(test_data) + assert predictions == expected_predictions + + predictions, metadata = matcher.predict_with_metadata(test_data) + assert len(predictions) == len(test_data) == len(metadata) + + assert "partial_matches" in metadata[0] + assert "full_matches" in metadata[0] diff --git a/tests/nodes/conftest.py b/tests/nodes/conftest.py index 9cde1f71..8930ce6e 100644 --- a/tests/nodes/conftest.py +++ b/tests/nodes/conftest.py @@ -3,7 +3,7 @@ from autointent import Context from autointent.configs.optimization_cli import DataConfig, LoggingConfig, VectorIndexConfig from autointent.nodes.optimization import NodeOptimizer -from tests.conftest import setup_environment +from tests.conftest import get_dataset_path, setup_environment @pytest.fixture @@ -38,8 +38,8 @@ def get_retrieval_optimizer(multilabel: bool): @pytest.fixture -def scoring_optimizer_multiclass(context, retrieval_optimizer_multiclass): - context = context(multilabel=False) +def scoring_optimizer_multiclass(retrieval_optimizer_multiclass): + context = get_context(multilabel=False) retrieval_optimizer_multiclass.fit(context) scoring_optimizer_config = { @@ -54,8 +54,8 @@ def scoring_optimizer_multiclass(context, retrieval_optimizer_multiclass): @pytest.fixture -def scoring_optimizer_multilabel(context, retrieval_optimizer_multilabel): - context = context(multilabel=True) +def scoring_optimizer_multilabel(retrieval_optimizer_multilabel): + context = get_context(multilabel=True) retrieval_optimizer_multilabel.fit(context) scoring_optimizer_config = { @@ -69,15 +69,11 @@ def scoring_optimizer_multilabel(context, retrieval_optimizer_multilabel): return context, NodeOptimizer.from_dict_config(scoring_optimizer_config) -@pytest.fixture -def context(dataset_path): +def get_context(multilabel): db_dir, dump_dir, logs_dir = setup_environment() - def _context(multilabel: bool): - res = Context() - res.configure_data(DataConfig(dataset_path, force_multilabel=multilabel)) - res.configure_logging(LoggingConfig(dirpath=logs_dir, dump_dir=dump_dir, dump_modules=True)) - res.configure_vector_index(VectorIndexConfig(db_dir=db_dir)) - return res - - return _context + res = Context() + res.configure_data(DataConfig(get_dataset_path(), force_multilabel=multilabel)) + res.configure_logging(LoggingConfig(dirpath=logs_dir, dump_dir=dump_dir, dump_modules=True)) + res.configure_vector_index(VectorIndexConfig(db_dir=db_dir, device="cpu")) + return res diff --git a/tests/nodes/test_retrieval.py b/tests/nodes/test_retrieval.py index 683ca59d..40e8c379 100644 --- a/tests/nodes/test_retrieval.py +++ b/tests/nodes/test_retrieval.py @@ -6,11 +6,13 @@ from autointent.configs.node import InferenceNodeConfig from autointent.nodes import InferenceNode, NodeOptimizer +from .conftest import get_context + logger = logging.getLogger(__name__) -def test_retrieval_multiclass(context): - context = context(multilabel=False) +def test_retrieval_multiclass(): + context = get_context(multilabel=False) retrieval_optimizer = get_retrieval_optimizer(multilabel=False) retrieval_optimizer.fit(context) @@ -28,8 +30,8 @@ def test_retrieval_multiclass(context): torch.cuda.empty_cache() -def test_retrieval_multilabel(context): - context = context(multilabel=True) +def test_retrieval_multilabel(): + context = get_context(multilabel=True) retrieval_optimizer = get_retrieval_optimizer(multilabel=True) retrieval_optimizer.fit(context) diff --git a/tests/nodes/test_scoring.py b/tests/nodes/test_scoring.py index 70eddf91..ca96d20e 100644 --- a/tests/nodes/test_scoring.py +++ b/tests/nodes/test_scoring.py @@ -7,11 +7,13 @@ from autointent.nodes import InferenceNode from autointent.nodes.optimization import NodeOptimizer +from .conftest import get_context + logger = logging.getLogger(__name__) -def test_scoring_multiclass(context, retrieval_optimizer_multiclass): - context = context(multilabel=False) +def test_scoring_multiclass(retrieval_optimizer_multiclass): + context = get_context(multilabel=False) retrieval_optimizer_multiclass.fit(context) scoring_optimizer_config = { @@ -61,8 +63,8 @@ def test_scoring_multiclass(context, retrieval_optimizer_multiclass): torch.cuda.empty_cache() -def test_scoring_multilabel(context, retrieval_optimizer_multilabel): - context = context(multilabel=True) +def test_scoring_multilabel(retrieval_optimizer_multilabel): + context = get_context(multilabel=True) retrieval_optimizer_multilabel.fit(context) scoring_optimizer_config = { diff --git a/tests/pipeline/test_inference.py b/tests/pipeline/test_inference.py index d7facbea..29dea405 100644 --- a/tests/pipeline/test_inference.py +++ b/tests/pipeline/test_inference.py @@ -46,12 +46,16 @@ def test_inference_config(dataset, task_type): inference_config = context.optimization_info.get_inference_nodes_config() inference_pipeline = InferencePipeline.from_config(inference_config) - prediction = inference_pipeline.predict(["123", "hello world"]) + utterances = ["123", "hello world"] + prediction = inference_pipeline.predict(utterances) if task_type == "multilabel": assert prediction.shape == (2, len(dataset.intents)) else: assert prediction.shape == (2,) + rich_outputs = inference_pipeline.predict_with_metadata(utterances) + assert len(rich_outputs.predictions) == len(utterances) + context.dump() context.vector_index_client.delete_db() @@ -66,19 +70,23 @@ def test_inference_context(dataset, task_type): pipeline_optimizer = PipelineOptimizer.from_dict_config(search_space) - pipeline_optimizer.set_config(LoggingConfig(dirpath=Path(logs_dir).resolve(), dump_modules=True)) + pipeline_optimizer.set_config(LoggingConfig(dirpath=Path(logs_dir).resolve(), dump_modules=False, clear_ram=False)) pipeline_optimizer.set_config(VectorIndexConfig(db_dir=Path(db_dir).resolve(), device="cpu", save_db=True)) pipeline_optimizer.set_config(EmbedderConfig(batch_size=16, max_length=32)) context = pipeline_optimizer.optimize_from_dataset(dataset, force_multilabel=(task_type == "multilabel")) inference_pipeline = InferencePipeline.from_context(context) - prediction = inference_pipeline.predict(["123", "hello world"]) + utterances = ["123", "hello world"] + prediction = inference_pipeline.predict(utterances) if task_type == "multilabel": assert prediction.shape == (2, len(dataset.intents)) else: assert prediction.shape == (2,) + rich_outputs = inference_pipeline.predict_with_metadata(utterances) + assert len(rich_outputs.predictions) == len(utterances) + context.dump() context.vector_index_client.delete_db() @@ -107,6 +115,7 @@ def test_inference_pipeline_cli(dataset, task_type): source_dir=logging_config.dirpath, output_path=logging_config.dump_dir / "predictions.json", log_level="CRITICAL", + with_metadata=False, ) inference_pipeline(config) context.vector_index_client.delete_db() diff --git a/tests/pipeline/test_optimization.py b/tests/pipeline/test_optimization.py index d34c2e29..c5c8d367 100644 --- a/tests/pipeline/test_optimization.py +++ b/tests/pipeline/test_optimization.py @@ -103,6 +103,7 @@ def test_optimization_pipeline_cli(task_type): search_space_path=get_search_space_path(task_type), ), vector_index=VectorIndexConfig( + db_dir=db_dir, device="cpu", ), logs=LoggingConfig(