diff --git a/src/raglite/_chainlit.py b/src/raglite/_chainlit.py index 1f3eeeb..f3e15d9 100644 --- a/src/raglite/_chainlit.py +++ b/src/raglite/_chainlit.py @@ -1,27 +1,17 @@ """Chainlit frontend for RAGLite.""" +import json import os from pathlib import Path import chainlit as cl from chainlit.input_widget import Switch, TextInput -from raglite import ( - RAGLiteConfig, - async_rag, - create_rag_instruction, - hybrid_search, - insert_document, - rerank_chunks, - retrieve_chunk_spans, - retrieve_chunks, -) +from raglite import RAGLiteConfig, async_rag, hybrid_search, insert_document, rerank_chunks from raglite._markdown import document_to_markdown async_insert_document = cl.make_async(insert_document) async_hybrid_search = cl.make_async(hybrid_search) -async_retrieve_chunks = cl.make_async(retrieve_chunks) -async_retrieve_chunk_spans = cl.make_async(retrieve_chunk_spans) async_rerank_chunks = cl.make_async(rerank_chunks) @@ -93,31 +83,27 @@ async def handle_message(user_message: cl.Message) -> None: for i, attachment in enumerate(inline_attachments) ) + f"\n\n{user_message.content}" - ) - # Search for relevant contexts for RAG. - async with cl.Step(name="search", type="retrieval") as step: - step.input = user_message.content - chunk_ids, _ = await async_hybrid_search(query=user_prompt, num_results=10, config=config) - chunks = await async_retrieve_chunks(chunk_ids=chunk_ids, config=config) - step.output = chunks - step.elements = [ # Show the top chunks inline. - cl.Text(content=str(chunk), display="inline") for chunk in chunks[:5] - ] - await step.update() # TODO: Workaround for https://github.com/Chainlit/chainlit/issues/602. - # Rerank the chunks and group them into chunk spans. - async with cl.Step(name="rerank", type="rerank") as step: - step.input = chunks - chunks = await async_rerank_chunks(query=user_prompt, chunk_ids=chunks, config=config) - chunk_spans = await async_retrieve_chunk_spans(chunks[:5], config=config) - step.output = chunk_spans - step.elements = [ # Show the top chunk spans inline. - cl.Text(content=str(chunk_span), display="inline") for chunk_span in chunk_spans - ] - await step.update() # TODO: Workaround for https://github.com/Chainlit/chainlit/issues/602. + ).strip() # Stream the LLM response. assistant_message = cl.Message(content="") messages: list[dict[str, str]] = cl.chat_context.to_openai()[:-1] # type: ignore[no-untyped-call] - messages.append(create_rag_instruction(user_prompt=user_prompt, context=chunk_spans)) + messages.append({"role": "user", "content": user_prompt}) async for token in async_rag(messages, config=config): await assistant_message.stream_token(token) + # Append RAG sources if any. + if messages[-2]["role"] == "tool": + rag_context = json.loads(messages[-2]["content"]) + rag_sources: dict[str, list[str]] = {} + for document in rag_context["documents"]: + rag_sources.setdefault(document["source"], []) + rag_sources[document["source"]].append( + document["span"]["headings"] + "\n" + document["span"]["content"] + ) + assistant_message.content += "\n\nSources: " + ", ".join( # Rendered as hyperlinks. + f"[{i + 1}]" for i in range(len(rag_sources)) + ) + assistant_message.elements = [ # Markdown content is rendered in sidebar. + cl.Text(name=f"[{i + 1}]", content="\n\n---\n\n".join(content), display="side") # type: ignore[misc] + for i, (_, content) in enumerate(rag_sources.items()) + ] await assistant_message.update() # type: ignore[no-untyped-call]