Skip to content

Commit

Permalink
feat: support for callbacks and traces in testset generation (explodi…
Browse files Browse the repository at this point in the history
  • Loading branch information
jjmachan authored and shahules786 committed Oct 2, 2024
1 parent 8985620 commit f4a22c9
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 37 deletions.
3 changes: 2 additions & 1 deletion docs/getstarted/rag_testset_generation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@
"source": [
"from langchain_openai import ChatOpenAI\n",
"from ragas.llms.base import LangchainLLMWrapper\n",
"openai_model = LangchainLLMWrapper(ChatOpenAI(model_name=\"gpt-4o\"))\n"
"\n",
"openai_model = LangchainLLMWrapper(ChatOpenAI(model_name=\"gpt-4o\"))"
]
},
{
Expand Down
46 changes: 29 additions & 17 deletions src/ragas/experimental/testset/simulators/abstract_qa.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import logging
import math
import random
Expand All @@ -24,6 +26,9 @@
Themes,
)

if t.TYPE_CHECKING:
from langchain_core.callbacks import Callbacks

logger = logging.getLogger(__name__)


Expand All @@ -41,8 +46,8 @@ def __post_init__(self):
super().__post_init__()
self.common_theme_prompt = CommonThemeFromSummaries()

async def generate_scenarios(
self, n: int, knowledge_graph: KnowledgeGraph
async def _generate_scenarios(
self, n: int, knowledge_graph: KnowledgeGraph, callbacks: Callbacks
) -> t.List[AbstractQuestionScenario]:
node_clusters = knowledge_graph.find_clusters(
relationship_condition=lambda rel: (
Expand Down Expand Up @@ -89,7 +94,7 @@ async def generate_scenarios(
summaries=summaries,
num_themes=num_themes,
)
kw_list.append({"data": summaries, "llm": self.llm})
kw_list.append({"data": summaries, "llm": self.llm, "callbacks": callbacks})

themes: t.List[Themes] = run_async_batch(
desc="Generating common themes",
Expand Down Expand Up @@ -129,14 +134,14 @@ async def generate_scenarios(
)
return distributions

async def generate_sample(
self, scenario: AbstractQuestionScenario
async def _generate_sample(
self, scenario: AbstractQuestionScenario, callbacks: Callbacks
) -> SingleTurnSample:
user_input = await self.generate_user_input(scenario)
if await self.critic_question(user_input):
user_input = await self.modify_question(user_input, scenario)
user_input = await self.generate_user_input(scenario, callbacks)
if await self.critic_question(user_input, callbacks):
user_input = await self.modify_question(user_input, scenario, callbacks)

reference = await self.generate_answer(user_input, scenario)
reference = await self.generate_answer(user_input, scenario, callbacks)

reference_contexts = []
for node in scenario.nodes:
Expand All @@ -149,13 +154,16 @@ async def generate_sample(
reference_contexts=reference_contexts,
)

async def generate_user_input(self, scenario: AbstractQuestionScenario) -> str:
async def generate_user_input(
self, scenario: AbstractQuestionScenario, callbacks: Callbacks
) -> str:
question = await self.generate_user_input_prompt.generate(
data=ThemeAndContext(
theme=scenario.theme,
context=self.make_source_text(scenario),
),
llm=self.llm,
callbacks=callbacks,
)
return question.text

Expand All @@ -173,8 +181,8 @@ class ComparativeAbstractQuestionSimulator(QASimulator):
default_factory=ComparativeAbstractQuestion
)

async def generate_scenarios(
self, n: int, knowledge_graph: KnowledgeGraph
async def _generate_scenarios(
self, n: int, knowledge_graph: KnowledgeGraph, callbacks: Callbacks
) -> t.List[ComparativeAbstractQuestionScenario]:
node_clusters = knowledge_graph.find_clusters(
relationship_condition=lambda rel: (
Expand Down Expand Up @@ -220,6 +228,7 @@ async def generate_scenarios(
num_concepts=num_concepts,
),
"llm": self.llm,
"callbacks": callbacks,
}
)

Expand Down Expand Up @@ -264,8 +273,8 @@ async def generate_scenarios(
)
return scenarios

async def generate_sample(
self, scenario: ComparativeAbstractQuestionScenario
async def _generate_sample(
self, scenario: ComparativeAbstractQuestionScenario, callbacks: Callbacks
) -> SingleTurnSample:
# generate the user input
keyphrases = []
Expand All @@ -285,15 +294,18 @@ async def generate_sample(
summaries=summaries,
),
llm=self.llm,
callbacks=callbacks,
)
question = question.text

# critic the question
if not await self.critic_question(question):
question = await self.modify_question(question, scenario)
if not await self.critic_question(question, callbacks):
question = await self.modify_question(question, scenario, callbacks)

# generate the answer
answer = await self.generate_answer(question, scenario, "summary")
answer = await self.generate_answer(
question, scenario, callbacks, reference_property_name="summary"
)

# make the reference contexts
# TODO: make this more efficient. Right now we are taking only the summary
Expand Down
45 changes: 42 additions & 3 deletions src/ragas/experimental/testset/simulators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@

from pydantic import BaseModel

from ragas.callbacks import new_group
from ragas.experimental.testset.graph import KnowledgeGraph, Node
from ragas.llms import BaseRagasLLM, llm_factory

if t.TYPE_CHECKING:
from langchain_core.callbacks import Callbacks

from ragas.dataset_schema import BaseEvalSample


Expand Down Expand Up @@ -45,12 +48,48 @@ def __post_init__(self):
if not self.name:
self.name = self.__class__.__name__

@abstractmethod
async def generate_scenarios(
self, n: int, knowledge_graph: KnowledgeGraph
self,
n: int,
knowledge_graph: KnowledgeGraph,
callbacks: t.Optional[Callbacks] = None,
) -> t.List[Scenario]:
callbacks = callbacks or []
scenario_generation_rm, scenario_generation_group = new_group(
name=self.name,
inputs={"n": n, "knowledge_graph": str(knowledge_graph)},
callbacks=callbacks,
)
scenarios = await self._generate_scenarios(
n, knowledge_graph, scenario_generation_group
)
scenario_generation_rm.on_chain_end(outputs={"scenarios": scenarios})
return scenarios

@abstractmethod
async def _generate_scenarios(
self, n: int, knowledge_graph: KnowledgeGraph, callbacks: Callbacks
) -> t.List[Scenario]:
pass

async def generate_sample(
self, scenario: Scenario, callbacks: t.Optional[Callbacks] = None
) -> BaseEvalSample:
callbacks = callbacks or []

# new group for Sample Generation
sample_generation_rm, sample_generation_grp = new_group(
name=self.name,
inputs={"scenario": scenario},
callbacks=callbacks,
)
sample = await self._generate_sample(scenario, sample_generation_grp)
sample_generation_rm.on_chain_end(outputs={"sample": sample})

return sample

@abstractmethod
async def generate_sample(self, scenario: Scenario) -> BaseEvalSample:
async def _generate_sample(
self, scenario: Scenario, callbacks: Callbacks
) -> BaseEvalSample:
pass
22 changes: 19 additions & 3 deletions src/ragas/experimental/testset/simulators/base_qa.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from __future__ import annotations

import typing as t
from dataclasses import dataclass, field

from ragas.experimental.prompt import StringIO
Expand All @@ -13,6 +16,9 @@
extend_modify_input_prompt,
)

if t.TYPE_CHECKING:
from langchain_core.callbacks import Callbacks


@dataclass
class QASimulator(BaseSimulator[Scenario]):
Expand All @@ -22,13 +28,19 @@ class QASimulator(BaseSimulator[Scenario]):
)
generate_reference_prompt: PydanticPrompt = field(default_factory=GenerateReference)

async def critic_question(self, question: str) -> bool:
async def critic_question(
self, question: str, callbacks: t.Optional[Callbacks] = None
) -> bool:
callbacks = callbacks or []
critic = await self.critic_user_input_prompt.generate(
data=StringIO(text=question), llm=self.llm
data=StringIO(text=question), llm=self.llm, callbacks=callbacks
)
return critic.independence > 1 and critic.clear_intent > 1

async def modify_question(self, question: str, scenario: Scenario) -> str:
async def modify_question(
self, question: str, scenario: Scenario, callbacks: t.Optional[Callbacks] = None
) -> str:
callbacks = callbacks or []
prompt = extend_modify_input_prompt(
question_modification_prompt=self.user_input_modification_prompt,
style=scenario.style,
Expand All @@ -41,21 +53,25 @@ async def modify_question(self, question: str, scenario: Scenario) -> str:
length=scenario.length,
),
llm=self.llm,
callbacks=callbacks,
)
return modified_question.text

async def generate_answer(
self,
question: str,
scenario: Scenario,
callbacks: t.Optional[Callbacks] = None,
reference_property_name: str = "page_content",
) -> str:
callbacks = callbacks or []
reference = await self.generate_reference_prompt.generate(
data=UserInputAndContext(
user_input=question,
context=self.make_source_text(scenario, reference_property_name),
),
llm=self.llm,
callbacks=callbacks,
)
return reference.text

Expand Down
Loading

0 comments on commit f4a22c9

Please sign in to comment.