Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/pipeline simpler fitting #36

Merged
merged 30 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f3d7777
stage result
voorhs Nov 5, 2024
b192dc8
decompose `Context.__init__()` and implement `get_` methods
voorhs Nov 5, 2024
0f6f568
fix tests
voorhs Nov 5, 2024
a118e27
fix typing
voorhs Nov 5, 2024
0e3fe2a
add `Context.set_datasets` and allow not dumping modules
voorhs Nov 5, 2024
c4e15d9
implement `PipelineOptimizer.fit_from_dataset`
voorhs Nov 5, 2024
109d47e
enable configuration for python api
voorhs Nov 5, 2024
2b0d371
fix typing
voorhs Nov 5, 2024
d7c4066
fix tests
voorhs Nov 5, 2024
d305bb5
add `clear_ram` option
voorhs Nov 6, 2024
d648849
infering modules from ram after optimization
voorhs Nov 6, 2024
e7d0fbd
minor change
voorhs Nov 6, 2024
975c8df
fix unintended `runs` directory creation
voorhs Nov 6, 2024
378e582
add `save_db` option
voorhs Nov 6, 2024
2b80888
Merge branch 'refs/heads/dev' into feat/pipeline-simpler-fitting
Samoed Nov 6, 2024
8c2eaff
fix circular imports
Samoed Nov 6, 2024
322340b
fix tests
voorhs Nov 8, 2024
c487363
Test/pipeline simpler fitting (#39)
Darinochka Nov 9, 2024
a2e4dea
refactor github actions
voorhs Nov 9, 2024
c349f18
rename actions
voorhs Nov 9, 2024
b26a878
fix `model_name` issue
voorhs Nov 9, 2024
7ccbca2
response to review
voorhs Nov 11, 2024
db7f207
attempt to fix `winerror access denied` problem
voorhs Nov 12, 2024
11bb883
try to fix `unexpected argument` error
voorhs Nov 12, 2024
616ba32
minor bug fix
voorhs Nov 12, 2024
11bdb6f
another attempt to fix permission error
voorhs Nov 12, 2024
9545bbf
stupid bug fix
voorhs Nov 12, 2024
1096e04
refactor cache cleaning
voorhs Nov 12, 2024
d6884d4
another attempt (workaround: ingore permission error)
voorhs Nov 12, 2024
9e55a65
change return type of classmethods-constructors to `Self`
voorhs Nov 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import random
from datetime import datetime

adjectives = [
"adorable",
Expand Down Expand Up @@ -342,3 +343,9 @@ def generate_name() -> str:
adjective = random.choice(adjectives)
noun = random.choice(nouns)
return f"{adjective}_{noun}"


def get_run_name(run_name: str | None = None) -> str:
if run_name is None:
run_name = generate_name()
return f"{run_name}_{datetime.now().strftime('%m-%d-%Y_%H-%M-%S')}" # noqa: DTZ005
2 changes: 1 addition & 1 deletion autointent/configs/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class InferenceNodeConfig:
node_type: str = MISSING
module_type: str = MISSING
module_config: dict[str, Any] = MISSING
load_path: str = MISSING
load_path: str | None = None
_target_: str = "autointent.nodes.InferenceNode"


Expand Down
3 changes: 2 additions & 1 deletion autointent/configs/optimization_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING

from autointent.pipeline.optimization.utils import generate_name
from .name import generate_name


@dataclass
Expand All @@ -28,6 +28,7 @@ class LoggingConfig:
run_name: str | None = None
dirpath: Path | None = None
dump_dir: Path | None = None
dump_modules: bool = True

def __post_init__(self) -> None:
self.define_run_name()
Expand Down
153 changes: 126 additions & 27 deletions autointent/context/context.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,79 @@
import json
import logging
from dataclasses import asdict
from pathlib import Path
from typing import Any

import yaml

from autointent.configs.optimization_cli import (
AugmentationConfig,
DataConfig,
EmbedderConfig,
LoggingConfig,
VectorIndexConfig,
)

from .data_handler import DataAugmenter, DataHandler, Dataset
from .optimization_info import OptimizationInfo
from .utils import NumpyEncoder, load_data
from .vector_index_client import VectorIndex, VectorIndexClient


class Context:
def __init__( # noqa: PLR0913
data_handler: DataHandler
vector_index_client: VectorIndexClient
optimization_info: OptimizationInfo

def __init__(
self,
dataset: Dataset,
test_dataset: Dataset | None = None,
device: str = "cpu",
multilabel_generation_config: str | None = None,
regex_sampling: int = 0,
seed: int = 42,
db_dir: str | Path | None = None,
dump_dir: str | Path | None = None,
force_multilabel: bool = False,
embedder_batch_size: int = 32,
embedder_max_length: int | None = None,
) -> None:
augmenter = DataAugmenter(multilabel_generation_config, regex_sampling, seed)
self.seed = seed
self._logger = logging.getLogger(__name__)

def config_logs(self, config: LoggingConfig) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Я бы подобные методы называл configure_logging или setup_logging

self.logging_config = config
self.optimization_info = OptimizationInfo()

def config_vector_index(self, config: VectorIndexConfig, embedder_config: EmbedderConfig | None = None) -> None:
self.vector_index_config = config
if embedder_config is None:
embedder_config = EmbedderConfig()
self.embedder_config = embedder_config

self.vector_index_client = VectorIndexClient(
self.vector_index_config.device,
self.vector_index_config.db_dir,
self.embedder_config.batch_size,
self.embedder_config.max_length,
)

def config_data(self, config: DataConfig, augmentation_config: AugmentationConfig | None = None) -> None:
if augmentation_config is not None:
self.augmentation_config = AugmentationConfig()
augmenter = DataAugmenter(
self.augmentation_config.multilabel_generation_config,
self.augmentation_config.regex_sampling,
self.seed,
)
else:
augmenter = None

self.data_handler = DataHandler(
dataset, test_dataset, random_seed=seed, force_multilabel=force_multilabel, augmenter=augmenter
dataset=load_data(config.train_path),
test_dataset=None if config.test_path is None else load_data(config.test_path),
random_seed=self.seed,
force_multilabel=config.force_multilabel,
augmenter=augmenter,
)

def set_datasets(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Метод set_datasets, а на вход ..._data

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

не оч понял

self, train_data: Dataset, val_data: Dataset | None = None, force_multilabel: bool = False
) -> None:
self.data_handler = DataHandler(
dataset=train_data, test_dataset=val_data, random_seed=self.seed, force_multilabel=force_multilabel
)
self.optimization_info = OptimizationInfo()
self.vector_index_client = VectorIndexClient(device, db_dir, embedder_batch_size, embedder_max_length)

self.db_dir = self.vector_index_client.db_dir
self.embedder_max_length = embedder_max_length
self.embedder_batch_size = embedder_batch_size
self.device = device
self.multilabel = self.data_handler.multilabel
self.n_classes = self.data_handler.n_classes
self.seed = seed
self.dump_dir = Path.cwd() / "modules_dumps" if dump_dir is None else Path(dump_dir)

def get_best_index(self) -> VectorIndex:
model_name = self.optimization_info.get_best_embedder()
Expand All @@ -48,10 +85,72 @@ def get_inference_config(self) -> dict[str, Any]:
cfg.pop("_target_")
return {
"metadata": {
"device": self.device,
"multilabel": self.multilabel,
"n_classes": self.n_classes,
"device": self.get_device(),
"multilabel": self.is_multilabel(),
"n_classes": self.get_n_classes(),
"seed": self.seed,
},
"nodes_configs": nodes_configs,
}

def dump(self) -> None:
self._logger.debug("dumping logs...")
optimization_results = self.optimization_info.dump_evaluation_results()

logs_dir = self.logging_config.dirpath
if logs_dir is None:
msg = "something's wrong with LoggingConfig"
raise ValueError(msg)

# create appropriate directory
logs_dir.mkdir(parents=True, exist_ok=True)

# dump search space and evaluation results
logs_path = logs_dir / "logs.json"
with logs_path.open("w") as file:
json.dump(optimization_results, file, indent=4, ensure_ascii=False, cls=NumpyEncoder)
# config_path = logs_dir / "config.yaml"
# with config_path.open("w") as file:
# yaml.dump(self.config, file)

# self._logger.info(make_report(optimization_results, nodes=nodes))

# dump train and test data splits
train_data, test_data = self.data_handler.dump()
train_path = logs_dir / "train_data.json"
test_path = logs_dir / "test_data.json"
with train_path.open("w") as file:
json.dump(train_data, file, indent=4, ensure_ascii=False)
with test_path.open("w") as file:
json.dump(test_data, file, indent=4, ensure_ascii=False)

self._logger.info("logs and other assets are saved to %s", logs_dir)

# dump optimization results (config for inference)
inference_config = self.get_inference_config()
inference_config_path = logs_dir / "inference_config.yaml"
with inference_config_path.open("w") as file:
yaml.dump(inference_config, file)

def get_db_dir(self) -> Path:
return self.vector_index_client.db_dir

def get_device(self) -> str:
return self.vector_index_client.device

def get_batch_size(self) -> int:
return self.vector_index_client.embedder_batch_size

def get_max_length(self) -> int | None:
return self.vector_index_client.embedder_max_length

def get_dump_dir(self) -> Path | None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Сделай get... просто как property

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

if self.logging_config.dump_modules:
return self.logging_config.dump_dir
return None

def is_multilabel(self) -> bool:
return self.data_handler.multilabel

def get_n_classes(self) -> int:
return self.data_handler.n_classes
2 changes: 1 addition & 1 deletion autointent/context/optimization_info/data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class Trial(BaseModel):
module_params: dict[str, Any]
metric_name: str
metric_value: float
module_dump_dir: str
module_dump_dir: str | None


class Trials(BaseModel):
Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions autointent/context/optimization_info/optimization_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from numpy.typing import NDArray

from autointent.configs.node import InferenceNodeConfig
from autointent.logger import get_logger

from .data_models import Artifact, Artifacts, RetrieverArtifact, ScorerArtifact, Trial, Trials, TrialsIds
from .logger import get_logger


class OptimizationInfo:
Expand All @@ -29,7 +29,7 @@ def log_module_optimization(
metric_value: float,
metric_name: str,
artifact: Artifact,
module_dump_dir: str,
module_dump_dir: str | None,
) -> None:
"""
Purposes:
Expand Down
39 changes: 39 additions & 0 deletions autointent/context/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import importlib.resources as ires
import json
from pathlib import Path
from typing import Any

import numpy as np
from omegaconf import ListConfig

from .data_handler import Dataset


class NumpyEncoder(json.JSONEncoder):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Как будто этот класс больше не нужен

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

он используется еще в Context.dump() и infererence.cli_enpoint.main()

"""Helper for dumping logs. Problem explained: https://stackoverflow.com/q/50916422"""

def default(self, obj: Any) -> str | int | float | list[Any] | Any: # noqa: ANN401
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, ListConfig):
return list(obj)
return super().default(obj)


def load_data(data_path: str | Path) -> Dataset:
"""load data from the given path or load sample data which is distributed along with the autointent package"""
if data_path == "default-multiclass":
with ires.files("autointent.datafiles").joinpath("banking77.json").open() as file:
res = json.load(file)
elif data_path == "default-multilabel":
with ires.files("autointent.datafiles").joinpath("dstc3-20shot.json").open() as file:
res = json.load(file)
else:
with Path(data_path).open() as file:
res = json.load(file)

return Dataset.model_validate(res)
4 changes: 3 additions & 1 deletion autointent/modules/prediction/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def get_prediction_evaluation_data(
oos_scores = context.optimization_info.get_best_oos_scores()
return_scores = scores
if oos_scores is not None:
oos_labels = [[0] * context.n_classes] * len(oos_scores) if context.multilabel else [-1] * len(oos_scores) # type: ignore[list-item]
oos_labels = (
[[0] * context.get_n_classes()] * len(oos_scores) if context.is_multilabel() else [-1] * len(oos_scores) # type: ignore[list-item]
)
labels = np.concatenate([labels, np.array(oos_labels)])
return_scores = np.concatenate([scores, oos_scores])

Expand Down
4 changes: 2 additions & 2 deletions autointent/modules/prediction/threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def __init__(
def from_context(cls, context: Context, thresh: float | npt.NDArray[Any] = 0.5) -> Self:
return cls(
thresh=thresh,
multilabel=context.multilabel,
n_classes=context.n_classes,
multilabel=context.is_multilabel(),
n_classes=context.get_n_classes(),
)

def fit(
Expand Down
8 changes: 4 additions & 4 deletions autointent/modules/retrieval/vectordb.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ def from_context(
return cls(
k=k,
model_name=model_name,
db_dir=str(context.db_dir),
device=context.device,
batch_size=context.embedder_batch_size,
max_length=context.embedder_max_length,
db_dir=str(context.get_db_dir()),
device=context.get_device(),
batch_size=context.get_batch_size(),
max_length=context.get_max_length(),
)

def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
Expand Down
4 changes: 2 additions & 2 deletions autointent/modules/scoring/description/description.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def from_context(

instance = cls(
temperature=temperature,
device=context.device,
db_dir=context.db_dir,
device=context.get_device(),
db_dir=context.get_db_dir(),
model_name=model_name,
)
instance.precomputed_embeddings = precomputed_embeddings
Expand Down
8 changes: 4 additions & 4 deletions autointent/modules/scoring/dnnc/dnnc.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ def from_context(
search_model_name=search_model_name,
k=k,
train_head=train_head,
device=context.device,
db_dir=str(context.db_dir),
batch_size=context.embedder_batch_size,
max_length=context.embedder_max_length,
device=context.get_device(),
db_dir=str(context.get_db_dir()),
batch_size=context.get_batch_size(),
max_length=context.get_max_length(),
)
instance.prebuilt_index = prebuilt_index
return instance
Expand Down
8 changes: 4 additions & 4 deletions autointent/modules/scoring/knn/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ def from_context(
model_name=model_name,
k=k,
weights=weights,
db_dir=str(context.db_dir),
device=context.device,
batch_size=context.embedder_batch_size,
max_length=context.embedder_max_length,
db_dir=str(context.get_db_dir()),
device=context.get_device(),
batch_size=context.get_batch_size(),
max_length=context.get_max_length(),
)
instance.prebuilt_index = prebuilt_index
return instance
Expand Down
8 changes: 4 additions & 4 deletions autointent/modules/scoring/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,13 @@ def from_context(

instance = cls(
model_name=model_name,
device=context.device,
device=context.get_device(),
seed=context.seed,
batch_size=context.embedder_batch_size,
max_length=context.embedder_max_length,
batch_size=context.get_batch_size(),
max_length=context.get_max_length(),
)
instance.precomputed_embeddings = precomputed_embeddings
instance.db_dir = str(context.db_dir)
instance.db_dir = str(context.get_db_dir())
return instance

def fit(
Expand Down
Loading
Loading