Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/pipeline simpler fitting #36

Merged
merged 30 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f3d7777
stage result
voorhs Nov 5, 2024
b192dc8
decompose `Context.__init__()` and implement `get_` methods
voorhs Nov 5, 2024
0f6f568
fix tests
voorhs Nov 5, 2024
a118e27
fix typing
voorhs Nov 5, 2024
0e3fe2a
add `Context.set_datasets` and allow not dumping modules
voorhs Nov 5, 2024
c4e15d9
implement `PipelineOptimizer.fit_from_dataset`
voorhs Nov 5, 2024
109d47e
enable configuration for python api
voorhs Nov 5, 2024
2b0d371
fix typing
voorhs Nov 5, 2024
d7c4066
fix tests
voorhs Nov 5, 2024
d305bb5
add `clear_ram` option
voorhs Nov 6, 2024
d648849
infering modules from ram after optimization
voorhs Nov 6, 2024
e7d0fbd
minor change
voorhs Nov 6, 2024
975c8df
fix unintended `runs` directory creation
voorhs Nov 6, 2024
378e582
add `save_db` option
voorhs Nov 6, 2024
2b80888
Merge branch 'refs/heads/dev' into feat/pipeline-simpler-fitting
Samoed Nov 6, 2024
8c2eaff
fix circular imports
Samoed Nov 6, 2024
322340b
fix tests
voorhs Nov 8, 2024
c487363
Test/pipeline simpler fitting (#39)
Darinochka Nov 9, 2024
a2e4dea
refactor github actions
voorhs Nov 9, 2024
c349f18
rename actions
voorhs Nov 9, 2024
b26a878
fix `model_name` issue
voorhs Nov 9, 2024
7ccbca2
response to review
voorhs Nov 11, 2024
db7f207
attempt to fix `winerror access denied` problem
voorhs Nov 12, 2024
11bb883
try to fix `unexpected argument` error
voorhs Nov 12, 2024
616ba32
minor bug fix
voorhs Nov 12, 2024
11bdb6f
another attempt to fix permission error
voorhs Nov 12, 2024
9545bbf
stupid bug fix
voorhs Nov 12, 2024
1096e04
refactor cache cleaning
voorhs Nov 12, 2024
d6884d4
another attempt (workaround: ingore permission error)
voorhs Nov 12, 2024
9e55a65
change return type of classmethods-constructors to `Self`
voorhs Nov 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions .github/workflows/test-inference.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
name: test inference

on:
push:
branches:
- dev
pull_request:
branches:
- dev

jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ ubuntu-latest ]
python-version: [ "3.10", "3.11", "3.12" ]
include:
- os: windows-latest
python-version: "3.10"

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: "pip"

- name: Install dependencies
run: |
pip install .
pip install pytest pytest-asyncio
- name: Run tests
run: |
pytest tests/pipeline/test_inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Run Tests
name: test nodes

on:
push:
Expand Down Expand Up @@ -37,4 +37,4 @@ jobs:
- name: Run tests
run: |
pytest
pytest tests/nodes
40 changes: 40 additions & 0 deletions .github/workflows/test-optimization.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
name: test optimization

on:
push:
branches:
- dev
pull_request:
branches:
- dev

jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ ubuntu-latest ]
python-version: [ "3.10", "3.11", "3.12" ]
include:
- os: windows-latest
python-version: "3.10"

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: "pip"

- name: Install dependencies
run: |
pip install .
pip install pytest pytest-asyncio
- name: Run tests
run: |
pytest tests/pipeline/test_optimization.py
40 changes: 40 additions & 0 deletions .github/workflows/unit-tests.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
name: unit tests

on:
push:
branches:
- dev
pull_request:
branches:
- dev

jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ ubuntu-latest ]
python-version: [ "3.10", "3.11", "3.12" ]
include:
- os: windows-latest
python-version: "3.10"

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: "pip"

- name: Install dependencies
run: |
pip install .
pip install pytest pytest-asyncio
- name: Run tests
run: |
pytest --ignore=tests/nodes --ignore=tests/pipeline
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import random
from datetime import datetime

adjectives = [
"adorable",
Expand Down Expand Up @@ -342,3 +343,9 @@ def generate_name() -> str:
adjective = random.choice(adjectives)
noun = random.choice(nouns)
return f"{adjective}_{noun}"


def get_run_name(run_name: str | None = None) -> str:
if run_name is None:
run_name = generate_name()
return f"{run_name}_{datetime.now().strftime('%m-%d-%Y_%H-%M-%S')}" # noqa: DTZ005
2 changes: 1 addition & 1 deletion autointent/configs/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class InferenceNodeConfig:
node_type: str = MISSING
module_type: str = MISSING
module_config: dict[str, Any] = MISSING
load_path: str = MISSING
load_path: str | None = None
_target_: str = "autointent.nodes.InferenceNode"


Expand Down
6 changes: 4 additions & 2 deletions autointent/configs/optimization_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING

from autointent.pipeline.optimization.utils import generate_name
from .name import generate_name


@dataclass
Expand All @@ -28,6 +28,8 @@ class LoggingConfig:
run_name: str | None = None
dirpath: Path | None = None
dump_dir: Path | None = None
dump_modules: bool = False
clear_ram: bool = True

def __post_init__(self) -> None:
self.define_run_name()
Expand All @@ -44,7 +46,6 @@ def define_dirpath(self) -> None:
if self.run_name is None:
raise ValueError
self.dirpath = dirpath / self.run_name
self.dirpath.mkdir(parents=True)

def define_dump_dir(self) -> None:
if self.dump_dir is None:
Expand All @@ -57,6 +58,7 @@ def define_dump_dir(self) -> None:
class VectorIndexConfig:
db_dir: Path | None = None
device: str = "cpu"
save_db: bool = False


@dataclass
Expand Down
160 changes: 133 additions & 27 deletions autointent/context/context.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,79 @@
import json
import logging
from dataclasses import asdict
from pathlib import Path
from typing import Any

import yaml

from autointent.configs.optimization_cli import (
AugmentationConfig,
DataConfig,
EmbedderConfig,
LoggingConfig,
VectorIndexConfig,
)

from .data_handler import DataAugmenter, DataHandler, Dataset
from .optimization_info import OptimizationInfo
from .utils import NumpyEncoder, load_data
from .vector_index_client import VectorIndex, VectorIndexClient


class Context:
def __init__( # noqa: PLR0913
data_handler: DataHandler
vector_index_client: VectorIndexClient
optimization_info: OptimizationInfo

def __init__(
self,
dataset: Dataset,
test_dataset: Dataset | None = None,
device: str = "cpu",
multilabel_generation_config: str | None = None,
regex_sampling: int = 0,
seed: int = 42,
db_dir: str | Path | None = None,
dump_dir: str | Path | None = None,
force_multilabel: bool = False,
embedder_batch_size: int = 32,
embedder_max_length: int | None = None,
) -> None:
augmenter = DataAugmenter(multilabel_generation_config, regex_sampling, seed)
self.seed = seed
self._logger = logging.getLogger(__name__)

def configure_logging(self, config: LoggingConfig) -> None:
self.logging_config = config
self.optimization_info = OptimizationInfo()

def configure_vector_index(self, config: VectorIndexConfig, embedder_config: EmbedderConfig | None = None) -> None:
self.vector_index_config = config
if embedder_config is None:
embedder_config = EmbedderConfig()
self.embedder_config = embedder_config

self.vector_index_client = VectorIndexClient(
self.vector_index_config.device,
self.vector_index_config.db_dir,
self.embedder_config.batch_size,
self.embedder_config.max_length,
)

def configure_data(self, config: DataConfig, augmentation_config: AugmentationConfig | None = None) -> None:
if augmentation_config is not None:
self.augmentation_config = AugmentationConfig()
augmenter = DataAugmenter(
self.augmentation_config.multilabel_generation_config,
self.augmentation_config.regex_sampling,
self.seed,
)
else:
augmenter = None

self.data_handler = DataHandler(
dataset, test_dataset, random_seed=seed, force_multilabel=force_multilabel, augmenter=augmenter
dataset=load_data(config.train_path),
test_dataset=None if config.test_path is None else load_data(config.test_path),
random_seed=self.seed,
force_multilabel=config.force_multilabel,
augmenter=augmenter,
)

def set_datasets(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Метод set_datasets, а на вход ..._data

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

не оч понял

self, train_data: Dataset, val_data: Dataset | None = None, force_multilabel: bool = False
) -> None:
self.data_handler = DataHandler(
dataset=train_data, test_dataset=val_data, random_seed=self.seed, force_multilabel=force_multilabel
)
self.optimization_info = OptimizationInfo()
self.vector_index_client = VectorIndexClient(device, db_dir, embedder_batch_size, embedder_max_length)

self.db_dir = self.vector_index_client.db_dir
self.embedder_max_length = embedder_max_length
self.embedder_batch_size = embedder_batch_size
self.device = device
self.multilabel = self.data_handler.multilabel
self.n_classes = self.data_handler.n_classes
self.seed = seed
self.dump_dir = Path.cwd() / "modules_dumps" if dump_dir is None else Path(dump_dir)

def get_best_index(self) -> VectorIndex:
model_name = self.optimization_info.get_best_embedder()
Expand All @@ -48,10 +85,79 @@ def get_inference_config(self) -> dict[str, Any]:
cfg.pop("_target_")
return {
"metadata": {
"device": self.device,
"multilabel": self.multilabel,
"n_classes": self.n_classes,
"device": self.get_device(),
"multilabel": self.is_multilabel(),
"n_classes": self.get_n_classes(),
"seed": self.seed,
},
"nodes_configs": nodes_configs,
}

def dump(self) -> None:
self._logger.debug("dumping logs...")
optimization_results = self.optimization_info.dump_evaluation_results()

logs_dir = self.logging_config.dirpath
if logs_dir is None:
msg = "something's wrong with LoggingConfig"
raise ValueError(msg)

# create appropriate directory
logs_dir.mkdir(parents=True, exist_ok=True)

# dump search space and evaluation results
logs_path = logs_dir / "logs.json"
with logs_path.open("w") as file:
json.dump(optimization_results, file, indent=4, ensure_ascii=False, cls=NumpyEncoder)
# config_path = logs_dir / "config.yaml"
# with config_path.open("w") as file:
# yaml.dump(self.config, file)

# self._logger.info(make_report(optimization_results, nodes=nodes))

# dump train and test data splits
train_data, test_data = self.data_handler.dump()
train_path = logs_dir / "train_data.json"
test_path = logs_dir / "test_data.json"
with train_path.open("w") as file:
json.dump(train_data, file, indent=4, ensure_ascii=False)
with test_path.open("w") as file:
json.dump(test_data, file, indent=4, ensure_ascii=False)

self._logger.info("logs and other assets are saved to %s", logs_dir)

# dump optimization results (config for inference)
inference_config = self.get_inference_config()
inference_config_path = logs_dir / "inference_config.yaml"
with inference_config_path.open("w") as file:
yaml.dump(inference_config, file)

def get_db_dir(self) -> Path:
return self.vector_index_client.db_dir

def get_device(self) -> str:
return self.vector_index_client.device

def get_batch_size(self) -> int:
return self.vector_index_client.embedder_batch_size

def get_max_length(self) -> int | None:
return self.vector_index_client.embedder_max_length

def get_dump_dir(self) -> Path | None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Сделай get... просто как property

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

if self.logging_config.dump_modules:
return self.logging_config.dump_dir
return None

def is_multilabel(self) -> bool:
return self.data_handler.multilabel

def get_n_classes(self) -> int:
return self.data_handler.n_classes

def is_ram_to_clear(self) -> bool:
return self.logging_config.clear_ram

def has_saved_modules(self) -> bool:
node_types = ["regexp", "retrieval", "scoring", "prediction"]
return any(len(self.optimization_info.modules.get(nt)) > 0 for nt in node_types)
Loading