Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add elastic #155

Merged
merged 11 commits into from
Nov 14, 2024
1 change: 1 addition & 0 deletions llm-complete-guide/configs/dev/rag.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ settings:
- pygithub
- rerankers[flashrank]
- matplotlib
- elasticsearch
wjayesh marked this conversation as resolved.
Show resolved Hide resolved

environment:
ZENML_PROJECT_SECRET_NAME: llm_complete
Expand Down
1 change: 1 addition & 0 deletions llm-complete-guide/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,4 @@
USE_ARGILLA_ANNOTATIONS = False

SECRET_NAME = os.getenv("ZENML_PROJECT_SECRET_NAME", "llm-complete")
SECRET_NAME_ELASTICSEARCH = "elasticsearch-zenml"
3 changes: 2 additions & 1 deletion llm-complete-guide/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
zenml[server]>=0.68.1
zenml[server]==0.68.1
ratelimit
pgvector
psycopg2-binary
Expand All @@ -20,6 +20,7 @@ datasets
torch
gradio
huggingface-hub
elasticsearch

# optional requirements for S3 artifact store
# s3fs>2022.3.0
Expand Down
5 changes: 3 additions & 2 deletions llm-complete-guide/steps/eval_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from utils.llm_utils import (
get_db_conn,
get_embeddings,
get_es_client,
get_topn_similar_docs,
rerank_documents,
)
Expand Down Expand Up @@ -76,11 +77,11 @@ def query_similar_docs(
Tuple containing the question, URL ending, and retrieved URLs.
"""
embedded_question = get_embeddings(question)
db_conn = get_db_conn()
es_client = get_es_client()
num_docs = 20 if use_reranking else returned_sample_size
# get (content, url) tuples for the top n similar documents
top_similar_docs = get_topn_similar_docs(
embedded_question, db_conn, n=num_docs, include_metadata=True
embedded_question, es_client, n=num_docs, include_metadata=True
)

if use_reranking:
Expand Down
180 changes: 93 additions & 87 deletions llm-complete-guide/steps/populate_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@
CHUNK_SIZE,
EMBEDDING_DIMENSIONALITY,
EMBEDDINGS_MODEL,
SECRET_NAME_ELASTICSEARCH,
)
from pgvector.psycopg2 import register_vector
from PIL import Image, ImageDraw, ImageFont
from sentence_transformers import SentenceTransformer
from structures import Document
from utils.llm_utils import get_db_conn, split_documents
from utils.llm_utils import get_db_conn, get_es_client, split_documents
from zenml import ArtifactConfig, log_artifact_metadata, step, log_model_metadata
from zenml.metadata.metadata_types import Uri
from zenml.client import Client
Expand Down Expand Up @@ -592,7 +593,7 @@ def generate_embeddings(
raise


@step
@step(enable_cache=False)
def index_generator(
documents: str,
) -> None:
Expand All @@ -609,83 +610,89 @@ def index_generator(
Raises:
Exception: If an error occurs during the index generation.
"""
conn = None
from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk
import hashlib

try:
conn = get_db_conn()
with conn.cursor() as cur:
# Install pgvector if not already installed
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
conn.commit()

# Create the embeddings table if it doesn't exist
table_create_command = f"""
CREATE TABLE IF NOT EXISTS embeddings (
id SERIAL PRIMARY KEY,
content TEXT,
token_count INTEGER,
embedding VECTOR({EMBEDDING_DIMENSIONALITY}),
filename TEXT,
parent_section TEXT,
url TEXT
);
"""
cur.execute(table_create_command)
conn.commit()

register_vector(conn)

# Parse the JSON string into a list of Document objects
document_list = [Document(**doc) for doc in json.loads(documents)]

# Insert data only if it doesn't already exist
for doc in document_list:
content = doc.page_content
token_count = doc.token_count
embedding = doc.embedding
filename = doc.filename
parent_section = doc.parent_section
url = doc.url

cur.execute(
"SELECT COUNT(*) FROM embeddings WHERE content = %s",
(content,),
)
count = cur.fetchone()[0]
if count == 0:
cur.execute(
"INSERT INTO embeddings (content, token_count, embedding, filename, parent_section, url) VALUES (%s, %s, %s, %s, %s, %s)",
(
content,
token_count,
embedding,
filename,
parent_section,
url,
),
)
conn.commit()

cur.execute("SELECT COUNT(*) as cnt FROM embeddings;")
num_records = cur.fetchone()[0]
logger.info(f"Number of vector records in table: {num_records}")

# calculate the index parameters according to best practices
num_lists = max(num_records / 1000, 10)
if num_records > 1000000:
num_lists = math.sqrt(num_records)

# use the cosine distance measure, which is what we'll later use for querying
cur.execute(
f"CREATE INDEX IF NOT EXISTS embeddings_idx ON embeddings USING ivfflat (embedding vector_cosine_ops) WITH (lists = {num_lists});"
)
conn.commit()
es = get_es_client()
index_name = "zenml_docs"

# Create index with mappings if it doesn't exist
if not es.indices.exists(index=index_name):
mappings = {
"mappings": {
"properties": {
"doc_id": {"type": "keyword"},
"content": {"type": "text"},
"token_count": {"type": "integer"},
"embedding": {
"type": "dense_vector",
"dims": EMBEDDING_DIMENSIONALITY,
"index": True,
"similarity": "cosine"
},
"filename": {"type": "text"},
"parent_section": {"type": "text"},
"url": {"type": "text"}
}
}
}
# TODO move to using mappings param directly
es.indices.create(index=index_name, body=mappings)

except Exception as e:
logger.error(f"Error in index_generator: {e}")
raise
finally:
if conn:
conn.close()
# Parse the JSON string into a list of Document objects
document_list = [Document(**doc) for doc in json.loads(documents)]

# Prepare bulk operations
operations = []
for doc in document_list[:500]:
# Create a unique identifier based on content and metadata
content_hash = hashlib.md5(
f"{doc.page_content}{doc.filename}{doc.parent_section}{doc.url}".encode()
).hexdigest()

# Check if document exists
exists_query = {
"query": {
"term": {
"doc_id": content_hash
}
}
}

if not es.count(index=index_name, body=exists_query)["count"]:
operations.append({
"index": {
"_index": index_name,
"_id": content_hash
}
})

operations.append({
"doc_id": content_hash,
"content": doc.page_content,
"token_count": doc.token_count,
"embedding": doc.embedding,
"filename": doc.filename,
"parent_section": doc.parent_section,
"url": doc.url
})

if operations:
response = es.bulk(operations=operations, timeout="10m")

success_count = sum(1 for item in response['items'] if 'index' in item and item['index']['status'] == 201)
failed_count = len(response['items']) - success_count

logger.info(f"Successfully indexed {success_count} documents")
if failed_count > 0:
logger.warning(f"Failed to index {failed_count} documents")
for item in response['items']:
if 'index' in item and item['index']['status'] != 201:
logger.warning(f"Failed to index document: {item['index']['error']}")
else:
logger.info("No new documents to index")

# Log the model metadata
prompt = """
Expand All @@ -700,12 +707,10 @@ def index_generator(
"""

client = Client()
es_host = client.get_secret(SECRET_NAME_ELASTICSEARCH).secret_values["elasticsearch_host"]
CONNECTION_DETAILS = {
"user": client.get_secret(SECRET_NAME).secret_values["supabase_user"],
"password": "**********",
"host": client.get_secret(SECRET_NAME).secret_values["supabase_host"],
"port": client.get_secret(SECRET_NAME).secret_values["supabase_port"],
"dbname": "postgres",
"host": es_host,
"api_key": "*********",
}

log_model_metadata(
Expand All @@ -721,12 +726,13 @@ def index_generator(
"content": prompt,
},
"vector_store": {
"name": "pgvector",
"name": "elasticsearch",
"connection_details": CONNECTION_DETAILS,
# TODO: Hard-coded for now
"database_url": Uri(
"https://supabase.com/dashboard/project/rkoiacgkeiwpwceahtlp/editor/29505?schema=public"
),
"index_name": index_name
},
},
)

except Exception as e:
logger.error(f"Error in index_generator: {e}")
raise
Loading
Loading