Skip to content

Commit

Permalink
Test/pipeline simpler fitting (#39)
Browse files Browse the repository at this point in the history
* tess: added inference_test

* test: added inference pipeline cli

* test: fixed device

* test: added optimization tests

* fix `inference_config.yaml` not found error

---------

Co-authored-by: voorhs <[email protected]>
  • Loading branch information
Darinochka and voorhs authored Nov 9, 2024
1 parent 322340b commit c487363
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 4 deletions.
2 changes: 1 addition & 1 deletion autointent/context/optimization_info/optimization_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def get_inference_nodes_config(self) -> list[InferenceNodeConfig]:
trial = self.trials.get_trial(node_type, idx)
res.append(
InferenceNodeConfig(
node_type=node_type,
node_type=node_type.value,
module_type=trial.module_type,
module_config=trial.module_params,
load_path=trial.module_dump_dir,
Expand Down
2 changes: 1 addition & 1 deletion autointent/pipeline/inference/cli_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def main(cfg: InferenceConfig) -> None:
logger.debug("Inference config loaded")

# instantiate pipeline
pipeline = InferencePipeline.from_config(inference_config["nodes_configs"])
pipeline = InferencePipeline.from_dict_config(inference_config["nodes_configs"])

# send data to pipeline
labels: list[LabelType] = pipeline.predict(data)
Expand Down
8 changes: 8 additions & 0 deletions autointent/pipeline/inference/inference_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

from autointent.configs.node import InferenceNodeConfig
from autointent.context import Context
from autointent.custom_types import LabelType, NodeType
Expand All @@ -8,6 +10,12 @@ class InferencePipeline:
def __init__(self, nodes: list[InferenceNode]) -> None:
self.nodes = {node.node_type: node for node in nodes}

@classmethod
def from_dict_config(cls, nodes_configs: list[dict[str, Any]]) -> "InferencePipeline":
nodes_configs_ = [InferenceNodeConfig(**cfg) for cfg in nodes_configs]
nodes = [InferenceNode.from_config(cfg) for cfg in nodes_configs_]
return cls(nodes)

@classmethod
def from_config(cls, nodes_configs: list[InferenceNodeConfig]) -> "InferencePipeline":
nodes = [InferenceNode.from_config(cfg) for cfg in nodes_configs]
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@


def setup_environment() -> tuple[str, str]:
logs_dir = ires.files("tests").joinpath("logs")
db_dir = logs_dir / "db" / str(uuid4())
logs_dir = ires.files("tests").joinpath("logs") / str(uuid4())
db_dir = logs_dir / "db"
dump_dir = logs_dir / "modules_dump"
return db_dir, dump_dir, logs_dir

Expand Down
109 changes: 109 additions & 0 deletions tests/pipeline/test_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import importlib.resources as ires
from pathlib import Path
from typing import Literal

import pytest

from autointent.configs.inference_cli import InferenceConfig
from autointent.configs.optimization_cli import (
EmbedderConfig,
LoggingConfig,
VectorIndexConfig,
)
from autointent.pipeline.inference import InferencePipeline
from autointent.pipeline.inference.cli_endpoint import main as inference_pipeline
from autointent.pipeline.optimization import PipelineOptimizer
from autointent.pipeline.optimization.utils import load_config
from tests.conftest import setup_environment

TaskType = Literal["multiclass", "multilabel", "description"]


def get_search_space_path(task_type: TaskType):
return ires.files("tests.assets.configs").joinpath(f"{task_type}.yaml")


def get_search_space(task_type: TaskType):
path = get_search_space_path(task_type)
return load_config(str(path), multilabel=task_type == "multilabel")


@pytest.mark.parametrize(
"task_type",
["multiclass", "multilabel", "description"],
)
def test_inference_config(dataset, task_type):
db_dir, dump_dir, logs_dir = setup_environment()
search_space = get_search_space(task_type)

pipeline_optimizer = PipelineOptimizer.from_dict_config(search_space)

pipeline_optimizer.set_config(LoggingConfig(dirpath=Path(logs_dir).resolve(), dump_modules=True))
pipeline_optimizer.set_config(VectorIndexConfig(db_dir=Path(db_dir).resolve(), device="cpu", save_db=True))
pipeline_optimizer.set_config(EmbedderConfig(batch_size=16, max_length=32))

context = pipeline_optimizer.optimize_from_dataset(dataset, force_multilabel=(task_type == "multilabel"))
inference_config = context.optimization_info.get_inference_nodes_config()

inference_pipeline = InferencePipeline.from_config(inference_config)
prediction = inference_pipeline.predict(["123", "hello world"])
if task_type == "multilabel":
assert prediction.shape == (2, len(dataset.intents))
else:
assert prediction.shape == (2,)

context.dump()


@pytest.mark.parametrize(
"task_type",
["multiclass", "multilabel", "description"],
)
def test_inference_context(dataset, task_type):
db_dir, dump_dir, logs_dir = setup_environment()
search_space = get_search_space(task_type)

pipeline_optimizer = PipelineOptimizer.from_dict_config(search_space)

pipeline_optimizer.set_config(LoggingConfig(dirpath=Path(logs_dir).resolve(), dump_modules=True))
pipeline_optimizer.set_config(VectorIndexConfig(db_dir=Path(db_dir).resolve(), device="cpu", save_db=True))
pipeline_optimizer.set_config(EmbedderConfig(batch_size=16, max_length=32))

context = pipeline_optimizer.optimize_from_dataset(dataset, force_multilabel=(task_type == "multilabel"))
inference_pipeline = InferencePipeline.from_context(context)
prediction = inference_pipeline.predict(["123", "hello world"])

if task_type == "multilabel":
assert prediction.shape == (2, len(dataset.intents))
else:
assert prediction.shape == (2,)

context.dump()


@pytest.mark.parametrize(
"task_type",
["multiclass", "multilabel", "description"],
)
def test_inference_pipeline_cli(dataset, task_type):
db_dir, dump_dir, logs_dir = setup_environment()
search_space = get_search_space(task_type)

pipeline_optimizer = PipelineOptimizer.from_dict_config(search_space)

pipeline_optimizer.set_config(
logging_config := LoggingConfig(dirpath=Path(logs_dir).resolve(), dump_dir=dump_dir, dump_modules=True)
)
pipeline_optimizer.set_config(VectorIndexConfig(db_dir=Path(db_dir).resolve(), device="cuda", save_db=True))
pipeline_optimizer.set_config(EmbedderConfig(batch_size=16, max_length=32))
context = pipeline_optimizer.optimize_from_dataset(dataset, force_multilabel=(task_type == "multilabel"))

context.dump()

config = InferenceConfig(
data_path=ires.files("tests.assets.data").joinpath("clinc_subset.json"),
source_dir=logging_config.dirpath,
output_path=logging_config.dump_dir,
log_level="CRITICAL",
)
inference_pipeline(config)
41 changes: 41 additions & 0 deletions tests/pipeline/test_optimization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib.resources as ires
import os
from pathlib import Path
from typing import Literal

Expand Down Expand Up @@ -47,6 +48,46 @@ def test_no_context_optimization(dataset, task_type):
context.dump()


@pytest.mark.parametrize(
"task_type",
["multiclass", "multilabel", "description"],
)
def test_save_db(dataset, task_type):
db_dir, dump_dir, logs_dir = setup_environment()
search_space = get_search_space(task_type)

pipeline_optimizer = PipelineOptimizer.from_dict_config(search_space)

pipeline_optimizer.set_config(LoggingConfig(dirpath=Path(logs_dir).resolve(), dump_modules=False))
pipeline_optimizer.set_config(VectorIndexConfig(db_dir=Path(db_dir).resolve(), save_db=True, device="cpu"))
pipeline_optimizer.set_config(EmbedderConfig(batch_size=16, max_length=32))

context = pipeline_optimizer.optimize_from_dataset(dataset, force_multilabel=(task_type == "multilabel"))
context.dump()

assert os.listdir(db_dir)


@pytest.mark.parametrize(
"task_type",
["multiclass", "multilabel", "description"],
)
def test_dump_modules(dataset, task_type):
db_dir, dump_dir, logs_dir = setup_environment()
search_space = get_search_space(task_type)

pipeline_optimizer = PipelineOptimizer.from_dict_config(search_space)

pipeline_optimizer.set_config(LoggingConfig(dirpath=Path(logs_dir).resolve(), dump_dir=dump_dir, dump_modules=True))
pipeline_optimizer.set_config(VectorIndexConfig(db_dir=Path(db_dir).resolve(), device="cpu"))
pipeline_optimizer.set_config(EmbedderConfig(batch_size=16, max_length=32))

context = pipeline_optimizer.optimize_from_dataset(dataset, force_multilabel=(task_type == "multilabel"))
context.dump()

assert os.listdir(dump_dir)


@pytest.mark.parametrize(
"task_type",
["multiclass", "multilabel", "description"],
Expand Down

0 comments on commit c487363

Please sign in to comment.