Skip to content

Commit

Permalink
small test and minor change
Browse files Browse the repository at this point in the history
  • Loading branch information
voorhs committed Nov 11, 2024
1 parent 4e41978 commit 7dbdaa0
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 85 deletions.
15 changes: 11 additions & 4 deletions autointent/modules/regexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from autointent import Context
from autointent.context.data_handler.data_handler import RegexPatterns
from autointent.context.data_handler.schemas import Intent
from autointent.context.optimization_info.data_models import Artifact
from autointent.custom_types import LabelType
from autointent.metrics.regexp import RegexpMetricFn
Expand All @@ -21,14 +22,20 @@ class RegexPatternsCompiled(TypedDict):


class RegExp(Module):
regexp_patterns: list[RegexPatterns]
regexp_patterns_compiled: list[RegexPatternsCompiled]

@classmethod
def from_context(cls, context: Context) -> Self:
return cls()

def fit(self) -> None:
def fit(self, intents: list[dict[str, Any]]) -> None:
intents_parsed = [Intent(**dct) for dct in intents]
self.regexp_patterns = [
RegexPatterns(
id=intent.id,
regexp_full_match=intent.regexp_full_match,
regexp_partial_match=intent.regexp_partial_match,
)
for intent in intents_parsed
]
self._compile_regex_patterns()

def predict(self, utterances: list[str]) -> list[LabelType]:
Expand Down
90 changes: 9 additions & 81 deletions tests/modules/test_regex.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,11 @@
import pytest
from autointent.modules import RegExp
from tests.conftest import setup_environment

from autointent import Context
from autointent.context.data_handler import Dataset
from autointent.metrics import retrieval_hit_rate, scoring_roc_auc
from autointent.modules import RegExp, VectorDBModule

def test_base_regex():
db_dir, dump_dir, logs_dir = setup_environment()

@pytest.mark.xfail(reason="Issues with intent_id")
def test_base_regex(setup_environment):
db_dir, dump_dir, logs_dir = setup_environment

data = {
"utterances": [
{
"text": "can i make a reservation for redrobin",
"label": 0,
},
{
"text": "is it possible to make a reservation at redrobin",
"label": 0,
},
{
"text": "does redrobin take reservations",
"label": 0,
},
{
"text": "are reservations taken at redrobin",
"label": 0,
},
{
"text": "does redrobin do reservations",
"label": 0,
},
{
"text": "why is there a hold on my american saving bank account",
"label": 1,
},
{
"text": "i am nost sure why my account is blocked",
"label": 1,
},
{
"text": "why is there a hold on my capital one checking account",
"label": 1,
},
{
"text": "i think my account is blocked but i do not know the reason",
"label": 1,
},
{
"text": "can you tell me why is my bank account frozen",
"label": 1,
},
],
"intents": [
train_data = [
{
"id": 0,
"name": "accept_reservations",
Expand All @@ -66,41 +18,17 @@ def test_base_regex(setup_environment):
"regexp_full_match": [".*"],
"regexp_partial_match": [".*"],
},
],
}

context = Context(
dataset=Dataset.model_validate(data),
dump_dir=dump_dir,
db_dir=db_dir(),
)

retrieval_params = {"k": 3, "model_name": "sergeyzh/rubert-tiny-turbo"}
vector_db = VectorDBModule(**retrieval_params)
vector_db.fit(context)
metric_value = vector_db.score(context, retrieval_hit_rate)
artifact = vector_db.get_assets()
context.optimization_info.log_module_optimization(
node_type="retrieval",
module_type="vector_db",
module_params=retrieval_params,
metric_value=metric_value,
metric_name="retrieval_hit_rate_macro",
artifact=artifact,
module_dump_dir="",
)
]

scorer = RegExp()
matcher = RegExp()
matcher.fit(train_data)

scorer.fit(context)
score, _ = scorer.score(context, scoring_roc_auc)
assert score == 0.5
test_data = [
"why is there a hold on my american saving bank account",
"i am nost sure why my account is blocked",
"why is there a hold on my capital one checking account",
"i think my account is blocked but i do not know the reason",
"can you tell me why is my bank account frozen",
]
predictions = scorer.predict(test_data)
predictions = matcher.predict(test_data)
assert predictions == [[0, 1], [0, 1], [0, 1], [0, 1], [0, 1]]

0 comments on commit 7dbdaa0

Please sign in to comment.