Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rag evals][1/n] refactor base scoring fn & data schema check #664

Merged
merged 12 commits into from
Jan 2, 2025
4 changes: 2 additions & 2 deletions llama_stack/apis/scoring/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ class Scoring(Protocol):
async def score_batch(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
scoring_functions: Dict[str, Optional[ScoringFnParams]],
save_results_dataset: bool = False,
) -> ScoreBatchResponse: ...

@webmethod(route="/scoring/score")
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
scoring_functions: Dict[str, Optional[ScoringFnParams]],
) -> ScoreResponse: ...
72 changes: 27 additions & 45 deletions llama_stack/providers/inline/eval/meta_reference/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,24 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Optional

from tqdm import tqdm

from llama_stack.apis.agents import Agents
from llama_stack.apis.common.type_system import (
ChatCompletionInputType,
CompletionInputType,
StringType,
)
from llama_stack.apis.agents import Agents, StepType
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval_tasks import EvalTask
from llama_stack.apis.inference import Inference, UserMessage
from llama_stack.apis.scoring import Scoring
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate

from llama_stack.providers.utils.common.data_schema_validator import (
ColumnName,
DataSchemaValidatorMixin,
get_valid_schemas,
)
from llama_stack.providers.utils.kvstore import kvstore_impl

from .....apis.common.job_types import Job
Expand All @@ -30,15 +31,7 @@
EVAL_TASKS_PREFIX = "eval_tasks:"


class ColumnName(Enum):
input_query = "input_query"
expected_answer = "expected_answer"
chat_completion_input = "chat_completion_input"
completion_input = "completion_input"
generated_answer = "generated_answer"


class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate, DataSchemaValidatorMixin):
def __init__(
self,
config: MetaReferenceEvalConfig,
Expand Down Expand Up @@ -82,29 +75,6 @@ async def register_eval_task(self, task_def: EvalTask) -> None:
)
self.eval_tasks[task_def.identifier] = task_def

async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")

expected_schemas = [
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
},
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.completion_input.value: CompletionInputType(),
},
]

if dataset_def.dataset_schema not in expected_schemas:
raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
)

async def run_eval(
self,
task_id: str,
Expand All @@ -114,8 +84,10 @@ async def run_eval(
dataset_id = task_def.dataset_id
candidate = task_config.eval_candidate
scoring_functions = task_def.scoring_functions

await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.eval.value)
)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=(
Expand Down Expand Up @@ -167,11 +139,21 @@ async def _run_agent_generation(
)
]
final_event = turn_response[-1].event.payload
generations.append(
{
ColumnName.generated_answer.value: final_event.turn.output_message.content
}

# check if there's a memory retrieval step and extract the context
memory_rag_context = None
for step in final_event.turn.steps:
if step.step_type == StepType.memory_retrieval.value:
memory_rag_context = " ".join(x.text for x in step.inserted_context)

agent_generation = {}
agent_generation[ColumnName.generated_answer.value] = (
final_event.turn.output_message.content
)
if memory_rag_context:
agent_generation[ColumnName.context.value] = memory_rag_context

generations.append(agent_generation)

return generations

Expand Down
34 changes: 14 additions & 20 deletions llama_stack/providers/inline/scoring/basic/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@
ScoringResult,
)
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate

from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
DataSchemaValidatorMixin,
get_valid_schemas,
)
from .config import BasicScoringConfig
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
Expand All @@ -24,7 +29,9 @@
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]


class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
class BasicScoringImpl(
Scoring, ScoringFunctionsProtocolPrivate, DataSchemaValidatorMixin
):
def __init__(
self,
config: BasicScoringConfig,
Expand Down Expand Up @@ -61,30 +68,17 @@ async def list_scoring_functions(self) -> List[ScoringFn]:
async def register_scoring_function(self, function_def: ScoringFn) -> None:
raise NotImplementedError("Register scoring function not implemented yet")

async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)

for required_column in ["generated_answer", "expected_answer", "input_query"]:
if required_column not in dataset_def.dataset_schema:
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column."
)
if dataset_def.dataset_schema[required_column].type != "string":
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)

async def score_batch(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)

all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from llama_stack.apis.scoring import ScoringResultRow

from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn

from .fn_defs.equality import equality


class EqualityScoringFn(BaseScoringFn):
class EqualityScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@

from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn

from .fn_defs.regex_parser_multiple_choice_answer import (
regex_parser_multiple_choice_answer,
)


class RegexParserScoringFn(BaseScoringFn):
class RegexParserScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn that parses answer from generated response according to context and check match with expected_answer.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn

from .fn_defs.subset_of import subset_of


class SubsetOfScoringFn(BaseScoringFn):
class SubsetOfScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise.
"""
Expand Down
Loading
Loading