Skip to content

Commit

Permalink
Add enhanced output (#38)
Browse files Browse the repository at this point in the history
Co-authored-by: voorhs <[email protected]>
Co-authored-by: Roman Solomatin <[email protected]>
Co-authored-by: Darinka <[email protected]>
  • Loading branch information
4 people authored Nov 12, 2024
1 parent ad097e8 commit de779c0
Show file tree
Hide file tree
Showing 29 changed files with 694 additions and 255 deletions.
1 change: 1 addition & 0 deletions autointent/configs/inference_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class InferenceConfig:
source_dir: str
output_path: str
log_level: LogLevel = LogLevel.ERROR
with_metadata: bool = False


cs = ConfigStore.instance()
Expand Down
5 changes: 4 additions & 1 deletion autointent/context/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,8 @@ def embed(self, utterances: list[str]) -> npt.NDArray[np.float32]:
if self.max_length is not None:
self.embedding_model.max_seq_length = self.max_length
return self.embedding_model.encode(
utterances, convert_to_numpy=True, batch_size=self.batch_size, normalize_embeddings=True,
utterances,
convert_to_numpy=True,
batch_size=self.batch_size,
normalize_embeddings=True,
) # type: ignore[return-value]
7 changes: 7 additions & 0 deletions autointent/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ def load(self, path: str) -> None:
def predict(self, *args: list[str] | npt.NDArray[Any], **kwargs: dict[str, Any]) -> npt.NDArray[Any]:
"""inference"""

def predict_with_metadata(
self,
*args: list[str] | npt.NDArray[Any],
**kwargs: dict[str, Any],
) -> tuple[npt.NDArray[Any], list[dict[str, Any]] | None]:
return self.predict(*args, **kwargs), None

@classmethod
@abstractmethod
def from_context(cls, context: Context, **kwargs: dict[str, Any]) -> Self:
Expand Down
108 changes: 75 additions & 33 deletions autointent/modules/regexp.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import json
import re
from typing import TypedDict
from pathlib import Path
from typing import Any, TypedDict

from typing_extensions import Self

from autointent import Context
from autointent.context.data_handler.data_handler import RegexPatterns
from autointent.context.data_handler.schemas import Intent
from autointent.context.optimization_info.data_models import Artifact
from autointent.custom_types import LabelType
from autointent.metrics.regexp import RegexpMetricFn
Expand All @@ -19,43 +22,60 @@ class RegexPatternsCompiled(TypedDict):


class RegExp(Module):
name = "regexp"

def __init__(self, regexp_patterns: list[RegexPatterns]) -> None:
self.regexp_patterns = regexp_patterns

@classmethod
def from_context(cls, context: Context) -> Self:
return cls(
regexp_patterns=context.data_handler.regexp_patterns,
)

def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
self.regexp_patterns_compiled: list[RegexPatternsCompiled] = [
{
"id": dct["id"],
"regexp_full_match": [re.compile(ptn, flags=re.IGNORECASE) for ptn in dct["regexp_full_match"]],
"regexp_partial_match": [re.compile(ptn, flags=re.IGNORECASE) for ptn in dct["regexp_partial_match"]],
}
for dct in self.regexp_patterns
return cls()

def fit(self, intents: list[dict[str, Any]]) -> None:
intents_parsed = [Intent(**dct) for dct in intents]
self.regexp_patterns = [
RegexPatterns(
id=intent.id,
regexp_full_match=intent.regexp_full_match,
regexp_partial_match=intent.regexp_partial_match,
)
for intent in intents_parsed
]
self._compile_regex_patterns()

def predict(self, utterances: list[str]) -> list[LabelType]:
return [list(self._predict_single(ut)) for ut in utterances]

def _match(self, text: str, intent_record: RegexPatternsCompiled) -> bool:
full_match = any(ptn.fullmatch(text) for ptn in intent_record["regexp_full_match"])
if full_match:
return True
return any(ptn.match(text) for ptn in intent_record["regexp_partial_match"])
return [self._predict_single(utterance)[0] for utterance in utterances]

def predict_with_metadata(
self,
utterances: list[str],
) -> tuple[list[LabelType], list[dict[str, Any]] | None]:
predictions, metadata = [], []
for utterance in utterances:
prediction, matches = self._predict_single(utterance)
predictions.append(prediction)
metadata.append(matches)
return predictions, metadata

def _match(self, utterance: str, intent_record: RegexPatternsCompiled) -> dict[str, list[str]]:
full_matches = [
pattern.pattern
for pattern in intent_record["regexp_full_match"]
if pattern.fullmatch(utterance) is not None
]
partial_matches = [
pattern.pattern
for pattern in intent_record["regexp_partial_match"]
if pattern.search(utterance) is not None
]
return {"full_matches": full_matches, "partial_matches": partial_matches}

def _predict_single(self, utterance: str) -> set[int]:
def _predict_single(self, utterance: str) -> tuple[LabelType, dict[str, list[str]]]:
# todo test this
return {
intent_record["id"]
for intent_record in self.regexp_patterns_compiled
if self._match(utterance, intent_record)
}
prediction = set()
matches: dict[str, list[str]] = {"full_matches": [], "partial_matches": []}
for intent_record in self.regexp_patterns_compiled:
intent_matches = self._match(utterance, intent_record)
if intent_matches["full_matches"] or intent_matches["partial_matches"]:
prediction.add(intent_record["id"])
matches["full_matches"].extend(intent_matches["full_matches"])
matches["partial_matches"].extend(intent_matches["partial_matches"])
return list(prediction), matches

def score(self, context: Context, metric_fn: RegexpMetricFn) -> float:
# TODO add parameter to a whole pipeline (or just to regexp module):
Expand All @@ -78,7 +98,29 @@ def get_assets(self) -> Artifact:
return Artifact()

def load(self, path: str) -> None:
pass
dump_dir = Path(path)

with (dump_dir / self.metadata_dict_name).open() as file:
self.regexp_patterns = json.load(file)

self._compile_regex_patterns()

def dump(self, path: str) -> None:
pass
dump_dir = Path(path)

with (dump_dir / self.metadata_dict_name).open("w") as file:
json.dump(self.regexp_patterns, file, indent=4)

def _compile_regex_patterns(self) -> None:
self.regexp_patterns_compiled = [
RegexPatternsCompiled(
id=regexp_patterns["id"],
regexp_full_match=[
re.compile(pattern, flags=re.IGNORECASE) for pattern in regexp_patterns["regexp_full_match"]
],
regexp_partial_match=[
re.compile(ptn, flags=re.IGNORECASE) for ptn in regexp_patterns["regexp_partial_match"]
],
)
for regexp_patterns in self.regexp_patterns
]
10 changes: 7 additions & 3 deletions autointent/modules/retrieval/vectordb.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,9 @@ def __init__(
batch_size: int = 32,
max_length: int | None = None,
) -> None:
if db_dir is None:
db_dir = str(get_db_dir())
self.embedder_name = embedder_name
self.device = device
self.db_dir = db_dir
self._db_dir = db_dir
self.batch_size = batch_size
self.max_length = max_length

Expand All @@ -58,6 +56,12 @@ def from_context(
max_length=context.get_max_length(),
)

@property
def db_dir(self) -> str:
if self._db_dir is None:
self._db_dir = str(get_db_dir())
return self._db_dir

def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
vector_index_client = VectorIndexClient(
self.device, self.db_dir, embedder_batch_size=self.batch_size, embedder_max_length=self.max_length
Expand Down
8 changes: 2 additions & 6 deletions autointent/modules/scoring/description/description.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from autointent.context import Context
from autointent.context.embedder import Embedder
from autointent.context.vector_index_client import VectorIndex, VectorIndexClient
from autointent.context.vector_index_client.cache import get_db_dir
from autointent.custom_types import LabelType
from autointent.modules.scoring.base import ScoringModule

Expand All @@ -30,22 +29,19 @@ class DescriptionScorer(ScoringModule):
precomputed_embeddings: bool = False
embedding_model_subdir: str = "embedding_model"
_vector_index: VectorIndex
db_dir: str
name = "description"

def __init__(
self,
embedder_name: str,
db_dir: Path | None = None,
temperature: float = 1.0,
device: str = "cpu",
batch_size: int = 32,
max_length: int | None = None,
) -> None:
if db_dir is None:
db_dir = get_db_dir()
self.temperature = temperature
self.device = device
self.db_dir = db_dir
self.embedder_name = embedder_name
self.batch_size = batch_size
self.max_length = max_length
Expand All @@ -66,10 +62,10 @@ def from_context(
instance = cls(
temperature=temperature,
device=context.get_device(),
db_dir=context.get_db_dir(),
embedder_name=embedder_name,
)
instance.precomputed_embeddings = precomputed_embeddings
instance.db_dir = str(context.get_db_dir())
return instance

def get_embedder_name(self) -> str:
Expand Down
47 changes: 31 additions & 16 deletions autointent/modules/scoring/dnnc/dnnc.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,21 @@ def __init__(
batch_size: int = 32,
max_length: int | None = None,
) -> None:
if db_dir is None:
db_dir = str(get_db_dir())

self.cross_encoder_name = cross_encoder_name
self.embedder_name = embedder_name
self.k = k
self.train_head = train_head
self.device = device
self.db_dir = db_dir
self._db_dir = db_dir
self.batch_size = batch_size
self.max_length = max_length

@property
def db_dir(self) -> str:
if self._db_dir is None:
self._db_dir = str(get_db_dir())
return self._db_dir

@classmethod
def from_context(
cls,
Expand Down Expand Up @@ -114,19 +117,18 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
self.model = model

def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
"""
Return
---
`(n_queries, n_classes)` matrix with zeros everywhere except the class of the best neighbor utterance
"""
labels, _, texts = self.vector_index.query(
utterances,
self.k,
)
return self._predict(utterances)[0]

cross_encoder_scores = self._get_cross_encoder_scores(utterances, texts)

return self._build_result(cross_encoder_scores, labels)
def predict_with_metadata(
self,
utterances: list[str],
) -> tuple[npt.NDArray[Any], list[dict[str, Any]] | None]:
scores, neighbors, neighbors_scores = self._predict(utterances)
metadata = [
{"neighbors": utterance_neighbors, "scores": utterance_neighbors_scores}
for utterance_neighbors, utterance_neighbors_scores in zip(neighbors, neighbors_scores, strict=True)
]
return scores, metadata

def _get_cross_encoder_scores(self, utterances: list[str], candidates: list[list[str]]) -> list[list[float]]:
"""
Expand Down Expand Up @@ -215,6 +217,19 @@ def load(self, path: str) -> None:
else:
self.model = CrossEncoder(crossencoder_dir, device=self.device)

def _predict(
self,
utterances: list[str],
) -> tuple[npt.NDArray[Any], list[list[str]], list[list[float]]]:
labels, _, neigbors = self.vector_index.query(
utterances,
self.k,
)

cross_encoder_scores = self._get_cross_encoder_scores(utterances, neigbors)

return self._build_result(cross_encoder_scores, labels), neigbors, cross_encoder_scores


def build_result(scores: npt.NDArray[Any], labels: npt.NDArray[Any], n_classes: int) -> npt.NDArray[Any]:
res = np.zeros((len(scores), n_classes))
Expand Down
26 changes: 21 additions & 5 deletions autointent/modules/scoring/knn/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,20 @@ def __init__(
- closest: each sample has a non zero weight iff is the closest sample of some class
- `device`: str, something like "cuda:0" or "cuda:0,1,2", a device to store embedding function
"""
if db_dir is None:
db_dir = str(get_db_dir())
self.embedder_name = embedder_name
self.k = k
self.weights = weights
self.db_dir = db_dir
self._db_dir = db_dir
self.device = device
self.batch_size = batch_size
self.max_length = max_length

@property
def db_dir(self) -> str:
if self._db_dir is None:
self._db_dir = str(get_db_dir())
return self._db_dir

@classmethod
def from_context(
cls,
Expand Down Expand Up @@ -107,8 +111,15 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
self._vector_index = vector_index_client.create_index(self.embedder_name, utterances, labels)

def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
labels, distances, _ = self._vector_index.query(utterances, self.k)
return apply_weights(np.array(labels), np.array(distances), self.weights, self.n_classes, self.multilabel)
return self._predict(utterances)[0]

def predict_with_metadata(
self,
utterances: list[str],
) -> tuple[npt.NDArray[Any], list[dict[str, Any]] | None]:
scores, neighbors = self._predict(utterances)
metadata = [{"neighbors": utterance_neighbors} for utterance_neighbors in neighbors]
return scores, metadata

def clear_cache(self) -> None:
self._vector_index.clear_ram()
Expand Down Expand Up @@ -145,3 +156,8 @@ def load(self, path: str) -> None:
embedder_max_length=self.metadata["max_length"],
)
self._vector_index = vector_index_client.get_index(self.embedder_name)

def _predict(self, utterances: list[str]) -> tuple[npt.NDArray[Any], list[list[str]]]:
labels, distances, neigbors = self._vector_index.query(utterances, self.k)
scores = apply_weights(np.array(labels), np.array(distances), self.weights, self.n_classes, self.multilabel)
return scores, neigbors
Loading

0 comments on commit de779c0

Please sign in to comment.