Skip to content

Commit

Permalink
add regex test and rich predict test
Browse files Browse the repository at this point in the history
  • Loading branch information
voorhs committed Nov 11, 2024
1 parent c3c44fc commit 1746556
Show file tree
Hide file tree
Showing 7 changed files with 327 additions and 16 deletions.
1 change: 1 addition & 0 deletions autointent/configs/inference_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class InferenceConfig:
source_dir: str
output_path: str
log_level: LogLevel = LogLevel.ERROR
with_metadata: bool = False


cs = ConfigStore.instance()
Expand Down
4 changes: 3 additions & 1 deletion autointent/modules/regexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def _match(self, utterance: str, intent_record: RegexPatternsCompiled) -> dict[s
if pattern.fullmatch(utterance) is not None
]
partial_matches = [
pattern.pattern for pattern in intent_record["regexp_partial_match"] if pattern.match(utterance) is not None
pattern.pattern
for pattern in intent_record["regexp_partial_match"]
if pattern.search(utterance) is not None
]
return {"full_matches": full_matches, "partial_matches": partial_matches}

Expand Down
2 changes: 1 addition & 1 deletion autointent/pipeline/inference/cli_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def main(cfg: InferenceConfig) -> None:
pipeline = InferencePipeline.from_dict_config(inference_config["nodes_configs"])

# send data to pipeline
output = pipeline.predict_with_metadata(data)
output = pipeline.predict_with_metadata(data) if cfg.with_metadata else pipeline.predict(data)

# save results
with Path(cfg.output_path).open("w") as file:
Expand Down
11 changes: 3 additions & 8 deletions autointent/pipeline/inference/inference_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class InferencePipelineUtteranceOutput(BaseModel):
utterance: str
prediction: LabelType
regexp_prediction: LabelType
regexp_prediction: LabelType | None
regexp_prediction_metadata: Any
score: list[float]
score_metadata: Any
Expand All @@ -38,14 +38,9 @@ def from_config(cls, nodes_configs: list[InferenceNodeConfig]) -> "InferencePipe
nodes = [InferenceNode.from_config(cfg) for cfg in nodes_configs]
return cls(nodes)

def predict(self, utterances: list[str]) -> InferencePipelineOutput:
def predict(self, utterances: list[str]) -> list[LabelType]:
scores = self.nodes[NodeType.scoring].module.predict(utterances)
predictions = self.nodes[NodeType.prediction].module.predict(scores)

regexp_predictions = None
if "regexp" in self.nodes:
regexp_predictions = self.nodes["regexp"].module.predict(utterances)
return InferencePipelineOutput(predictions=predictions, regexp_predictions=regexp_predictions)
return self.nodes[NodeType.prediction].module.predict(scores)

def predict_with_metadata(self, utterances: list[str]) -> InferencePipelineOutput:
scores, scores_metadata = self.nodes["scoring"].module.predict_with_metadata(utterances)
Expand Down
301 changes: 301 additions & 0 deletions experiments/predict_with_metadata/testbed.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,301 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"from typing import Literal\n",
"from autointent.pipeline.optimization.utils import load_config\n",
"\n",
"TaskType = Literal[\"multiclass\", \"multilabel\", \"description\"]\n",
"\n",
"\n",
"def setup_environment() -> tuple[str, str]:\n",
" logs_dir = Path.cwd() / \"logs\"\n",
" db_dir = logs_dir / \"db\"\n",
" dump_dir = logs_dir / \"modules_dump\"\n",
" return db_dir, dump_dir, logs_dir\n",
"\n",
"def get_search_space_path(task_type: TaskType):\n",
" return Path(\"/home/voorhs/repos/AutoIntent/tests/assets/configs\").joinpath(f\"{task_type}.yaml\")\n",
"\n",
"\n",
"def get_search_space(task_type: TaskType):\n",
" path = get_search_space_path(task_type)\n",
" return load_config(str(path), multilabel=task_type == \"multilabel\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(PosixPath('/home/voorhs/repos/AutoIntent/experiments/predict_with_metadata/logs/db'),\n",
" PosixPath('/home/voorhs/repos/AutoIntent/experiments/predict_with_metadata/logs/modules_dump'),\n",
" PosixPath('/home/voorhs/repos/AutoIntent/experiments/predict_with_metadata/logs'))"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"setup_environment()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'nodes': [{'node_type': 'retrieval',\n",
" 'metric': 'retrieval_hit_rate',\n",
" 'search_space': [{'module_type': 'vector_db',\n",
" 'k': [10],\n",
" 'embedder_name': ['sentence-transformers/all-MiniLM-L6-v2',\n",
" 'avsolatorio/GIST-small-Embedding-v0']}]},\n",
" {'node_type': 'scoring',\n",
" 'metric': 'scoring_roc_auc',\n",
" 'search_space': [{'module_type': 'knn',\n",
" 'k': [5, 10],\n",
" 'weights': ['uniform', 'distance', 'closest']},\n",
" {'module_type': 'linear'},\n",
" {'module_type': 'dnnc',\n",
" 'cross_encoder_name': ['cross-encoder/ms-marco-MiniLM-L-6-v2',\n",
" 'avsolatorio/GIST-small-Embedding-v0'],\n",
" 'k': [1, 3],\n",
" 'train_head': [False, True]}]},\n",
" {'node_type': 'prediction',\n",
" 'metric': 'prediction_accuracy',\n",
" 'search_space': [{'module_type': 'threshold',\n",
" 'thresh': [0.5, [0.5, 0.5, 0.5]]},\n",
" {'module_type': 'tunable'},\n",
" {'module_type': 'argmax'},\n",
" {'module_type': 'jinoos'}]}]}"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_search_space(\"multiclass\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"from autointent.context.utils import load_data\n",
"\n",
"\n",
"def get_dataset_path():\n",
" return Path(\"/home/voorhs/repos/AutoIntent/tests/assets/data\").joinpath(\"clinc_subset.json\")\n",
"\n",
"\n",
"def get_dataset():\n",
" return load_data(get_dataset_path())"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"task_type = \"multiclass\""
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at avsolatorio/GIST-small-Embedding-v0 and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at avsolatorio/GIST-small-Embedding-v0 and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at avsolatorio/GIST-small-Embedding-v0 and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at avsolatorio/GIST-small-Embedding-v0 and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"[I 2024-11-11 13:13:01,596] A new study created in memory with name: no-name-066e9d3e-65d1-45d8-b87b-29160b3a9f1f\n"
]
}
],
"source": [
"from autointent.pipeline.optimization import PipelineOptimizer\n",
"from autointent.configs.optimization_cli import LoggingConfig, VectorIndexConfig, EmbedderConfig\n",
"\n",
"db_dir, dump_dir, logs_dir = setup_environment()\n",
"search_space = get_search_space(task_type)\n",
"\n",
"pipeline_optimizer = PipelineOptimizer.from_dict_config(search_space)\n",
"\n",
"pipeline_optimizer.set_config(LoggingConfig(dirpath=Path(logs_dir).resolve(), dump_modules=True))\n",
"pipeline_optimizer.set_config(VectorIndexConfig(db_dir=Path(db_dir).resolve(), device=\"cpu\", save_db=True))\n",
"pipeline_optimizer.set_config(EmbedderConfig(batch_size=16, max_length=32))\n",
"\n",
"\n",
"dataset = get_dataset()\n",
"context = pipeline_optimizer.optimize_from_dataset(dataset, force_multilabel=(task_type == \"multilabel\"))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"from autointent.pipeline.inference import InferencePipeline\n",
"\n",
"\n",
"inference_pipeline = InferencePipeline.from_context(context)\n",
"prediction = inference_pipeline.predict_with_metadata([\"123\", \"hello world\"])"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2, 2]\n",
"[InferencePipelineUtteranceOutput(utterance='123', prediction=2, regexp_prediction=None, regexp_prediction_metadata=None, score=[0.0, 0.4, 0.6], score_metadata={'neighbors': ['set my alarm for getting up', 'wake me up at noon tomorrow', 'i am nost sure why my account is blocked', 'i think my account is blocked but i do not know the reason', 'please set an alarm for mid day']}),\n",
" InferencePipelineUtteranceOutput(utterance='hello world', prediction=2, regexp_prediction=None, regexp_prediction_metadata=None, score=[0.0, 0.4, 0.6], score_metadata={'neighbors': ['wake me up at noon tomorrow', 'set my alarm for getting up', 'please set an alarm for mid day', 'why is there a hold on my american saving bank account', 'i am nost sure why my account is blocked']})]\n"
]
}
],
"source": [
"from pprint import pprint\n",
"\n",
"pprint(prediction.predictions)\n",
"pprint(prediction.utterances)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if task_type == \"multilabel\":\n",
" assert prediction.shape == (2, len(dataset.intents))\n",
"else:\n",
" assert prediction.shape == (2,)\n",
"\n",
"context.dump()\n",
"context.vector_index_client.delete_db()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "autointent-D7M6VOhJ-py3.12",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
13 changes: 9 additions & 4 deletions tests/modules/test_regex.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import pytest

from autointent.modules import RegExp
from tests.conftest import setup_environment


def test_base_regex():
@pytest.mark.parametrize(
("partial_match", "expected_predictions"),
[(".*", [[0, 1], [0, 1], [0, 1], [0, 1], [0, 1]]), ("frozen", [[0], [0], [0], [0], [0, 1]])],
)
def test_base_regex(partial_match, expected_predictions):
db_dir, dump_dir, logs_dir = setup_environment()

train_data = [
Expand All @@ -15,8 +21,7 @@ def test_base_regex():
{
"id": 1,
"name": "account_blocked",
"regexp_full_match": [".*"],
"regexp_partial_match": [".*"],
"regexp_partial_match": [partial_match],
},
]

Expand All @@ -31,4 +36,4 @@ def test_base_regex():
"can you tell me why is my bank account frozen",
]
predictions = matcher.predict(test_data)
assert predictions == [[0, 1], [0, 1], [0, 1], [0, 1], [0, 1]]
assert predictions == expected_predictions
Loading

0 comments on commit 1746556

Please sign in to comment.