-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add regex test and rich predict test
- Loading branch information
Showing
7 changed files
with
327 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,301 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from pathlib import Path\n", | ||
"from typing import Literal\n", | ||
"from autointent.pipeline.optimization.utils import load_config\n", | ||
"\n", | ||
"TaskType = Literal[\"multiclass\", \"multilabel\", \"description\"]\n", | ||
"\n", | ||
"\n", | ||
"def setup_environment() -> tuple[str, str]:\n", | ||
" logs_dir = Path.cwd() / \"logs\"\n", | ||
" db_dir = logs_dir / \"db\"\n", | ||
" dump_dir = logs_dir / \"modules_dump\"\n", | ||
" return db_dir, dump_dir, logs_dir\n", | ||
"\n", | ||
"def get_search_space_path(task_type: TaskType):\n", | ||
" return Path(\"/home/voorhs/repos/AutoIntent/tests/assets/configs\").joinpath(f\"{task_type}.yaml\")\n", | ||
"\n", | ||
"\n", | ||
"def get_search_space(task_type: TaskType):\n", | ||
" path = get_search_space_path(task_type)\n", | ||
" return load_config(str(path), multilabel=task_type == \"multilabel\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"(PosixPath('/home/voorhs/repos/AutoIntent/experiments/predict_with_metadata/logs/db'),\n", | ||
" PosixPath('/home/voorhs/repos/AutoIntent/experiments/predict_with_metadata/logs/modules_dump'),\n", | ||
" PosixPath('/home/voorhs/repos/AutoIntent/experiments/predict_with_metadata/logs'))" | ||
] | ||
}, | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"setup_environment()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{'nodes': [{'node_type': 'retrieval',\n", | ||
" 'metric': 'retrieval_hit_rate',\n", | ||
" 'search_space': [{'module_type': 'vector_db',\n", | ||
" 'k': [10],\n", | ||
" 'embedder_name': ['sentence-transformers/all-MiniLM-L6-v2',\n", | ||
" 'avsolatorio/GIST-small-Embedding-v0']}]},\n", | ||
" {'node_type': 'scoring',\n", | ||
" 'metric': 'scoring_roc_auc',\n", | ||
" 'search_space': [{'module_type': 'knn',\n", | ||
" 'k': [5, 10],\n", | ||
" 'weights': ['uniform', 'distance', 'closest']},\n", | ||
" {'module_type': 'linear'},\n", | ||
" {'module_type': 'dnnc',\n", | ||
" 'cross_encoder_name': ['cross-encoder/ms-marco-MiniLM-L-6-v2',\n", | ||
" 'avsolatorio/GIST-small-Embedding-v0'],\n", | ||
" 'k': [1, 3],\n", | ||
" 'train_head': [False, True]}]},\n", | ||
" {'node_type': 'prediction',\n", | ||
" 'metric': 'prediction_accuracy',\n", | ||
" 'search_space': [{'module_type': 'threshold',\n", | ||
" 'thresh': [0.5, [0.5, 0.5, 0.5]]},\n", | ||
" {'module_type': 'tunable'},\n", | ||
" {'module_type': 'argmax'},\n", | ||
" {'module_type': 'jinoos'}]}]}" | ||
] | ||
}, | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"get_search_space(\"multiclass\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from autointent.context.utils import load_data\n", | ||
"\n", | ||
"\n", | ||
"def get_dataset_path():\n", | ||
" return Path(\"/home/voorhs/repos/AutoIntent/tests/assets/data\").joinpath(\"clinc_subset.json\")\n", | ||
"\n", | ||
"\n", | ||
"def get_dataset():\n", | ||
" return load_data(get_dataset_path())" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 11, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"task_type = \"multiclass\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 12, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", | ||
"To disable this warning, you can either:\n", | ||
"\t- Avoid using `tokenizers` before the fork if possible\n", | ||
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", | ||
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", | ||
"To disable this warning, you can either:\n", | ||
"\t- Avoid using `tokenizers` before the fork if possible\n", | ||
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", | ||
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", | ||
"To disable this warning, you can either:\n", | ||
"\t- Avoid using `tokenizers` before the fork if possible\n", | ||
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", | ||
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", | ||
"To disable this warning, you can either:\n", | ||
"\t- Avoid using `tokenizers` before the fork if possible\n", | ||
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", | ||
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", | ||
"To disable this warning, you can either:\n", | ||
"\t- Avoid using `tokenizers` before the fork if possible\n", | ||
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", | ||
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", | ||
"To disable this warning, you can either:\n", | ||
"\t- Avoid using `tokenizers` before the fork if possible\n", | ||
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", | ||
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", | ||
"To disable this warning, you can either:\n", | ||
"\t- Avoid using `tokenizers` before the fork if possible\n", | ||
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", | ||
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", | ||
"To disable this warning, you can either:\n", | ||
"\t- Avoid using `tokenizers` before the fork if possible\n", | ||
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", | ||
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", | ||
"To disable this warning, you can either:\n", | ||
"\t- Avoid using `tokenizers` before the fork if possible\n", | ||
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", | ||
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", | ||
"To disable this warning, you can either:\n", | ||
"\t- Avoid using `tokenizers` before the fork if possible\n", | ||
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", | ||
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", | ||
"To disable this warning, you can either:\n", | ||
"\t- Avoid using `tokenizers` before the fork if possible\n", | ||
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", | ||
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", | ||
"To disable this warning, you can either:\n", | ||
"\t- Avoid using `tokenizers` before the fork if possible\n", | ||
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", | ||
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", | ||
"To disable this warning, you can either:\n", | ||
"\t- Avoid using `tokenizers` before the fork if possible\n", | ||
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", | ||
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", | ||
"To disable this warning, you can either:\n", | ||
"\t- Avoid using `tokenizers` before the fork if possible\n", | ||
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", | ||
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", | ||
"To disable this warning, you can either:\n", | ||
"\t- Avoid using `tokenizers` before the fork if possible\n", | ||
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", | ||
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", | ||
"To disable this warning, you can either:\n", | ||
"\t- Avoid using `tokenizers` before the fork if possible\n", | ||
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", | ||
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", | ||
"To disable this warning, you can either:\n", | ||
"\t- Avoid using `tokenizers` before the fork if possible\n", | ||
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", | ||
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at avsolatorio/GIST-small-Embedding-v0 and are newly initialized: ['classifier.bias', 'classifier.weight']\n", | ||
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", | ||
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at avsolatorio/GIST-small-Embedding-v0 and are newly initialized: ['classifier.bias', 'classifier.weight']\n", | ||
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", | ||
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at avsolatorio/GIST-small-Embedding-v0 and are newly initialized: ['classifier.bias', 'classifier.weight']\n", | ||
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", | ||
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at avsolatorio/GIST-small-Embedding-v0 and are newly initialized: ['classifier.bias', 'classifier.weight']\n", | ||
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", | ||
"[I 2024-11-11 13:13:01,596] A new study created in memory with name: no-name-066e9d3e-65d1-45d8-b87b-29160b3a9f1f\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from autointent.pipeline.optimization import PipelineOptimizer\n", | ||
"from autointent.configs.optimization_cli import LoggingConfig, VectorIndexConfig, EmbedderConfig\n", | ||
"\n", | ||
"db_dir, dump_dir, logs_dir = setup_environment()\n", | ||
"search_space = get_search_space(task_type)\n", | ||
"\n", | ||
"pipeline_optimizer = PipelineOptimizer.from_dict_config(search_space)\n", | ||
"\n", | ||
"pipeline_optimizer.set_config(LoggingConfig(dirpath=Path(logs_dir).resolve(), dump_modules=True))\n", | ||
"pipeline_optimizer.set_config(VectorIndexConfig(db_dir=Path(db_dir).resolve(), device=\"cpu\", save_db=True))\n", | ||
"pipeline_optimizer.set_config(EmbedderConfig(batch_size=16, max_length=32))\n", | ||
"\n", | ||
"\n", | ||
"dataset = get_dataset()\n", | ||
"context = pipeline_optimizer.optimize_from_dataset(dataset, force_multilabel=(task_type == \"multilabel\"))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 13, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from autointent.pipeline.inference import InferencePipeline\n", | ||
"\n", | ||
"\n", | ||
"inference_pipeline = InferencePipeline.from_context(context)\n", | ||
"prediction = inference_pipeline.predict_with_metadata([\"123\", \"hello world\"])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 17, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[2, 2]\n", | ||
"[InferencePipelineUtteranceOutput(utterance='123', prediction=2, regexp_prediction=None, regexp_prediction_metadata=None, score=[0.0, 0.4, 0.6], score_metadata={'neighbors': ['set my alarm for getting up', 'wake me up at noon tomorrow', 'i am nost sure why my account is blocked', 'i think my account is blocked but i do not know the reason', 'please set an alarm for mid day']}),\n", | ||
" InferencePipelineUtteranceOutput(utterance='hello world', prediction=2, regexp_prediction=None, regexp_prediction_metadata=None, score=[0.0, 0.4, 0.6], score_metadata={'neighbors': ['wake me up at noon tomorrow', 'set my alarm for getting up', 'please set an alarm for mid day', 'why is there a hold on my american saving bank account', 'i am nost sure why my account is blocked']})]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from pprint import pprint\n", | ||
"\n", | ||
"pprint(prediction.predictions)\n", | ||
"pprint(prediction.utterances)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"if task_type == \"multilabel\":\n", | ||
" assert prediction.shape == (2, len(dataset.intents))\n", | ||
"else:\n", | ||
" assert prediction.shape == (2,)\n", | ||
"\n", | ||
"context.dump()\n", | ||
"context.vector_index_client.delete_db()" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "autointent-D7M6VOhJ-py3.12", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.