diff --git a/src/ragas/experimental/prompt.py b/src/ragas/experimental/prompt.py index 3e2b7bfcd..6c3ddd917 100644 --- a/src/ragas/experimental/prompt.py +++ b/src/ragas/experimental/prompt.py @@ -1,5 +1,6 @@ from __future__ import annotations +import inspect import typing as t from abc import ABC, abstractmethod @@ -188,3 +189,25 @@ async def generate( callbacks=callbacks, ) return llm_result.generations[0][0].text + + +class PromptMixin: + def get_prompts(self) -> t.Dict[str, PydanticPrompt]: + prompts = {} + for name, value in inspect.getmembers(self): + if isinstance(value, PydanticPrompt): + prompts.update({name: value}) + return prompts + + def set_prompts(self, **prompts): + available_prompts = self.get_prompts() + for key, value in prompts.items(): + if key not in available_prompts: + raise ValueError( + f"Prompt with name '{key}' does not exist. Use get_prompts() to see available prompts." + ) + if not isinstance(value, PydanticPrompt): + raise ValueError( + f"Prompt with name '{key}' must be an instance of 'ragas.prompt.PydanticPrompt'" + ) + setattr(self, key, value) diff --git a/src/ragas/experimental/testset/simulators/base.py b/src/ragas/experimental/testset/simulators/base.py index dd35383eb..12b86a4f3 100644 --- a/src/ragas/experimental/testset/simulators/base.py +++ b/src/ragas/experimental/testset/simulators/base.py @@ -7,6 +7,7 @@ from pydantic import BaseModel +from ragas.experimental.prompt import PromptMixin from ragas.callbacks import new_group from ragas.experimental.testset.graph import KnowledgeGraph, Node from ragas.llms import BaseRagasLLM, llm_factory @@ -40,7 +41,7 @@ class BaseScenario(BaseModel): @dataclass -class BaseSimulator(ABC, t.Generic[Scenario]): +class BaseSimulator(ABC, t.Generic[Scenario], PromptMixin): name: str = "" llm: BaseRagasLLM = field(default_factory=llm_factory) diff --git a/src/ragas/experimental/testset/transforms/base.py b/src/ragas/experimental/testset/transforms/base.py index 67e1b08b8..25ce7384a 100644 --- a/src/ragas/experimental/testset/transforms/base.py +++ b/src/ragas/experimental/testset/transforms/base.py @@ -3,12 +3,17 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field +from ragas.experimental.prompt import PromptMixin from ragas.experimental.testset.graph import KnowledgeGraph, Node, Relationship from ragas.llms import BaseRagasLLM, llm_factory logger = logging.getLogger(__name__) +def default_filter(node: Node) -> bool: + return True + + @dataclass class BaseGraphTransformation(ABC): """ @@ -90,7 +95,9 @@ class Extractor(BaseGraphTransformation): Abstract method to extract a specific property from a node. """ - filter_nodes: t.Callable[[Node], bool] = lambda _: True + filter_nodes: t.Callable[[Node], bool] = field( + default_factory=lambda: default_filter + ) async def transform( self, kg: KnowledgeGraph @@ -180,7 +187,7 @@ def filter(self, kg: KnowledgeGraph) -> KnowledgeGraph: @dataclass -class LLMBasedExtractor(Extractor): +class LLMBasedExtractor(Extractor, PromptMixin): llm: BaseRagasLLM = field(default_factory=llm_factory) merge_if_possible: bool = True diff --git a/src/ragas/metrics/base.py b/src/ragas/metrics/base.py index a1ac702f2..dedbe612a 100644 --- a/src/ragas/metrics/base.py +++ b/src/ragas/metrics/base.py @@ -27,12 +27,10 @@ from ragas.embeddings import BaseRagasEmbeddings from ragas.llms import BaseRagasLLM -import inspect - from pysbd import Segmenter from pysbd.languages import LANGUAGE_CODES -from ragas.experimental.prompt import PydanticPrompt as Prompt +from ragas.experimental.prompt import PromptMixin logger = logging.getLogger(__name__) @@ -60,7 +58,8 @@ class Metric(ABC): @property @abstractmethod - def name(self) -> str: ... + def name(self) -> str: + ... @property def required_columns(self) -> t.Dict[str, t.Set[str]]: @@ -147,11 +146,12 @@ async def ascore( return score @abstractmethod - async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float: ... + async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float: + ... @dataclass -class MetricWithLLM(Metric): +class MetricWithLLM(Metric, PromptMixin): llm: t.Optional[BaseRagasLLM] = None def init(self, run_config: RunConfig): @@ -166,26 +166,6 @@ def init(self, run_config: RunConfig): ) self.llm.set_run_config(run_config) - def get_prompts(self) -> t.Dict[str, Prompt]: - prompts = {} - for name, value in inspect.getmembers(self): - if isinstance(value, Prompt): - prompts.update({name: value}) - return prompts - - def set_prompts(self, **prompts): - available_prompts = self.get_prompts() - for key, value in prompts.items(): - if key not in available_prompts: - raise ValueError( - f"Prompt with name '{key}' does not exist in the metric {self.name}. Use get_prompts() to see available prompts." - ) - if not isinstance(value, Prompt): - raise ValueError( - f"Prompt with name '{key}' must be an instance of 'Prompt'" - ) - setattr(self, key, value) - @dataclass class MetricWithEmbeddings(Metric): @@ -254,7 +234,8 @@ async def _single_turn_ascore( self, sample: SingleTurnSample, callbacks: Callbacks, - ) -> float: ... + ) -> float: + ... class MultiTurnMetric(Metric): @@ -306,7 +287,8 @@ async def _multi_turn_ascore( self, sample: MultiTurnSample, callbacks: Callbacks, - ) -> float: ... + ) -> float: + ... class Ensember: