diff --git a/core/quivr_core/rag/quivr_rag.py b/core/quivr_core/rag/quivr_rag.py index 38502eb4ddd7..dea9ace4011d 100644 --- a/core/quivr_core/rag/quivr_rag.py +++ b/core/quivr_core/rag/quivr_rag.py @@ -12,9 +12,9 @@ from langchain_core.runnables import RunnableLambda, RunnablePassthrough from langchain_core.vectorstores import VectorStore +from quivr_core.llm import LLMEndpoint from quivr_core.rag.entities.chat import ChatHistory from quivr_core.rag.entities.config import RetrievalConfig -from quivr_core.llm import LLMEndpoint from quivr_core.rag.entities.models import ( ParsedRAGChunkResponse, ParsedRAGResponse, @@ -24,6 +24,7 @@ ) from quivr_core.rag.prompts import custom_prompts from quivr_core.rag.utils import ( + LangfuseService, combine_documents, format_file_list, get_chunk_metadata, @@ -32,6 +33,8 @@ ) logger = logging.getLogger("quivr_core") +langfuse_service = LangfuseService() +langfuse_handler = langfuse_service.get_handler() class IdempotentCompressor(BaseDocumentCompressor): @@ -173,7 +176,7 @@ def answer( "chat_history": history, "custom_instructions": (self.retrieval_config.prompt), }, - config={"metadata": metadata}, + config={"metadata": metadata, "callbacks": [langfuse_handler]}, ) response = parse_response( raw_llm_response, self.retrieval_config.llm_config.model @@ -206,7 +209,7 @@ async def answer_astream( "chat_history": history, "custom_personality": (self.retrieval_config.prompt), }, - config={"metadata": metadata}, + config={"metadata": metadata, "callbacks": [langfuse_handler]}, ): # Could receive this anywhere so we need to save it for the last chunk if "docs" in chunk: diff --git a/core/quivr_core/rag/quivr_rag_langgraph.py b/core/quivr_core/rag/quivr_rag_langgraph.py index 3fd3349bd9a9..66069a4b1076 100644 --- a/core/quivr_core/rag/quivr_rag_langgraph.py +++ b/core/quivr_core/rag/quivr_rag_langgraph.py @@ -29,8 +29,6 @@ from langgraph.types import Send from pydantic import BaseModel, Field -from langfuse.callback import CallbackHandler - from quivr_core.llm import LLMEndpoint from quivr_core.llm_tools.llm_tools import LLMToolFactory from quivr_core.rag.entities.chat import ChatHistory @@ -41,6 +39,7 @@ ) from quivr_core.rag.prompts import custom_prompts from quivr_core.rag.utils import ( + LangfuseService, collect_tools, combine_documents, format_file_list, @@ -50,8 +49,8 @@ logger = logging.getLogger("quivr_core") -# Initialize Langfuse CallbackHandler for Langchain (tracing) -langfuse_handler = CallbackHandler() +langfuse_service = LangfuseService() +langfuse_handler = langfuse_service.get_handler() class SplittedInput(BaseModel): @@ -502,7 +501,7 @@ async def rewrite(self, state: AgentState) -> AgentState: task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else [] # Replace each question with its condensed version - for response, task_id in zip(responses, task_ids): + for response, task_id in zip(responses, task_ids, strict=False): tasks.set_definition(task_id, response.content) return {**state, "tasks": tasks} @@ -558,7 +557,7 @@ async def tool_routing(self, state: AgentState): ) task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else [] - for response, task_id in zip(responses, task_ids): + for response, task_id in zip(responses, task_ids, strict=False): tasks.set_completion(task_id, response.is_task_completable) if not response.is_task_completable and response.tool: tasks.set_tool(task_id, response.tool) @@ -599,7 +598,7 @@ async def run_tool(self, state: AgentState) -> AgentState: ) task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else [] - for response, task_id in zip(responses, task_ids): + for response, task_id in zip(responses, task_ids, strict=False): _docs = tool_wrapper.format_output(response) _docs = self.filter_chunks_by_relevance(_docs) tasks.set_docs(task_id, _docs) @@ -652,7 +651,7 @@ async def retrieve(self, state: AgentState) -> AgentState: task_ids = [task[1] for task in async_jobs] if async_jobs else [] # Process responses and associate docs with tasks - for response, task_id in zip(responses, task_ids): + for response, task_id in zip(responses, task_ids, strict=False): _docs = self.filter_chunks_by_relevance(response) tasks.set_docs(task_id, _docs) # Associate docs with the specific task @@ -715,7 +714,7 @@ async def dynamic_retrieve(self, state: AgentState) -> AgentState: task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else [] _n = [] - for response, task_id in zip(responses, task_ids): + for response, task_id in zip(responses, task_ids, strict=False): _docs = self.filter_chunks_by_relevance(response) _n.append(len(_docs)) tasks.set_docs(task_id, _docs) diff --git a/core/quivr_core/rag/utils.py b/core/quivr_core/rag/utils.py index f9c053555475..0643d25e8735 100644 --- a/core/quivr_core/rag/utils.py +++ b/core/quivr_core/rag/utils.py @@ -4,6 +4,7 @@ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.messages.ai import AIMessageChunk from langchain_core.prompts import format_document +from langfuse.callback import CallbackHandler from quivr_core.rag.entities.config import WorkflowConfig from quivr_core.rag.entities.models import ( @@ -195,3 +196,11 @@ def collect_tools(workflow_config: WorkflowConfig): activated_tools += f"Tool {i+1} description: {tool.description}\n\n" return validated_tools, activated_tools + + +class LangfuseService: + def __init__(self): + self.langfuse_handler = CallbackHandler() + + def get_handler(self): + return self.langfuse_handler