Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rag evals][1/n] refactor base scoring fn & data schema check #664

Merged
merged 12 commits into from
Jan 2, 2025
4 changes: 2 additions & 2 deletions llama_stack/apis/scoring/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ class Scoring(Protocol):
async def score_batch(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
scoring_functions: Dict[str, Optional[ScoringFnParams]],
save_results_dataset: bool = False,
) -> ScoreBatchResponse: ...

@webmethod(route="/scoring/score")
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
scoring_functions: Dict[str, Optional[ScoringFnParams]],
) -> ScoreResponse: ...
52 changes: 12 additions & 40 deletions llama_stack/providers/inline/eval/meta_reference/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,24 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Optional

from tqdm import tqdm

from llama_stack.apis.agents import Agents
from llama_stack.apis.common.type_system import (
ChatCompletionInputType,
CompletionInputType,
StringType,
)
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval_tasks import EvalTask
from llama_stack.apis.inference import Inference, UserMessage
from llama_stack.apis.scoring import Scoring
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate

from llama_stack.providers.utils.common.data_schema_validator import (
ColumnName,
DataSchemaValidatorMixin,
get_valid_schemas,
)
from llama_stack.providers.utils.kvstore import kvstore_impl

from .....apis.common.job_types import Job
Expand All @@ -30,15 +31,7 @@
EVAL_TASKS_PREFIX = "eval_tasks:"


class ColumnName(Enum):
input_query = "input_query"
expected_answer = "expected_answer"
chat_completion_input = "chat_completion_input"
completion_input = "completion_input"
generated_answer = "generated_answer"


class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate, DataSchemaValidatorMixin):
def __init__(
self,
config: MetaReferenceEvalConfig,
Expand Down Expand Up @@ -82,29 +75,6 @@ async def register_eval_task(self, task_def: EvalTask) -> None:
)
self.eval_tasks[task_def.identifier] = task_def

async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")

expected_schemas = [
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
},
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.completion_input.value: CompletionInputType(),
},
]

if dataset_def.dataset_schema not in expected_schemas:
raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
)

async def run_eval(
self,
task_id: str,
Expand All @@ -114,8 +84,10 @@ async def run_eval(
dataset_id = task_def.dataset_id
candidate = task_config.eval_candidate
scoring_functions = task_def.scoring_functions

await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.eval.value)
)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=(
Expand Down
34 changes: 14 additions & 20 deletions llama_stack/providers/inline/scoring/basic/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@
ScoringResult,
)
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate

from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
DataSchemaValidatorMixin,
get_valid_schemas,
)
from .config import BasicScoringConfig
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
Expand All @@ -24,7 +29,9 @@
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]


class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
class BasicScoringImpl(
Scoring, ScoringFunctionsProtocolPrivate, DataSchemaValidatorMixin
):
def __init__(
self,
config: BasicScoringConfig,
Expand Down Expand Up @@ -61,30 +68,17 @@ async def list_scoring_functions(self) -> List[ScoringFn]:
async def register_scoring_function(self, function_def: ScoringFn) -> None:
raise NotImplementedError("Register scoring function not implemented yet")

async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)

for required_column in ["generated_answer", "expected_answer", "input_query"]:
if required_column not in dataset_def.dataset_schema:
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column."
)
if dataset_def.dataset_schema[required_column].type != "string":
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)

async def score_batch(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)

all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from llama_stack.apis.scoring import ScoringResultRow

from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn

from .fn_defs.equality import equality


class EqualityScoringFn(BaseScoringFn):
class EqualityScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@

from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn

from .fn_defs.regex_parser_multiple_choice_answer import (
regex_parser_multiple_choice_answer,
)


class RegexParserScoringFn(BaseScoringFn):
class RegexParserScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn that parses answer from generated response according to context and check match with expected_answer.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn

from .fn_defs.subset_of import subset_of


class SubsetOfScoringFn(BaseScoringFn):
class SubsetOfScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise.
"""
Expand Down
94 changes: 63 additions & 31 deletions llama_stack/providers/inline/scoring/braintrust/braintrust.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from autoevals.llm import Factuality
from autoevals.ragas import AnswerCorrectness
from pydantic import BaseModel

from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
Expand All @@ -18,20 +19,48 @@
ScoringResult,
ScoringResultRow,
)
from llama_stack.apis.scoring_functions import AggregationFunctionType, ScoringFn
from llama_stack.apis.scoring_functions import ScoringFn

from llama_stack.distribution.datatypes import Api

from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
DataSchemaValidatorMixin,
get_valid_schemas,
)

from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average

from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
from .config import BraintrustScoringConfig
from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
from .scoring_fn.fn_defs.factuality import factuality_fn_def


class BraintrustScoringFnEntry(BaseModel):
identifier: str
evaluator: Any
fn_def: ScoringFn


SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY = [
BraintrustScoringFnEntry(
identifier="braintrust::factuality",
evaluator=Factuality(),
fn_def=factuality_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::answer-correctness",
evaluator=AnswerCorrectness(),
fn_def=answer_correctness_fn_def,
),
]


class BraintrustScoringImpl(
Scoring, ScoringFunctionsProtocolPrivate, NeedsRequestProviderData
Scoring,
ScoringFunctionsProtocolPrivate,
NeedsRequestProviderData,
DataSchemaValidatorMixin,
):
def __init__(
self,
Expand All @@ -44,12 +73,12 @@ def __init__(
self.datasets_api = datasets_api

self.braintrust_evaluators = {
"braintrust::factuality": Factuality(),
"braintrust::answer-correctness": AnswerCorrectness(),
entry.identifier: entry.evaluator
for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
}
self.supported_fn_defs_registry = {
factuality_fn_def.identifier: factuality_fn_def,
answer_correctness_fn_def.identifier: answer_correctness_fn_def,
entry.identifier: entry.fn_def
for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
}

async def initialize(self) -> None: ...
Expand All @@ -70,23 +99,6 @@ async def register_scoring_function(self, scoring_fn: ScoringFn) -> None:
"Registering scoring function not allowed for braintrust provider"
)

async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)

for required_column in ["generated_answer", "expected_answer", "input_query"]:
if required_column not in dataset_def.dataset_schema:
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column."
)
if dataset_def.dataset_schema[required_column].type != "string":
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)

async def set_api_key(self) -> None:
# api key is in the request headers
if not self.config.openai_api_key:
Expand All @@ -102,11 +114,16 @@ async def set_api_key(self) -> None:
async def score_batch(
self,
dataset_id: str,
scoring_functions: List[str],
scoring_functions: Dict[str, Optional[ScoringFnParams]],
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.set_api_key()
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)

dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)

all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,
Expand All @@ -133,12 +150,18 @@ async def score_row(
input_query = input_row["input_query"]
evaluator = self.braintrust_evaluators[scoring_fn_identifier]

result = evaluator(generated_answer, expected_answer, input=input_query)
result = evaluator(
generated_answer,
expected_answer,
input=input_query,
)
score = result.score
return {"score": score, "metadata": result.metadata}

async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]],
) -> ScoreResponse:
await self.set_api_key()
res = {}
Expand All @@ -150,8 +173,17 @@ async def score(
await self.score_row(input_row, scoring_fn_id)
for input_row in input_rows
]
aggregation_functions = [AggregationFunctionType.average]
agg_results = aggregate_average(score_results)
aggregation_functions = self.supported_fn_defs_registry[
scoring_fn_id
].params.aggregation_functions

# override scoring_fn params if provided
if scoring_functions[scoring_fn_id] is not None:
override_params = scoring_functions[scoring_fn_id]
if override_params.aggregation_functions:
aggregation_functions = override_params.aggregation_functions

agg_results = aggregate_metrics(score_results, aggregation_functions)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
aggregated_results=agg_results,
Expand Down
Loading
Loading