Skip to content

Commit

Permalink
Feat/pipeline simpler fitting (#36)
Browse files Browse the repository at this point in the history
Co-authored-by: Roman Solomatin <[email protected]>
Co-authored-by: Darinka <[email protected]>
  • Loading branch information
3 people committed Nov 15, 2024
1 parent 953636f commit 4d6ee75
Show file tree
Hide file tree
Showing 61 changed files with 2,737 additions and 526 deletions.
40 changes: 40 additions & 0 deletions .github/workflows/test-inference.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
name: test inference

on:
push:
branches:
- dev
pull_request:
branches:
- dev

jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ ubuntu-latest ]
python-version: [ "3.10", "3.11", "3.12" ]
include:
- os: windows-latest
python-version: "3.10"

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: "pip"

- name: Install dependencies
run: |
pip install .
pip install pytest pytest-asyncio
- name: Run tests
run: |
pytest tests/pipeline/test_inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Run Tests
name: test nodes

on:
push:
Expand Down Expand Up @@ -37,4 +37,4 @@ jobs:
- name: Run tests
run: |
pytest
pytest tests/nodes
40 changes: 40 additions & 0 deletions .github/workflows/test-optimization.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
name: test optimization

on:
push:
branches:
- dev
pull_request:
branches:
- dev

jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ ubuntu-latest ]
python-version: [ "3.10", "3.11", "3.12" ]
include:
- os: windows-latest
python-version: "3.10"

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: "pip"

- name: Install dependencies
run: |
pip install .
pip install pytest pytest-asyncio
- name: Run tests
run: |
pytest tests/pipeline/test_optimization.py
40 changes: 40 additions & 0 deletions .github/workflows/unit-tests.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
name: unit tests

on:
push:
branches:
- dev
pull_request:
branches:
- dev

jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ ubuntu-latest ]
python-version: [ "3.10", "3.11", "3.12" ]
include:
- os: windows-latest
python-version: "3.10"

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: "pip"

- name: Install dependencies
run: |
pip install .
pip install pytest pytest-asyncio
- name: Run tests
run: |
pytest --ignore=tests/nodes --ignore=tests/pipeline
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import random
from datetime import datetime

adjectives = [
"adorable",
Expand Down Expand Up @@ -342,3 +343,9 @@ def generate_name() -> str:
adjective = random.choice(adjectives)
noun = random.choice(nouns)
return f"{adjective}_{noun}"


def get_run_name(run_name: str | None = None) -> str:
if run_name is None:
run_name = generate_name()
return f"{run_name}_{datetime.now().strftime('%m-%d-%Y_%H-%M-%S')}" # noqa: DTZ005
2 changes: 1 addition & 1 deletion autointent/configs/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class InferenceNodeConfig:
node_type: str = MISSING
module_type: str = MISSING
module_config: dict[str, Any] = MISSING
load_path: str = MISSING
load_path: str | None = None
_target_: str = "autointent.nodes.InferenceNode"


Expand Down
6 changes: 4 additions & 2 deletions autointent/configs/optimization_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING

from autointent.pipeline.optimization.utils import generate_name
from .name import generate_name


@dataclass
Expand All @@ -28,6 +28,8 @@ class LoggingConfig:
run_name: str | None = None
dirpath: Path | None = None
dump_dir: Path | None = None
dump_modules: bool = False
clear_ram: bool = True

def __post_init__(self) -> None:
self.define_run_name()
Expand All @@ -44,7 +46,6 @@ def define_dirpath(self) -> None:
if self.run_name is None:
raise ValueError
self.dirpath = dirpath / self.run_name
self.dirpath.mkdir(parents=True)

def define_dump_dir(self) -> None:
if self.dump_dir is None:
Expand All @@ -57,6 +58,7 @@ def define_dump_dir(self) -> None:
class VectorIndexConfig:
db_dir: Path | None = None
device: str = "cpu"
save_db: bool = False


@dataclass
Expand Down
160 changes: 133 additions & 27 deletions autointent/context/context.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,79 @@
import json
import logging
from dataclasses import asdict
from pathlib import Path
from typing import Any

import yaml

from autointent.configs.optimization_cli import (
AugmentationConfig,
DataConfig,
EmbedderConfig,
LoggingConfig,
VectorIndexConfig,
)

from .data_handler import DataAugmenter, DataHandler, Dataset
from .optimization_info import OptimizationInfo
from .utils import NumpyEncoder, load_data
from .vector_index_client import VectorIndex, VectorIndexClient


class Context:
def __init__( # noqa: PLR0913
data_handler: DataHandler
vector_index_client: VectorIndexClient
optimization_info: OptimizationInfo

def __init__(
self,
dataset: Dataset,
test_dataset: Dataset | None = None,
device: str = "cpu",
multilabel_generation_config: str | None = None,
regex_sampling: int = 0,
seed: int = 42,
db_dir: str | Path | None = None,
dump_dir: str | Path | None = None,
force_multilabel: bool = False,
embedder_batch_size: int = 32,
embedder_max_length: int | None = None,
) -> None:
augmenter = DataAugmenter(multilabel_generation_config, regex_sampling, seed)
self.seed = seed
self._logger = logging.getLogger(__name__)

def configure_logging(self, config: LoggingConfig) -> None:
self.logging_config = config
self.optimization_info = OptimizationInfo()

def configure_vector_index(self, config: VectorIndexConfig, embedder_config: EmbedderConfig | None = None) -> None:
self.vector_index_config = config
if embedder_config is None:
embedder_config = EmbedderConfig()
self.embedder_config = embedder_config

self.vector_index_client = VectorIndexClient(
self.vector_index_config.device,
self.vector_index_config.db_dir,
self.embedder_config.batch_size,
self.embedder_config.max_length,
)

def configure_data(self, config: DataConfig, augmentation_config: AugmentationConfig | None = None) -> None:
if augmentation_config is not None:
self.augmentation_config = AugmentationConfig()
augmenter = DataAugmenter(
self.augmentation_config.multilabel_generation_config,
self.augmentation_config.regex_sampling,
self.seed,
)
else:
augmenter = None

self.data_handler = DataHandler(
dataset, test_dataset, random_seed=seed, force_multilabel=force_multilabel, augmenter=augmenter
dataset=load_data(config.train_path),
test_dataset=None if config.test_path is None else load_data(config.test_path),
random_seed=self.seed,
force_multilabel=config.force_multilabel,
augmenter=augmenter,
)

def set_datasets(
self, train_data: Dataset, val_data: Dataset | None = None, force_multilabel: bool = False
) -> None:
self.data_handler = DataHandler(
dataset=train_data, test_dataset=val_data, random_seed=self.seed, force_multilabel=force_multilabel
)
self.optimization_info = OptimizationInfo()
self.vector_index_client = VectorIndexClient(device, db_dir, embedder_batch_size, embedder_max_length)

self.db_dir = self.vector_index_client.db_dir
self.embedder_max_length = embedder_max_length
self.embedder_batch_size = embedder_batch_size
self.device = device
self.multilabel = self.data_handler.multilabel
self.n_classes = self.data_handler.n_classes
self.seed = seed
self.dump_dir = Path.cwd() / "modules_dumps" if dump_dir is None else Path(dump_dir)

def get_best_index(self) -> VectorIndex:
model_name = self.optimization_info.get_best_embedder()
Expand All @@ -48,10 +85,79 @@ def get_inference_config(self) -> dict[str, Any]:
cfg.pop("_target_")
return {
"metadata": {
"device": self.device,
"multilabel": self.multilabel,
"n_classes": self.n_classes,
"device": self.get_device(),
"multilabel": self.is_multilabel(),
"n_classes": self.get_n_classes(),
"seed": self.seed,
},
"nodes_configs": nodes_configs,
}

def dump(self) -> None:
self._logger.debug("dumping logs...")
optimization_results = self.optimization_info.dump_evaluation_results()

logs_dir = self.logging_config.dirpath
if logs_dir is None:
msg = "something's wrong with LoggingConfig"
raise ValueError(msg)

# create appropriate directory
logs_dir.mkdir(parents=True, exist_ok=True)

# dump search space and evaluation results
logs_path = logs_dir / "logs.json"
with logs_path.open("w") as file:
json.dump(optimization_results, file, indent=4, ensure_ascii=False, cls=NumpyEncoder)
# config_path = logs_dir / "config.yaml"
# with config_path.open("w") as file:
# yaml.dump(self.config, file)

# self._logger.info(make_report(optimization_results, nodes=nodes))

# dump train and test data splits
train_data, test_data = self.data_handler.dump()
train_path = logs_dir / "train_data.json"
test_path = logs_dir / "test_data.json"
with train_path.open("w") as file:
json.dump(train_data, file, indent=4, ensure_ascii=False)
with test_path.open("w") as file:
json.dump(test_data, file, indent=4, ensure_ascii=False)

self._logger.info("logs and other assets are saved to %s", logs_dir)

# dump optimization results (config for inference)
inference_config = self.get_inference_config()
inference_config_path = logs_dir / "inference_config.yaml"
with inference_config_path.open("w") as file:
yaml.dump(inference_config, file)

def get_db_dir(self) -> Path:
return self.vector_index_client.db_dir

def get_device(self) -> str:
return self.vector_index_client.device

def get_batch_size(self) -> int:
return self.vector_index_client.embedder_batch_size

def get_max_length(self) -> int | None:
return self.vector_index_client.embedder_max_length

def get_dump_dir(self) -> Path | None:
if self.logging_config.dump_modules:
return self.logging_config.dump_dir
return None

def is_multilabel(self) -> bool:
return self.data_handler.multilabel

def get_n_classes(self) -> int:
return self.data_handler.n_classes

def is_ram_to_clear(self) -> bool:
return self.logging_config.clear_ram

def has_saved_modules(self) -> bool:
node_types = ["regexp", "retrieval", "scoring", "prediction"]
return any(len(self.optimization_info.modules.get(nt)) > 0 for nt in node_types)
Loading

0 comments on commit 4d6ee75

Please sign in to comment.