Skip to content

Commit

Permalink
feat: add tool use to Chainlit
Browse files Browse the repository at this point in the history
  • Loading branch information
lsorber committed Dec 9, 2024
1 parent aa5cc80 commit 93a75b9
Showing 1 changed file with 20 additions and 34 deletions.
54 changes: 20 additions & 34 deletions src/raglite/_chainlit.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,17 @@
"""Chainlit frontend for RAGLite."""

import json
import os
from pathlib import Path

import chainlit as cl
from chainlit.input_widget import Switch, TextInput

from raglite import (
RAGLiteConfig,
async_rag,
create_rag_instruction,
hybrid_search,
insert_document,
rerank_chunks,
retrieve_chunk_spans,
retrieve_chunks,
)
from raglite import RAGLiteConfig, async_rag, hybrid_search, insert_document, rerank_chunks
from raglite._markdown import document_to_markdown

async_insert_document = cl.make_async(insert_document)
async_hybrid_search = cl.make_async(hybrid_search)
async_retrieve_chunks = cl.make_async(retrieve_chunks)
async_retrieve_chunk_spans = cl.make_async(retrieve_chunk_spans)
async_rerank_chunks = cl.make_async(rerank_chunks)


Expand Down Expand Up @@ -93,31 +83,27 @@ async def handle_message(user_message: cl.Message) -> None:
for i, attachment in enumerate(inline_attachments)
)
+ f"\n\n{user_message.content}"
)
# Search for relevant contexts for RAG.
async with cl.Step(name="search", type="retrieval") as step:
step.input = user_message.content
chunk_ids, _ = await async_hybrid_search(query=user_prompt, num_results=10, config=config)
chunks = await async_retrieve_chunks(chunk_ids=chunk_ids, config=config)
step.output = chunks
step.elements = [ # Show the top chunks inline.
cl.Text(content=str(chunk), display="inline") for chunk in chunks[:5]
]
await step.update() # TODO: Workaround for https://github.com/Chainlit/chainlit/issues/602.
# Rerank the chunks and group them into chunk spans.
async with cl.Step(name="rerank", type="rerank") as step:
step.input = chunks
chunks = await async_rerank_chunks(query=user_prompt, chunk_ids=chunks, config=config)
chunk_spans = await async_retrieve_chunk_spans(chunks[:5], config=config)
step.output = chunk_spans
step.elements = [ # Show the top chunk spans inline.
cl.Text(content=str(chunk_span), display="inline") for chunk_span in chunk_spans
]
await step.update() # TODO: Workaround for https://github.com/Chainlit/chainlit/issues/602.
).strip()
# Stream the LLM response.
assistant_message = cl.Message(content="")
messages: list[dict[str, str]] = cl.chat_context.to_openai()[:-1] # type: ignore[no-untyped-call]
messages.append(create_rag_instruction(user_prompt=user_prompt, context=chunk_spans))
messages.append({"role": "user", "content": user_prompt})
async for token in async_rag(messages, config=config):
await assistant_message.stream_token(token)
# Append RAG sources if any.
if messages[-2]["role"] == "tool":
rag_context = json.loads(messages[-2]["content"])
rag_sources: dict[str, list[str]] = {}
for document in rag_context["documents"]:
rag_sources.setdefault(document["source"], [])
rag_sources[document["source"]].append(
document["span"]["headings"] + "\n" + document["span"]["content"]
)
assistant_message.content += "\n\nSources: " + ", ".join( # Rendered as hyperlinks.
f"[{i + 1}]" for i in range(len(rag_sources))
)
assistant_message.elements = [ # Markdown content is rendered in sidebar.
cl.Text(name=f"[{i + 1}]", content="\n\n---\n\n".join(content), display="side") # type: ignore[misc]
for i, (_, content) in enumerate(rag_sources.items())
]
await assistant_message.update() # type: ignore[no-untyped-call]

0 comments on commit 93a75b9

Please sign in to comment.