-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/pipeline simpler fitting #36
Changes from 9 commits
f3d7777
b192dc8
0f6f568
a118e27
0e3fe2a
c4e15d9
109d47e
2b0d371
d7c4066
d305bb5
d648849
e7d0fbd
975c8df
378e582
2b80888
8c2eaff
322340b
c487363
a2e4dea
c349f18
b26a878
7ccbca2
db7f207
11bb883
616ba32
11bdb6f
9545bbf
1096e04
d6884d4
9e55a65
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,42 +1,79 @@ | ||
import json | ||
import logging | ||
from dataclasses import asdict | ||
from pathlib import Path | ||
from typing import Any | ||
|
||
import yaml | ||
|
||
from autointent.configs.optimization_cli import ( | ||
AugmentationConfig, | ||
DataConfig, | ||
EmbedderConfig, | ||
LoggingConfig, | ||
VectorIndexConfig, | ||
) | ||
|
||
from .data_handler import DataAugmenter, DataHandler, Dataset | ||
from .optimization_info import OptimizationInfo | ||
from .utils import NumpyEncoder, load_data | ||
from .vector_index_client import VectorIndex, VectorIndexClient | ||
|
||
|
||
class Context: | ||
def __init__( # noqa: PLR0913 | ||
data_handler: DataHandler | ||
vector_index_client: VectorIndexClient | ||
optimization_info: OptimizationInfo | ||
|
||
def __init__( | ||
self, | ||
dataset: Dataset, | ||
test_dataset: Dataset | None = None, | ||
device: str = "cpu", | ||
multilabel_generation_config: str | None = None, | ||
regex_sampling: int = 0, | ||
seed: int = 42, | ||
db_dir: str | Path | None = None, | ||
dump_dir: str | Path | None = None, | ||
force_multilabel: bool = False, | ||
embedder_batch_size: int = 32, | ||
embedder_max_length: int | None = None, | ||
) -> None: | ||
augmenter = DataAugmenter(multilabel_generation_config, regex_sampling, seed) | ||
self.seed = seed | ||
self._logger = logging.getLogger(__name__) | ||
|
||
def config_logs(self, config: LoggingConfig) -> None: | ||
self.logging_config = config | ||
self.optimization_info = OptimizationInfo() | ||
|
||
def config_vector_index(self, config: VectorIndexConfig, embedder_config: EmbedderConfig | None = None) -> None: | ||
self.vector_index_config = config | ||
if embedder_config is None: | ||
embedder_config = EmbedderConfig() | ||
self.embedder_config = embedder_config | ||
|
||
self.vector_index_client = VectorIndexClient( | ||
self.vector_index_config.device, | ||
self.vector_index_config.db_dir, | ||
self.embedder_config.batch_size, | ||
self.embedder_config.max_length, | ||
) | ||
|
||
def config_data(self, config: DataConfig, augmentation_config: AugmentationConfig | None = None) -> None: | ||
if augmentation_config is not None: | ||
self.augmentation_config = AugmentationConfig() | ||
augmenter = DataAugmenter( | ||
self.augmentation_config.multilabel_generation_config, | ||
self.augmentation_config.regex_sampling, | ||
self.seed, | ||
) | ||
else: | ||
augmenter = None | ||
|
||
self.data_handler = DataHandler( | ||
dataset, test_dataset, random_seed=seed, force_multilabel=force_multilabel, augmenter=augmenter | ||
dataset=load_data(config.train_path), | ||
test_dataset=None if config.test_path is None else load_data(config.test_path), | ||
random_seed=self.seed, | ||
force_multilabel=config.force_multilabel, | ||
augmenter=augmenter, | ||
) | ||
|
||
def set_datasets( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Метод There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. не оч понял |
||
self, train_data: Dataset, val_data: Dataset | None = None, force_multilabel: bool = False | ||
) -> None: | ||
self.data_handler = DataHandler( | ||
dataset=train_data, test_dataset=val_data, random_seed=self.seed, force_multilabel=force_multilabel | ||
) | ||
self.optimization_info = OptimizationInfo() | ||
self.vector_index_client = VectorIndexClient(device, db_dir, embedder_batch_size, embedder_max_length) | ||
|
||
self.db_dir = self.vector_index_client.db_dir | ||
self.embedder_max_length = embedder_max_length | ||
self.embedder_batch_size = embedder_batch_size | ||
self.device = device | ||
self.multilabel = self.data_handler.multilabel | ||
self.n_classes = self.data_handler.n_classes | ||
self.seed = seed | ||
self.dump_dir = Path.cwd() / "modules_dumps" if dump_dir is None else Path(dump_dir) | ||
|
||
def get_best_index(self) -> VectorIndex: | ||
model_name = self.optimization_info.get_best_embedder() | ||
|
@@ -48,10 +85,72 @@ def get_inference_config(self) -> dict[str, Any]: | |
cfg.pop("_target_") | ||
return { | ||
"metadata": { | ||
"device": self.device, | ||
"multilabel": self.multilabel, | ||
"n_classes": self.n_classes, | ||
"device": self.get_device(), | ||
"multilabel": self.is_multilabel(), | ||
"n_classes": self.get_n_classes(), | ||
"seed": self.seed, | ||
}, | ||
"nodes_configs": nodes_configs, | ||
} | ||
|
||
def dump(self) -> None: | ||
self._logger.debug("dumping logs...") | ||
optimization_results = self.optimization_info.dump_evaluation_results() | ||
|
||
logs_dir = self.logging_config.dirpath | ||
if logs_dir is None: | ||
msg = "something's wrong with LoggingConfig" | ||
raise ValueError(msg) | ||
|
||
# create appropriate directory | ||
logs_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
# dump search space and evaluation results | ||
logs_path = logs_dir / "logs.json" | ||
with logs_path.open("w") as file: | ||
json.dump(optimization_results, file, indent=4, ensure_ascii=False, cls=NumpyEncoder) | ||
# config_path = logs_dir / "config.yaml" | ||
# with config_path.open("w") as file: | ||
# yaml.dump(self.config, file) | ||
|
||
# self._logger.info(make_report(optimization_results, nodes=nodes)) | ||
|
||
# dump train and test data splits | ||
train_data, test_data = self.data_handler.dump() | ||
train_path = logs_dir / "train_data.json" | ||
test_path = logs_dir / "test_data.json" | ||
with train_path.open("w") as file: | ||
json.dump(train_data, file, indent=4, ensure_ascii=False) | ||
with test_path.open("w") as file: | ||
json.dump(test_data, file, indent=4, ensure_ascii=False) | ||
|
||
self._logger.info("logs and other assets are saved to %s", logs_dir) | ||
|
||
# dump optimization results (config for inference) | ||
inference_config = self.get_inference_config() | ||
inference_config_path = logs_dir / "inference_config.yaml" | ||
with inference_config_path.open("w") as file: | ||
yaml.dump(inference_config, file) | ||
|
||
def get_db_dir(self) -> Path: | ||
return self.vector_index_client.db_dir | ||
|
||
def get_device(self) -> str: | ||
return self.vector_index_client.device | ||
|
||
def get_batch_size(self) -> int: | ||
return self.vector_index_client.embedder_batch_size | ||
|
||
def get_max_length(self) -> int | None: | ||
return self.vector_index_client.embedder_max_length | ||
|
||
def get_dump_dir(self) -> Path | None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Сделай get... просто как property There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 |
||
if self.logging_config.dump_modules: | ||
return self.logging_config.dump_dir | ||
return None | ||
|
||
def is_multilabel(self) -> bool: | ||
return self.data_handler.multilabel | ||
|
||
def get_n_classes(self) -> int: | ||
return self.data_handler.n_classes |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import importlib.resources as ires | ||
import json | ||
from pathlib import Path | ||
from typing import Any | ||
|
||
import numpy as np | ||
from omegaconf import ListConfig | ||
|
||
from .data_handler import Dataset | ||
|
||
|
||
class NumpyEncoder(json.JSONEncoder): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Как будто этот класс больше не нужен There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. он используется еще в |
||
"""Helper for dumping logs. Problem explained: https://stackoverflow.com/q/50916422""" | ||
|
||
def default(self, obj: Any) -> str | int | float | list[Any] | Any: # noqa: ANN401 | ||
if isinstance(obj, np.integer): | ||
return int(obj) | ||
if isinstance(obj, np.floating): | ||
return float(obj) | ||
if isinstance(obj, np.ndarray): | ||
return obj.tolist() | ||
if isinstance(obj, ListConfig): | ||
return list(obj) | ||
return super().default(obj) | ||
|
||
|
||
def load_data(data_path: str | Path) -> Dataset: | ||
"""load data from the given path or load sample data which is distributed along with the autointent package""" | ||
if data_path == "default-multiclass": | ||
with ires.files("autointent.datafiles").joinpath("banking77.json").open() as file: | ||
res = json.load(file) | ||
elif data_path == "default-multilabel": | ||
with ires.files("autointent.datafiles").joinpath("dstc3-20shot.json").open() as file: | ||
res = json.load(file) | ||
else: | ||
with Path(data_path).open() as file: | ||
res = json.load(file) | ||
|
||
return Dataset.model_validate(res) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Я бы подобные методы называл
configure_logging
илиsetup_logging