Skip to content

Commit

Permalink
Bugfix Summarization Score (explodinggradients#1164)
Browse files Browse the repository at this point in the history
Hi! 

I was experimenting with Summarization Score and found an issue that
when there are no questions generated by the LLM, it throws division by
zero error due to 0 generated answers.

## Notable Changes
- Added an assertion to keyphrases, questions and answers generation
functions.
- Changed the llm response variable name `answer` to `response` to
prevent possible confusion with score related `answer`.

Please let me know if there is something I can add. 

Cheers!

---------

Co-authored-by: jjmachan <[email protected]>
  • Loading branch information
emreds and jjmachan authored Aug 10, 2024
1 parent fdba9a2 commit b7c1bdc
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 51 deletions.
9 changes: 5 additions & 4 deletions docs/howtos/customisations/run_config.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,16 @@
"\n",
"# load the dataset\n",
"from datasets import load_dataset\n",
"\n",
"amnesty_qa = load_dataset(\"explodinggradients/amnesty_qa\", \"english_v2\")\n",
"\n",
"# configure RunConfig\n",
"from ragas.run_config import RunConfig\n",
"\n",
"_ = evaluate(\n",
" dataset=amnesty_qa[\"eval\"], \n",
" dataset=amnesty_qa[\"eval\"],\n",
" metrics=[faithfulness],\n",
" run_config=RunConfig(max_workers=64), # increasing max_workers from default 16\n",
" run_config=RunConfig(max_workers=64), # increasing max_workers from default 16\n",
")"
]
},
Expand Down Expand Up @@ -94,9 +95,9 @@
],
"source": [
"_ = evaluate(\n",
" dataset=amnesty_qa[\"eval\"], \n",
" dataset=amnesty_qa[\"eval\"],\n",
" metrics=[faithfulness],\n",
" run_config=RunConfig(max_workers=2), # increasing max_workers from default 16\n",
" run_config=RunConfig(max_workers=2), # increasing max_workers from default 16\n",
")"
]
},
Expand Down
27 changes: 18 additions & 9 deletions src/ragas/metrics/_summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def _compute_score(self, scores) -> float:
"""Returns average score of the different scores."""
return sum(scores) / len(scores)

def _compute_qa_score(self, answers: t.List) -> float:
def _compute_qa_score(self, answers: t.List[str]) -> float:
"""Returns a score between 0 and 1 reflecting the fraction of
correct answers, ie with a value 'yes'
"""
Expand All @@ -209,10 +209,15 @@ async def _extract_keyphrases(self, text: str, callbacks: Callbacks) -> t.List[s
callbacks=callbacks,
)
result_text = result.generations[0][0].text
answer = await _output_parser_keyphrase_extraction.aparse(
response = await _output_parser_keyphrase_extraction.aparse(
result_text, p_value, self.llm, self.max_retries
)
return answer.keyphrases if answer else []

if not response or not response.keyphrases:
logging.error("No keyphrases generated, unable to calculate the score.")
return []

return response.keyphrases

async def _get_questions(
self, text: str, keyphrases: list[str], callbacks: Callbacks
Expand All @@ -225,13 +230,15 @@ async def _get_questions(
)

result_text = result.generations[0][0].text
answer = await _output_parser_question_generation.aparse(
response = await _output_parser_question_generation.aparse(
result_text, p_value, self.llm, self.max_retries
)
if answer is None:

if not response or not response.questions:
logging.error("No questions generated, unable to calculate the score.")
return []

return answer.questions
return response.questions

async def _get_answers(
self, questions: t.List[str], summary: str, callbacks: Callbacks
Expand All @@ -244,13 +251,15 @@ async def _get_answers(
)

result_text = result.generations[0][0].text
answer = await _output_parser_answer_generation.aparse(
response = await _output_parser_answer_generation.aparse(
result_text, p_value, self.llm, self.max_retries
)
if answer is None:

if not response or not response.answers:
logger.error("No answers generated, unable to calculate the score.")
return []

return answer.answers
return response.answers


def adapt(self, language: str, cache_dir: str | None = None) -> None:
Expand Down
9 changes: 6 additions & 3 deletions src/ragas/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,13 @@ def get_required_columns(
class Metric(ABC):
@property
@abstractmethod
def name(self) -> str: ...
def name(self) -> str:
...

@property
@abstractmethod
def evaluation_mode(self) -> EvaluationMode: ...
def evaluation_mode(self) -> EvaluationMode:
...

@abstractmethod
def init(self, run_config: RunConfig):
Expand Down Expand Up @@ -130,7 +132,8 @@ async def ascore(
return score

@abstractmethod
async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float: ...
async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
...


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions src/ragas/testset/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,5 +509,5 @@ class EvolutionElimination(BaseModel):
question_rewrite_prompt,
context_scoring_prompt,
filter_question_prompt,
evolution_elimination_prompt
]
evolution_elimination_prompt,
]
4 changes: 3 additions & 1 deletion tests/unit/test_analytics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import typing as t

import pytest


Expand Down Expand Up @@ -130,7 +132,7 @@ def test_testset_generation_tracking(monkeypatch):


def test_was_completed(monkeypatch):
from ragas._analytics import track_was_completed, IsCompleteEvent
from ragas._analytics import IsCompleteEvent, track_was_completed

event_properties_list: t.List[IsCompleteEvent] = []

Expand Down
6 changes: 4 additions & 2 deletions tests/unit/test_executor_in_jupyter.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"\n",
"exec = Executor(raise_exceptions=True)\n",
"for i in range(10):\n",
" exec.submit(sleep, i/10)\n",
" exec.submit(sleep, i / 10)\n",
"\n",
"assert exec.results(), \"didn't get anything from results\""
]
Expand Down Expand Up @@ -140,16 +140,18 @@
"source": [
"from ragas.metrics.base import Metric, EvaluationMode\n",
"\n",
"\n",
"class FakeMetric(Metric):\n",
" name = \"fake_metric\"\n",
" evaluation_mode = EvaluationMode.qa\n",
"\n",
" def init(self):\n",
" pass\n",
"\n",
" async def _ascore(self, row, callbacks)->float:\n",
" async def _ascore(self, row, callbacks) -> float:\n",
" return 0\n",
"\n",
"\n",
"fm = FakeMetric()"
]
},
Expand Down
60 changes: 30 additions & 30 deletions tests/unit/test_run_config.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,42 @@
import sys, importlib
from packaging.version import parse as parse_version
import importlib
import sys
from platform import python_version

import pytest
from numpy.random import default_rng, Generator
from numpy.random import Generator, default_rng
from packaging.version import parse as parse_version

from ragas.run_config import RunConfig

if parse_version(python_version()) < parse_version("3.10"):
from typing import NewType, Callable
RandomComparison = NewType("RandomComparison", Callable[[Generator, Generator], bool])
from typing import Callable, NewType

RandomComparison = NewType(
"RandomComparison", Callable[[Generator, Generator], bool]
)
elif parse_version(python_version()) >= parse_version("3.10"):
from typing import TypeAlias, Callable
from typing import Callable, TypeAlias

RandomComparison: TypeAlias = Callable[[Generator, Generator], bool]


@pytest.fixture(scope="function")
def compare_rng() -> Callable[[Generator, Generator], bool]:
"""Pytest fixture wrapper to check :py:cls:`numpy.random.Generator` object equivalence.
"""Pytest fixture wrapper to check :py:cls:`numpy.random.Generator` object equivalence."""

"""
def _compare_rng(rng_0:Generator, rng_1:Generator) -> bool:
def _compare_rng(rng_0: Generator, rng_1: Generator) -> bool:
"""Compare two :py:cls:`numpy.random.Generator`object.
Args:
rng_0 (numpy.random.Generator) : The first generator to compare with.
rng_1 (numpy.random.Generator) : The second generator to compare with.
Returns:
bool: Whether the two generators are at the same state.
"""
return rng_0.random() == rng_1.random()

return _compare_rng


Expand All @@ -39,9 +45,11 @@ def _compare_rng(rng_0:Generator, rng_1:Generator) -> bool:
(
[42, True],
[None, False],
)
),
)
def test_random_num_generator(seed, compare_rng:RandomComparison, expected_equivalence):
def test_random_num_generator(
seed, compare_rng: RandomComparison, expected_equivalence
):
"""Check :py:mod:`numpy.random` functionality and seed behaviour control."""
rc = RunConfig(seed=seed)

Expand All @@ -53,7 +61,7 @@ def test_random_num_generator(seed, compare_rng:RandomComparison, expected_equiv
assert compare_rng(rc.rng, rng) == expected_equivalence

# Check generation consistency
importlib.reload(sys.modules['numpy.random'])
importlib.reload(sys.modules["numpy.random"])
new_rc = RunConfig(seed=seed)
new_rng = default_rng(seed=seed)

Expand All @@ -63,22 +71,14 @@ def test_random_num_generator(seed, compare_rng:RandomComparison, expected_equiv

# Check equivalence
if expected_equivalence:
assert all(
list(
map(
compare_rng,
[rc.rng, new_rc.rng],
[new_rng, rng]
)
)
)
assert all(list(map(compare_rng, [rc.rng, new_rc.rng], [new_rng, rng])))
else:
assert all(
list(
map(
lambda x, y:not compare_rng(x, y),
[rc.rng, new_rc.rng],
[new_rng, rng]
)
list(
map(
lambda x, y: not compare_rng(x, y),
[rc.rng, new_rc.rng],
[new_rng, rng],
)
)
)

0 comments on commit b7c1bdc

Please sign in to comment.