Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rag evals][2/n] add more braintrust scoring fns for RAG eval #666

Open
wants to merge 6 commits into
base: rag_scoring_fn_1
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 55 additions & 2 deletions llama_stack/providers/inline/scoring/braintrust/braintrust.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,16 @@
from typing import Any, Dict, List, Optional

from autoevals.llm import Factuality
from autoevals.ragas import AnswerCorrectness
from autoevals.ragas import (
AnswerCorrectness,
AnswerRelevancy,
AnswerSimilarity,
ContextEntityRecall,
ContextPrecision,
ContextRecall,
ContextRelevancy,
Faithfulness,
)
from pydantic import BaseModel

from llama_stack.apis.datasetio import DatasetIO
Expand All @@ -19,7 +28,7 @@
ScoringResult,
ScoringResultRow,
)
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams

from llama_stack.distribution.datatypes import Api

Expand All @@ -33,7 +42,14 @@
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
from .config import BraintrustScoringConfig
from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
from .scoring_fn.fn_defs.answer_relevancy import answer_relevancy_fn_def
from .scoring_fn.fn_defs.answer_similarity import answer_similarity_fn_def
from .scoring_fn.fn_defs.context_entity_recall import context_entity_recall_fn_def
from .scoring_fn.fn_defs.context_precision import context_precision_fn_def
from .scoring_fn.fn_defs.context_recall import context_recall_fn_def
from .scoring_fn.fn_defs.context_relevancy import context_relevancy_fn_def
from .scoring_fn.fn_defs.factuality import factuality_fn_def
from .scoring_fn.fn_defs.faithfulness import faithfulness_fn_def


class BraintrustScoringFnEntry(BaseModel):
Expand All @@ -53,6 +69,41 @@ class BraintrustScoringFnEntry(BaseModel):
evaluator=AnswerCorrectness(),
fn_def=answer_correctness_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::answer-relevancy",
evaluator=AnswerRelevancy(),
fn_def=answer_relevancy_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::answer-similarity",
evaluator=AnswerSimilarity(),
fn_def=answer_similarity_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::faithfulness",
evaluator=Faithfulness(),
fn_def=faithfulness_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::context-entity-recall",
evaluator=ContextEntityRecall(),
fn_def=context_entity_recall_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::context-precision",
evaluator=ContextPrecision(),
fn_def=context_precision_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::context-recall",
evaluator=ContextRecall(),
fn_def=context_recall_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::context-relevancy",
evaluator=ContextRelevancy(),
fn_def=context_relevancy_fn_def,
),
]


Expand Down Expand Up @@ -143,6 +194,7 @@ async def score_batch(
async def score_row(
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
) -> ScoringResultRow:
self.validate_row_schema(input_row, get_valid_schemas(Api.scoring.value))
await self.set_api_key()
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
expected_answer = input_row["expected_answer"]
Expand All @@ -154,6 +206,7 @@ async def score_row(
generated_answer,
expected_answer,
input=input_query,
context=input_row["context"] if "context" in input_row else None,
)
score = result.score
return {"score": score, "metadata": result.metadata}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)

answer_relevancy_fn_def = ScoringFn(
identifier="braintrust::answer-relevancy",
description=(
"Test output relevancy against the input query using Braintrust LLM scorer. "
"See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="answer-relevancy",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)

answer_similarity_fn_def = ScoringFn(
identifier="braintrust::answer-similarity",
description=(
"Test output similarity against expected value using Braintrust LLM scorer. "
"See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="answer-similarity",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)

context_entity_recall_fn_def = ScoringFn(
identifier="braintrust::context-entity-recall",
description=(
"Evaluates how well the context captures the named entities present in the "
"reference answer. See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="context-entity-recall",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)

context_precision_fn_def = ScoringFn(
identifier="braintrust::context-precision",
description=(
"Measures how much of the provided context is actually relevant to answering the "
"question. See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="context-precision",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)

context_recall_fn_def = ScoringFn(
identifier="braintrust::context-recall",
description=(
"Evaluates how well the context covers the information needed to answer the "
"question. See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="context-recall",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)

context_relevancy_fn_def = ScoringFn(
identifier="braintrust::context-relevancy",
description=(
"Assesses how relevant the provided context is to the given question. "
"See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="context-relevancy",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)

faithfulness_fn_def = ScoringFn(
identifier="braintrust::faithfulness",
description=(
"Test output faithfulness to the input query using Braintrust LLM scorer. "
"See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="faithfulness",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)
17 changes: 15 additions & 2 deletions llama_stack/providers/tests/datasetio/test_datasetio.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,15 @@ def data_url_from_file(file_path: str) -> str:


async def register_dataset(
datasets_impl: Datasets, for_generation=False, dataset_id="test_dataset"
datasets_impl: Datasets,
for_generation=False,
for_rag=False,
dataset_id="test_dataset",
):
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
if for_rag:
test_file = Path(os.path.abspath(__file__)).parent / "test_rag_dataset.csv"
else:
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
test_url = data_url_from_file(str(test_file))

if for_generation:
Expand All @@ -49,6 +55,13 @@ async def register_dataset(
"input_query": StringType(),
"chat_completion_input": ChatCompletionInputType(),
}
elif for_rag:
dataset_schema = {
"expected_answer": StringType(),
"input_query": StringType(),
"generated_answer": StringType(),
"context": StringType(),
}
else:
dataset_schema = {
"expected_answer": StringType(),
Expand Down
6 changes: 6 additions & 0 deletions llama_stack/providers/tests/datasetio/test_rag_dataset.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
input_query,context,generated_answer,expected_answer
What is the capital of France?,"France is a country in Western Europe with a population of about 67 million people. Its capital city has been a major European cultural center since the 17th century and is known for landmarks like the Eiffel Tower and the Louvre Museum.",London,Paris
Who is the CEO of Meta?,"Meta Platforms, formerly known as Facebook, is one of the world's largest technology companies. Founded by Mark Zuckerberg in 2004, the company has expanded to include platforms like Instagram, WhatsApp, and virtual reality technologies.",Mark Zuckerberg,Mark Zuckerberg
What is the largest planet in our solar system?,"The solar system consists of eight planets orbiting around the Sun. These planets, in order from the Sun, are Mercury, Venus, Earth, Mars, Jupiter, Saturn, Uranus, and Neptune. Gas giants are significantly larger than terrestrial planets.",Jupiter,Jupiter
What is the smallest country in the world?,"Independent city-states and micronations are among the world's smallest sovereign territories. Some notable examples include Monaco, San Marino, and Vatican City, which is an enclave within Rome, Italy.",China,Vatican City
What is the currency of Japan?,"Japan is an island country in East Asia with a rich cultural heritage and one of the world's largest economies. Its financial system has been established since the Meiji period, with its modern currency being introduced in 1871.",Yen,Yen
6 changes: 3 additions & 3 deletions llama_stack/providers/tests/scoring/test_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def test_scoring_score(self, scoring_stack):
f"{provider_id} provider does not support scoring without params"
)

await register_dataset(datasets_impl)
await register_dataset(datasets_impl, for_rag=True)
response = await datasets_impl.list_datasets()
assert len(response) == 1

Expand Down Expand Up @@ -112,7 +112,7 @@ async def test_scoring_score_with_params_llm_as_judge(
scoring_stack[Api.datasets],
scoring_stack[Api.models],
)
await register_dataset(datasets_impl)
await register_dataset(datasets_impl, for_rag=True)
response = await datasets_impl.list_datasets()
assert len(response) == 1

Expand Down Expand Up @@ -173,7 +173,7 @@ async def test_scoring_score_with_aggregation_functions(
scoring_stack[Api.datasets],
scoring_stack[Api.models],
)
await register_dataset(datasets_impl)
await register_dataset(datasets_impl, for_rag=True)
rows = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=3,
Expand Down
Loading