-
Notifications
You must be signed in to change notification settings - Fork 55
/
Copy path__init__.py
44 lines (37 loc) · 1.26 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from typing import Optional
from dotenv import load_dotenv
from semantic_router.encoders import BaseEncoder
from semantic_router.encoders.openai import OpenAIEncoder
from models.vector_database import VectorDatabase
from vectordbs.astra import AstraService
from vectordbs.base import BaseVectorDatabase
from vectordbs.pinecone import PineconeService
from vectordbs.qdrant import QdrantService
from vectordbs.weaviate import WeaviateService
from vectordbs.pgvector import PGVectorService
load_dotenv()
def get_vector_service(
*,
index_name: str,
credentials: VectorDatabase,
encoder: BaseEncoder = OpenAIEncoder(),
dimensions: Optional[int] = 384,
) -> BaseVectorDatabase:
services = {
"pinecone": PineconeService,
"qdrant": QdrantService,
"weaviate": WeaviateService,
"astra": AstraService,
"pgvector": PGVectorService,
# Add other providers here
# e.g "weaviate": WeaviateVectorService,
}
service = services.get(credentials.type.value)
if service is None:
raise ValueError(f"Unsupported provider: {credentials.type.value}")
return service(
index_name=index_name,
dimension=dimensions,
credentials=dict(credentials.config),
encoder=encoder,
)