Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/refactor optimization logs #12

Merged
merged 5 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions autointent/context/context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .data_handler import DataHandler
from .optimization_logs import OptimizationLogs
from .optimization_info import OptimizationInfo
from .vector_index import VectorIndex


Expand All @@ -25,7 +25,7 @@ def __init__(
regex_sampling,
seed,
)
self.optimization_logs = OptimizationLogs()
self.optimization_info = OptimizationInfo()
self.vector_index = VectorIndex(db_dir, device, self.data_handler.multilabel, self.data_handler.n_classes)

self.device = device
Expand All @@ -34,5 +34,5 @@ def __init__(
self.seed = seed

def get_best_collection(self):
model_name = self.optimization_logs.get_best_embedder()
model_name = self.optimization_info.get_best_embedder()
return self.vector_index.get_collection(model_name)
2 changes: 2 additions & 0 deletions autointent/context/optimization_info/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .data_models import RetrieverArtifact, ScorerArtifact
from .optimization_info import OptimizationInfo
91 changes: 91 additions & 0 deletions autointent/context/optimization_info/data_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import Any

import numpy as np
from numpy.typing import NDArray
from pydantic import BaseModel, ConfigDict, Field


class Artifact(BaseModel):
...


class RegexpArtifact(Artifact):
...


class RetrieverArtifact(Artifact):
"""
Name of the embedding model chosen after retrieval optimization
"""
embedder_name: str


class ScorerArtifact(Artifact):
"""
Outputs from best scorer, numpy arrays of shape (n_samples, n_classes)
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
test_scores: NDArray[np.float64] | None = Field(None, description="Scorer outputs for test utterances")
oos_scores: NDArray[np.float64] | None = Field(None, description="Scorer outputs for out-of-scope utterances")


class PredictorArtifact(Artifact):
"""
Outputs from best predictor, numpy array of shape (n_samples,) or
(n_samples, n_classes) depending on classification mode (multi-class or multi-label)
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
labels: NDArray[np.float64]


class Artifacts(BaseModel):
"""
Modules hyperparams and outputs. The best ones are transmitted between nodes of the pipeline
"""
model_config = ConfigDict(arbitrary_types_allowed=True)

regexp: list[RegexpArtifact] = []
retrieval: list[RetrieverArtifact] = []
scoring: list[ScorerArtifact] = []
prediction: list[PredictorArtifact] = []

def __getitem__(self, node_type: str) -> list:
return getattr(self, node_type)


class Trial(BaseModel):
"""
Detailed representation of one optimization trial
"""
module_type: str
module_params: dict[str, Any]
metric_name: str
metric_value: float

class Trials(BaseModel):
"""
Detailed representation of optimization results
"""
regexp: list[Trial] = []
retrieval: list[Trial] = []
scoring: list[Trial] = []
prediction: list[Trial] = []

def __getitem__(self, node_type: str) -> list[Trial]:
return getattr(self, node_type)


class TrialsIds(BaseModel):
"""
Detailed representation of optimization results
"""
regexp: int | None = None
retrieval: int | None = None
scoring: int | None = None
prediction: int | None = None

def __getitem__(self, node_type: str) -> list[Trial]:
return getattr(self, node_type)

def __setitem__(self, node_type: str, idx: int) -> None:
setattr(self, node_type, idx)
104 changes: 104 additions & 0 deletions autointent/context/optimization_info/optimization_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import logging
from pprint import pformat

import numpy as np

from .data_models import Artifact, Artifacts, NDArray, RetrieverArtifact, ScorerArtifact, Trial, Trials, TrialsIds


class OptimizationInfo:
"""TODO continous IO with file system (to be able to restore the state of optimization)"""

def __init__(self):
self._logger = get_logger()

self.artifacts = Artifacts()
self.trials = Trials()
self._trials_best_ids = TrialsIds()

def log_module_optimization(
self,
node_type: str,
module_type: str,
module_params: dict,
metric_value: float,
metric_name: str,
artifact: Artifact,
):
"""
Purposes:
- save optimization results in a text form (hyperparameters and corresponding metrics)
- update best assets
"""

# save trial
trial = Trial(
module_type=module_type,
metric_name=metric_name,
metric_value=metric_value,
module_params=module_params,
)
self.trials[node_type].append(trial)
self._logger.info(trial.model_dump())

# save artifact
self.artifacts[node_type].append(artifact)

def _get_metrics_values(self, node_type: str) -> list[float]:
return [trial.metric_value for trial in self.trials[node_type]]

def _get_best_trial_idx(self, node_type: str) -> int:
res = self._trials_best_ids[node_type]
if res is not None:
return res
res = np.argmax(self._get_metrics_values(node_type))
self._trials_best_ids[node_type] = res
return res

def _get_best_artifact(self, node_type: str) -> Artifact:
i_best = self._get_best_trial_idx(node_type)
return self.artifacts[node_type][i_best]

def get_best_embedder(self) -> str:
best_retriever_artifact: RetrieverArtifact = self._get_best_artifact(node_type="retrieval")
return best_retriever_artifact.embedder_name

def get_best_test_scores(self) -> NDArray[np.float64]:
best_scorer_artifact: ScorerArtifact = self._get_best_artifact(node_type="scoring")
return best_scorer_artifact.test_scores

def get_best_oos_scores(self) -> NDArray[np.float64]:
best_scorer_artifact: ScorerArtifact = self._get_best_artifact(node_type="scoring")
return best_scorer_artifact.oos_scores

def dump(self):
node_wise_metrics = {
node_type: self._get_metrics_values(node_type)
for node_type in ["regexp", "retrieval", "scoring", "prediction"]
}
return {
"metrics": node_wise_metrics,
"configs": self.trials.model_dump(),
}

def get_logger() -> logging.Logger:
voorhs marked this conversation as resolved.
Show resolved Hide resolved
logger = logging.getLogger(__name__)

formatter = PPrintFormatter()
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)

return logger


class PPrintFormatter(logging.Formatter):
def __init__(self):
super().__init__(fmt="{asctime} - {name} - {levelname} - {message}", style="{")

def format(self, record):
if isinstance(record.msg, dict):
format_msg = "module scoring results:\n"
dct_to_str = pformat(record.msg)
record.msg = format_msg + dct_to_str
return super().format(record)
91 changes: 0 additions & 91 deletions autointent/context/optimization_logs.py

This file was deleted.

4 changes: 2 additions & 2 deletions autointent/modules/prediction/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def clear_cache(self):

def get_prediction_evaluation_data(context: Context):
labels = context.data_handler.labels_test
scores = context.optimization_logs.get_best_test_scores()
scores = context.optimization_info.get_best_test_scores()

oos_scores = context.optimization_logs.get_best_oos_scores()
oos_scores = context.optimization_info.get_best_oos_scores()
if oos_scores is not None:
oos_labels = [[0] * context.n_classes] * len(oos_scores) if context.multilabel else [-1] * len(oos_scores)
labels = np.concatenate([labels, oos_labels])
Expand Down
7 changes: 4 additions & 3 deletions autointent/modules/prediction/tunable.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from optuna.trial import Trial
from sklearn.metrics import f1_score

from .base import Context, PredictionModule
from .base import Context, PredictionModule, get_prediction_evaluation_data
from .threshold import multiclass_predict, multilabel_predict


Expand All @@ -24,9 +24,10 @@ def fit(self, context: Context):
)

thresh_optimizer = ThreshOptimizer(n_classes=context.n_classes, multilabel=context.multilabel)
labels, scores = get_prediction_evaluation_data(context)
thresh_optimizer.fit(
probas=context.optimization_logs.get_best_test_scores(),
labels=context.data_handler.labels_test,
probas=scores,
labels=labels,
seed=context.seed,
tags=self.tags,
)
Expand Down
5 changes: 3 additions & 2 deletions autointent/modules/retrieval/vectordb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from chromadb import Collection

from autointent.context import Context
from autointent.context.optimization_info import RetrieverArtifact
from autointent.metrics import RetrievalMetricFn

from .base import RetrievalModule
Expand All @@ -26,8 +27,8 @@ def score(self, context: Context, metric_fn: RetrievalMetricFn) -> tuple[float,
)
return metric_fn(context.data_handler.labels_test, labels_pred)

def get_assets(self):
return self.model_name
def get_assets(self) -> RetrieverArtifact:
return RetrieverArtifact(embedder_name=self.model_name)

def clear_cache(self):
model = self.collection._embedding_function._model # noqa: SLF001
Expand Down
5 changes: 3 additions & 2 deletions autointent/modules/scoring/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np

from autointent.context.optimization_info import ScorerArtifact
from autointent.metrics import ScoringMetricFn
from autointent.modules.base import Context, Module

Expand All @@ -21,8 +22,8 @@ def score(self, context: Context, metric_fn: ScoringMetricFn) -> tuple[float, np
self._oos_scores = self.predict(context.data_handler.oos_utterances)
return res

def get_assets(self):
return {"test_scores": self._test_scores, "oos_scores": self._oos_scores}
def get_assets(self) -> ScorerArtifact:
return ScorerArtifact(test_scores=self._test_scores, oos_scores=self._oos_scores)

@abstractmethod
def predict(self, utterances: list[str]):
Expand Down
2 changes: 1 addition & 1 deletion autointent/nodes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def fit(self, context: Context):
metric_value = module.score(context, self.metrics_available[self.metric_name])

assets = module.get_assets()
context.optimization_logs.log_module_optimization(
context.optimization_info.log_module_optimization(
self.node_type,
module_type,
module_kwargs,
Expand Down
Loading