From f4807c840a49dd2a150ed400a579627c0c96f48a Mon Sep 17 00:00:00 2001 From: voorhs Date: Tue, 1 Oct 2024 17:52:27 +0300 Subject: [PATCH] refactor cli and dump more metadata into logs --- autointent/context/context.py | 13 +++++++++++++ autointent/context/vector_index.py | 4 ++-- autointent/pipeline/inference.py | 0 autointent/pipeline/{main.py => optimization.py} | 0 autointent/pipeline/pipeline.py | 2 +- pyproject.toml | 3 ++- .../test_minimal_optimization.py | 2 +- tests/test_modules/test_scoring/test_knn.py | 2 +- tests/test_modules/test_scoring/test_mlknn.py | 2 +- 9 files changed, 21 insertions(+), 7 deletions(-) create mode 100644 autointent/pipeline/inference.py rename autointent/pipeline/{main.py => optimization.py} (100%) diff --git a/autointent/context/context.py b/autointent/context/context.py index a8d8a65..0e6d6e3 100644 --- a/autointent/context/context.py +++ b/autointent/context/context.py @@ -1,3 +1,5 @@ +from typing import Any + from .data_handler import DataHandler from .optimization_info import OptimizationInfo from .vector_index import VectorIndex @@ -36,3 +38,14 @@ def __init__( def get_best_collection(self): model_name = self.optimization_info.get_best_embedder() return self.vector_index.get_collection(model_name) + + def dump(self) -> dict[str, Any]: + optim_logs = self.optimization_info.dump() + optim_logs["metadata"] = { + "device": self.device, + "multilabel": self.multilabel, + "n_classes": self.n_classes, + "seed": self.seed, + "db_dir": self.vector_index.db_dir + } + return optim_logs diff --git a/autointent/context/vector_index.py b/autointent/context/vector_index.py index 1739e8f..a86d651 100644 --- a/autointent/context/vector_index.py +++ b/autointent/context/vector_index.py @@ -1,5 +1,4 @@ import logging -import os from chromadb import PersistentClient from chromadb.config import Settings @@ -9,9 +8,10 @@ class VectorIndex: - def __init__(self, db_dir: os.PathLike, device: str, multilabel: bool, n_classes: int): + def __init__(self, db_dir: str, device: str, multilabel: bool, n_classes: int): self._logger = logging.getLogger(__name__) + self.db_dir = db_dir self.device = device self.multilabel = multilabel self.n_classes = n_classes diff --git a/autointent/pipeline/inference.py b/autointent/pipeline/inference.py new file mode 100644 index 0000000..e69de29 diff --git a/autointent/pipeline/main.py b/autointent/pipeline/optimization.py similarity index 100% rename from autointent/pipeline/main.py rename to autointent/pipeline/optimization.py diff --git a/autointent/pipeline/pipeline.py b/autointent/pipeline/pipeline.py index ac50f59..14373d8 100644 --- a/autointent/pipeline/pipeline.py +++ b/autointent/pipeline/pipeline.py @@ -42,7 +42,7 @@ def optimize(self, context: Context): def dump(self, logs_dir: str, run_name: str): self._logger.debug("dumping logs...") - optimization_results = self.context.optimization_info.dump() + optimization_results = self.context.dump() # create appropriate directory logs_dir = Path.cwd() if logs_dir == "" else Path(logs_dir) diff --git a/pyproject.toml b/pyproject.toml index 9705f55..fc9df85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,8 @@ optional = true ruff = "^0.6.8" [tool.poetry.scripts] -"autointent" = "autointent.pipeline.main:main" +"autointent" = "autointent.pipeline.optimization:main" +"autointent-inference" = "autointent.pipeline.inference:main" "clear-cache" = "autointent.pipeline.utils.cache:clear_chroma_cache" [tool.ruff] diff --git a/tests/minimal-optimization/test_minimal_optimization.py b/tests/minimal-optimization/test_minimal_optimization.py index 2a342eb..931c301 100644 --- a/tests/minimal-optimization/test_minimal_optimization.py +++ b/tests/minimal-optimization/test_minimal_optimization.py @@ -1,6 +1,6 @@ from autointent import Context from autointent.pipeline import Pipeline -from autointent.pipeline.main import get_db_dir, get_run_name, load_data, setup_logging +from autointent.pipeline.optimization import get_db_dir, get_run_name, load_data, setup_logging def test_multiclass(): diff --git a/tests/test_modules/test_scoring/test_knn.py b/tests/test_modules/test_scoring/test_knn.py index 2f76c89..129fd17 100644 --- a/tests/test_modules/test_scoring/test_knn.py +++ b/tests/test_modules/test_scoring/test_knn.py @@ -3,7 +3,7 @@ from autointent import Context from autointent.metrics import retrieval_hit_rate, scoring_roc_auc from autointent.modules import KNNScorer, VectorDBModule -from autointent.pipeline.main import get_db_dir, get_run_name, load_data, setup_logging +from autointent.pipeline.optimization import get_db_dir, get_run_name, load_data, setup_logging def test_base_knn(): diff --git a/tests/test_modules/test_scoring/test_mlknn.py b/tests/test_modules/test_scoring/test_mlknn.py index 9773f2f..7bfc690 100644 --- a/tests/test_modules/test_scoring/test_mlknn.py +++ b/tests/test_modules/test_scoring/test_mlknn.py @@ -4,7 +4,7 @@ from autointent.metrics import retrieval_hit_rate_macro, scoring_f1 from autointent.modules import VectorDBModule from autointent.modules.scoring.mlknn.mlknn import MLKnnScorer -from autointent.pipeline.main import get_db_dir, get_run_name, load_data, setup_logging +from autointent.pipeline.optimization import get_db_dir, get_run_name, load_data, setup_logging def test_base_mlknn():