diff --git a/autointent/modules/regexp.py b/autointent/modules/regexp.py index e1cb2fa8..ea2156b2 100644 --- a/autointent/modules/regexp.py +++ b/autointent/modules/regexp.py @@ -7,6 +7,7 @@ from autointent import Context from autointent.context.data_handler.data_handler import RegexPatterns +from autointent.context.data_handler.schemas import Intent from autointent.context.optimization_info.data_models import Artifact from autointent.custom_types import LabelType from autointent.metrics.regexp import RegexpMetricFn @@ -21,14 +22,20 @@ class RegexPatternsCompiled(TypedDict): class RegExp(Module): - regexp_patterns: list[RegexPatterns] - regexp_patterns_compiled: list[RegexPatternsCompiled] - @classmethod def from_context(cls, context: Context) -> Self: return cls() - def fit(self) -> None: + def fit(self, intents: list[dict[str, Any]]) -> None: + intents_parsed = [Intent(**dct) for dct in intents] + self.regexp_patterns = [ + RegexPatterns( + id=intent.id, + regexp_full_match=intent.regexp_full_match, + regexp_partial_match=intent.regexp_partial_match, + ) + for intent in intents_parsed + ] self._compile_regex_patterns() def predict(self, utterances: list[str]) -> list[LabelType]: diff --git a/tests/modules/test_regex.py b/tests/modules/test_regex.py index 8138651c..2a0c0b0a 100644 --- a/tests/modules/test_regex.py +++ b/tests/modules/test_regex.py @@ -1,59 +1,11 @@ -import pytest +from autointent.modules import RegExp +from tests.conftest import setup_environment -from autointent import Context -from autointent.context.data_handler import Dataset -from autointent.metrics import retrieval_hit_rate, scoring_roc_auc -from autointent.modules import RegExp, VectorDBModule +def test_base_regex(): + db_dir, dump_dir, logs_dir = setup_environment() -@pytest.mark.xfail(reason="Issues with intent_id") -def test_base_regex(setup_environment): - db_dir, dump_dir, logs_dir = setup_environment - - data = { - "utterances": [ - { - "text": "can i make a reservation for redrobin", - "label": 0, - }, - { - "text": "is it possible to make a reservation at redrobin", - "label": 0, - }, - { - "text": "does redrobin take reservations", - "label": 0, - }, - { - "text": "are reservations taken at redrobin", - "label": 0, - }, - { - "text": "does redrobin do reservations", - "label": 0, - }, - { - "text": "why is there a hold on my american saving bank account", - "label": 1, - }, - { - "text": "i am nost sure why my account is blocked", - "label": 1, - }, - { - "text": "why is there a hold on my capital one checking account", - "label": 1, - }, - { - "text": "i think my account is blocked but i do not know the reason", - "label": 1, - }, - { - "text": "can you tell me why is my bank account frozen", - "label": 1, - }, - ], - "intents": [ + train_data = [ { "id": 0, "name": "accept_reservations", @@ -66,35 +18,11 @@ def test_base_regex(setup_environment): "regexp_full_match": [".*"], "regexp_partial_match": [".*"], }, - ], - } - - context = Context( - dataset=Dataset.model_validate(data), - dump_dir=dump_dir, - db_dir=db_dir(), - ) - - retrieval_params = {"k": 3, "model_name": "sergeyzh/rubert-tiny-turbo"} - vector_db = VectorDBModule(**retrieval_params) - vector_db.fit(context) - metric_value = vector_db.score(context, retrieval_hit_rate) - artifact = vector_db.get_assets() - context.optimization_info.log_module_optimization( - node_type="retrieval", - module_type="vector_db", - module_params=retrieval_params, - metric_value=metric_value, - metric_name="retrieval_hit_rate_macro", - artifact=artifact, - module_dump_dir="", - ) + ] - scorer = RegExp() + matcher = RegExp() + matcher.fit(train_data) - scorer.fit(context) - score, _ = scorer.score(context, scoring_roc_auc) - assert score == 0.5 test_data = [ "why is there a hold on my american saving bank account", "i am nost sure why my account is blocked", @@ -102,5 +30,5 @@ def test_base_regex(setup_environment): "i think my account is blocked but i do not know the reason", "can you tell me why is my bank account frozen", ] - predictions = scorer.predict(test_data) + predictions = matcher.predict(test_data) assert predictions == [[0, 1], [0, 1], [0, 1], [0, 1], [0, 1]]