Skip to content

Commit

Permalink
feat: add config pydantic classes for chroma and qdrant, refactor qdr…
Browse files Browse the repository at this point in the history
…ant init method for readability
  • Loading branch information
Sai krishna committed Oct 20, 2024
1 parent ef56a7e commit 472c6b3
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 53 deletions.
7 changes: 0 additions & 7 deletions backend/modules/query_controllers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import async_timeout
import requests
from fastapi import HTTPException
from langchain.prompts import PromptTemplate
from langchain.retrievers import ContextualCompressionRetriever, MultiQueryRetriever
from langchain.schema.vectorstore import VectorStoreRetriever
from langchain_core.language_models.chat_models import BaseChatModel
Expand Down Expand Up @@ -40,12 +39,6 @@ class BaseQueryController:
"relevance_score",
]

def _get_prompt_template(self, input_variables, template):
"""
Get the prompt template
"""
return PromptTemplate(input_variables=input_variables, template=template)

def _format_docs(self, docs):
return "\n\n".join([doc.page_content for doc in docs])

Expand Down
3 changes: 2 additions & 1 deletion backend/modules/query_controllers/example/controller.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from fastapi import Body
from fastapi.responses import StreamingResponse
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnableParallel, RunnablePassthrough

from backend.modules.query_controllers.base import BaseQueryController
Expand Down Expand Up @@ -35,7 +36,7 @@ async def answer(
vector_store = await self._get_vector_store(request.collection_name)

# Create the QA prompt templates
QA_PROMPT = self._get_prompt_template(
QA_PROMPT = PromptTemplate(
input_variables=["context", "question"],
template=request.prompt_template,
)
Expand Down
53 changes: 41 additions & 12 deletions backend/modules/vector_db/chroma.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
from typing import List, Optional, Union
from typing import List, Optional

from chromadb import HttpClient, PersistentClient
from chromadb.api import ClientAPI
from chromadb.api.models.Collection import Collection
from fastapi import HTTPException
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain_community.vectorstores import Chroma

from backend.constants import DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE
from backend.modules.vector_db.base import BaseVectorDB
from backend.types import VectorDBConfig
from backend.types import ChromaVectorDBConfig


class ChromaVectorDB(BaseVectorDB):
def __init__(self, db_config: VectorDBConfig):
def __init__(self, db_config: ChromaVectorDBConfig):
self.db_config = db_config
self.client = self.get_client()
self.client = self._create_client()

def get_client(self) -> ClientAPI:
## Initialization utility

def _create_client(self) -> ClientAPI:
# For local development, we use a persistent client that saves data to a temporary directory
if self.db_config.local:
return PersistentClient()
Expand All @@ -28,11 +32,20 @@ def get_client(self) -> ClientAPI:
config=self.db_config.config,
)

def get_vector_client(self) -> Union[PersistentClient, HttpClient]:
## Client
def get_vector_client(self) -> ClientAPI:
return self.client

def get_vector_store(self, collection_name: str, embeddings: Embeddings):
pass
## Vector store

def get_vector_store(self, collection_name: str, **kwargs):
return Chroma(
client=self.client,
collection_name=collection_name,
**kwargs,
)

## Collections

def create_collection(self, collection_name: str, **kwargs):
try:
Expand All @@ -48,6 +61,9 @@ def create_collection(self, collection_name: str, **kwargs):
status_code=400, detail=f"Unable to create collection: {e}"
)

def get_collection(self, collection_name: str):
return self.client.get_collection(name=collection_name)

def delete_collection(self, collection_name: str):
try:
return self.client.delete_collection(name=collection_name)
Expand All @@ -59,9 +75,16 @@ def delete_collection(self, collection_name: str):

def list_collections(
self, limit: Optional[int] = None, offset: Optional[int] = None
):
) -> List[Collection]:
return self.client.list_collections(limit=limit, offset=offset)

def get_collections(
self, limit: Optional[int] = None, offset: Optional[int] = None
) -> List[Collection]:
return self.client.list_collections(limit=limit, offset=offset)

## Documents

def upsert_documents(
self,
collection_name: str,
Expand All @@ -73,15 +96,21 @@ def upsert_documents(
collection_name, documents, embeddings, incremental
)

def delete_documents(self, collection_name: str, document_ids: List[str]):
# Fetch the collection
collection: Collection = self.client.get_collection(collection_name)
# Delete the documents in the collection by ids
collection.delete(ids=document_ids)

## Data point vectors

def list_data_point_vectors(
self,
collection_name: str,
data_source_fqn: str,
batch_size: int = DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE,
):
return super().list_data_point_vectors(
collection_name, data_source_fqn, batch_size
)
pass

def delete_data_point_vectors(
self,
Expand Down
68 changes: 36 additions & 32 deletions backend/modules/vector_db/qdrant.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional
from urllib.parse import urlparse

from langchain.embeddings.base import Embeddings
Expand All @@ -9,36 +9,42 @@
from backend.constants import DATA_POINT_FQN_METADATA_KEY, DATA_POINT_HASH_METADATA_KEY
from backend.logger import logger
from backend.modules.vector_db.base import BaseVectorDB
from backend.types import DataPointVector, QdrantClientConfig, VectorDBConfig
from backend.types import DataPointVector, QdrantVectorDBConfig

MAX_SCROLL_LIMIT = int(1e6)
BATCH_SIZE = 1000


class QdrantVectorDB(BaseVectorDB):
def __init__(self, config: VectorDBConfig):
logger.debug(f"Connecting to qdrant using config: {config.model_dump()}")
if config.local is True:
# TODO: make this path customizable
self.qdrant_client = QdrantClient(
path="./qdrant_db",
)
else:
url = config.url
api_key = config.api_key
if not api_key:
api_key = None
qdrant_kwargs = QdrantClientConfig.model_validate(config.config or {})
if url.startswith("http://") or url.startswith("https://"):
if qdrant_kwargs.port is None:
parsed_port = urlparse(url).port
if parsed_port:
qdrant_kwargs.port = parsed_port
else:
qdrant_kwargs.port = 443 if url.startswith("https://") else 6333
self.qdrant_client = QdrantClient(
url=url, api_key=api_key, **qdrant_kwargs.model_dump()
)
def __init__(self, db_config: QdrantVectorDBConfig):
logger.debug(f"Connecting to qdrant using config: {db_config.model_dump()}")
self.qdrant_client = self._create_client(db_config)

def _create_client(self, db_config: QdrantVectorDBConfig) -> QdrantClient:
# Local
if db_config.local:
return QdrantClient(path=db_config.path)

url = db_config.url

if url.startswith(("http://", "https://")):
db_config.config.port = self._get_port(url, db_config.config.port)

# If the Qdrant server is hosted on a remote server, create an http client
return QdrantClient(
url=url, api_key=db_config.api_key, **db_config.config.model_dump()
)

@staticmethod
def _get_port(url: str, existing_port: Optional[int]) -> int:
if existing_port:
return existing_port

parsed_port = urlparse(url).port
if parsed_port:
return parsed_port

return 443 if url.startswith("https://") else 6333

def create_collection(self, collection_name: str, embeddings: Embeddings):
logger.debug(f"[Qdrant] Creating new collection {collection_name}")
Expand Down Expand Up @@ -113,11 +119,11 @@ def _get_records_to_be_upserted(
def upsert_documents(
self,
collection_name: str,
documents,
documents: list,
embeddings: Embeddings,
incremental: bool = True,
):
if len(documents) == 0:
if not documents:
logger.warning("No documents to index")
return
# get record IDs to be upserted
Expand Down Expand Up @@ -219,11 +225,9 @@ def list_data_point_vectors(
offset=offset,
)
for record in records:
metadata: dict = record.payload.get("metadata")
if (
metadata
and metadata.get(DATA_POINT_FQN_METADATA_KEY)
and metadata.get(DATA_POINT_HASH_METADATA_KEY)
metadata: dict = record.payload.get("metadata", {})
if metadata.get(DATA_POINT_FQN_METADATA_KEY) and metadata.get(
DATA_POINT_HASH_METADATA_KEY
):
data_point_vectors.append(
DataPointVector(
Expand Down
29 changes: 28 additions & 1 deletion backend/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,28 @@ class VectorDBConfig(ConfiguredBaseModel):
"""

provider: str
path: Optional[str] = None
local: bool = False
url: Optional[str] = None
api_key: Optional[str] = None
config: Optional[Dict[str, Any]] = Field(default_factory=dict)
config: Dict[str, Any] = Field(default_factory=dict)


class ChromaClientConfig(ConfiguredBaseModel):
"""
Chroma client configuration
"""

model_config = ConfigDict(extra="allow")


class ChromaVectorDBConfig(VectorDBConfig):
"""
Chroma-specific vector db configuration
"""

path: Optional[str] = "./chroma_db"
config: ChromaClientConfig = Field(default_factory=ChromaClientConfig)


class QdrantClientConfig(ConfiguredBaseModel):
Expand All @@ -196,6 +214,15 @@ class QdrantClientConfig(ConfiguredBaseModel):
timeout: int = 300


class QdrantVectorDBConfig(VectorDBConfig):
"""
Qdrant-specific vector db configuration
"""

path: Optional[str] = "./qdrant_db"
config: QdrantClientConfig = Field(default_factory=QdrantClientConfig)


class MetadataStoreConfig(ConfiguredBaseModel):
"""
Metadata store configuration
Expand Down

0 comments on commit 472c6b3

Please sign in to comment.