Skip to content

Commit

Permalink
feat: Use any embedding model with ragas
Browse files Browse the repository at this point in the history
  • Loading branch information
undo76 committed Dec 4, 2024
1 parent 0fd1970 commit c9317e8
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions src/raglite/_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,16 +210,33 @@ def evaluate(
try:
from datasets import Dataset
from langchain_community.chat_models import ChatLiteLLM
from langchain_community.embeddings import LlamaCppEmbeddings
from langchain_community.llms import LlamaCpp
from ragas import RunConfig
from ragas import evaluate as ragas_evaluate
from ragas.embeddings import BaseRagasEmbeddings

from raglite._config import RAGLiteConfig
from raglite._embed import embed_sentences
from raglite._litellm import LlamaCppPythonLLM

except ImportError as import_error:
error_message = "To use the `evaluate` function, please install the `ragas` extra."
raise ImportError(error_message) from import_error

class RAGLiteRagasEmbeddings(BaseRagasEmbeddings):
def __init__(self, config: RAGLiteConfig):
self.config = config or RAGLiteConfig()

def embed_query(self, text: str) -> list[float]:
# Embed the query text using RAGLite's embedding function
embeddings = embed_sentences([text], config=self.config)
return embeddings[0].tolist() # type: ignore[no-any-return]

def embed_documents(self, texts: list[str]) -> list[list[float]]:
# Embed a list of documents using RAGLite's embedding function
embeddings = embed_sentences(texts, config=self.config)
return embeddings.tolist() # type: ignore[no-any-return]

# Create a set of answered evals if not provided.
config = config or RAGLiteConfig()
answered_evals_df = (
Expand All @@ -239,23 +256,12 @@ def evaluate(
)
else:
lc_llm = ChatLiteLLM(model=config.llm) # type: ignore[call-arg]
# Load the embedder.
if not config.embedder.startswith("llama-cpp-python"):
error_message = "Currently, only `llama-cpp-python` embedders are supported."
raise NotImplementedError(error_message)
embedder = LlamaCppPythonLLM().llm(model=config.embedder, embedding=True)
lc_embedder = LlamaCppEmbeddings( # type: ignore[call-arg]
model_path=embedder.model_path,
n_batch=embedder.n_batch,
n_ctx=embedder.n_ctx(),
n_gpu_layers=-1,
verbose=embedder.verbose,
)
embedder = RAGLiteRagasEmbeddings(config=config)
# Evaluate the answered evals with Ragas.
evaluation_df = ragas_evaluate(
dataset=Dataset.from_pandas(answered_evals_df),
llm=lc_llm,
embeddings=lc_embedder,
embeddings=embedder,
run_config=RunConfig(max_workers=1),
).to_pandas()
return evaluation_df

0 comments on commit c9317e8

Please sign in to comment.