Skip to content

Commit

Permalink
refactor cli and dump more metadata into logs
Browse files Browse the repository at this point in the history
  • Loading branch information
voorhs committed Oct 1, 2024
1 parent 34ff962 commit f4807c8
Show file tree
Hide file tree
Showing 9 changed files with 21 additions and 7 deletions.
13 changes: 13 additions & 0 deletions autointent/context/context.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

from .data_handler import DataHandler
from .optimization_info import OptimizationInfo
from .vector_index import VectorIndex
Expand Down Expand Up @@ -36,3 +38,14 @@ def __init__(
def get_best_collection(self):
model_name = self.optimization_info.get_best_embedder()
return self.vector_index.get_collection(model_name)

def dump(self) -> dict[str, Any]:
optim_logs = self.optimization_info.dump()
optim_logs["metadata"] = {
"device": self.device,
"multilabel": self.multilabel,
"n_classes": self.n_classes,
"seed": self.seed,
"db_dir": self.vector_index.db_dir
}
return optim_logs
4 changes: 2 additions & 2 deletions autointent/context/vector_index.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import os

from chromadb import PersistentClient
from chromadb.config import Settings
Expand All @@ -9,9 +8,10 @@


class VectorIndex:
def __init__(self, db_dir: os.PathLike, device: str, multilabel: bool, n_classes: int):
def __init__(self, db_dir: str, device: str, multilabel: bool, n_classes: int):
self._logger = logging.getLogger(__name__)

self.db_dir = db_dir
self.device = device
self.multilabel = multilabel
self.n_classes = n_classes
Expand Down
Empty file.
File renamed without changes.
2 changes: 1 addition & 1 deletion autointent/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def optimize(self, context: Context):

def dump(self, logs_dir: str, run_name: str):
self._logger.debug("dumping logs...")
optimization_results = self.context.optimization_info.dump()
optimization_results = self.context.dump()

# create appropriate directory
logs_dir = Path.cwd() if logs_dir == "" else Path(logs_dir)
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ optional = true
ruff = "^0.6.8"

[tool.poetry.scripts]
"autointent" = "autointent.pipeline.main:main"
"autointent" = "autointent.pipeline.optimization:main"
"autointent-inference" = "autointent.pipeline.inference:main"
"clear-cache" = "autointent.pipeline.utils.cache:clear_chroma_cache"

[tool.ruff]
Expand Down
2 changes: 1 addition & 1 deletion tests/minimal-optimization/test_minimal_optimization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from autointent import Context
from autointent.pipeline import Pipeline
from autointent.pipeline.main import get_db_dir, get_run_name, load_data, setup_logging
from autointent.pipeline.optimization import get_db_dir, get_run_name, load_data, setup_logging


def test_multiclass():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_modules/test_scoring/test_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from autointent import Context
from autointent.metrics import retrieval_hit_rate, scoring_roc_auc
from autointent.modules import KNNScorer, VectorDBModule
from autointent.pipeline.main import get_db_dir, get_run_name, load_data, setup_logging
from autointent.pipeline.optimization import get_db_dir, get_run_name, load_data, setup_logging


def test_base_knn():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_modules/test_scoring/test_mlknn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from autointent.metrics import retrieval_hit_rate_macro, scoring_f1
from autointent.modules import VectorDBModule
from autointent.modules.scoring.mlknn.mlknn import MLKnnScorer
from autointent.pipeline.main import get_db_dir, get_run_name, load_data, setup_logging
from autointent.pipeline.optimization import get_db_dir, get_run_name, load_data, setup_logging


def test_base_mlknn():
Expand Down

0 comments on commit f4807c8

Please sign in to comment.