Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: get and set prompts mixin #1391

Merged
merged 6 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/ragas/experimental/testset/simulators/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

import inspect
import typing as t
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum

from pydantic import BaseModel

from ragas.experimental.prompt import PydanticPrompt as Prompt
from ragas.experimental.testset.graph import KnowledgeGraph, Node
from ragas.llms import BaseRagasLLM, llm_factory

Expand Down Expand Up @@ -54,3 +56,23 @@ async def generate_scenarios(
@abstractmethod
async def generate_sample(self, scenario: Scenario) -> BaseEvalSample:
pass

def get_prompts(self) -> t.Dict[str, Prompt]:
prompts = {}
for name, value in inspect.getmembers(self):
if isinstance(value, Prompt):
prompts.update({name: value})
return prompts

def set_prompts(self, **prompts):
available_prompts = self.get_prompts()
for key, value in prompts.items():
if key not in available_prompts:
raise ValueError(
f"Prompt with name '{key}' does not exist in the simulator {self.name}. Use get_prompts() to see available prompts."
)
if not isinstance(value, Prompt):
raise ValueError(
f"Prompt with name '{key}' must be an instance of 'Prompt'"
)
setattr(self, key, value)
28 changes: 27 additions & 1 deletion src/ragas/experimental/testset/transforms/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import inspect
import logging
import typing as t
from abc import ABC, abstractmethod
from dataclasses import dataclass, field

from ragas.experimental.prompt import PydanticPrompt as Prompt
from ragas.experimental.testset.graph import KnowledgeGraph, Node, Relationship
from ragas.llms import BaseRagasLLM, llm_factory

Expand Down Expand Up @@ -90,7 +92,11 @@ class Extractor(BaseGraphTransformation):
Abstract method to extract a specific property from a node.
"""

filter_nodes: t.Callable[[Node], bool] = lambda _: True
@staticmethod
def default_filter(node: Node) -> bool:
return True

filter_nodes: t.Callable[[Node], bool] = default_filter

async def transform(
self, kg: KnowledgeGraph
Expand Down Expand Up @@ -184,6 +190,26 @@ class LLMBasedExtractor(Extractor):
llm: BaseRagasLLM = field(default_factory=llm_factory)
merge_if_possible: bool = True

def get_prompts(self) -> t.Dict[str, Prompt]:
prompts = {}
for name, value in inspect.getmembers(self):
if isinstance(value, Prompt):
prompts.update({name: value})
return prompts

def set_prompts(self, **prompts):
available_prompts = self.get_prompts()
for key, value in prompts.items():
if key not in available_prompts:
raise ValueError(
f"Prompt with name '{key}' does not exist in the extractor. Use get_prompts() to see available prompts."
)
if not isinstance(value, Prompt):
raise ValueError(
f"Prompt with name '{key}' must be an instance of 'Prompt'"
)
setattr(self, key, value)


class Splitter(BaseGraphTransformation):
"""
Expand Down
Loading