-
Notifications
You must be signed in to change notification settings - Fork 1
/
search.py
134 lines (110 loc) · 4.93 KB
/
search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import enum
from pathlib import Path
from typing import List, Dict, Union
import torch
import pandas as pd
from dotenv import load_dotenv
import pinecone
from pinecone.core.client.model.query_response import QueryResponse
from sentence_transformers import SentenceTransformer
from embeddings import get_single_embedding
from config import ST_EMBEDDING_MODEL
from encode import Verse
from models import EmbeddingType
load_dotenv()
pinecone.init(api_key=os.getenv("PINECONE_API_KEY"), environment=os.getenv("PINECONE_ENV"))
class Engine:
def _get_query_vec(self, query: str, emb_type: EmbeddingType) -> torch.Tensor:
if emb_type == EmbeddingType.Ada:
return torch.tensor(get_single_embedding(query))
elif emb_type == EmbeddingType.SentenceTransfomer:
return self.model.encode(query, convert_to_tensor=True)
else:
raise Exception(f"No such embedding: {emb_type}")
def search(
self,
query: str,
emb_type: EmbeddingType = EmbeddingType.Ada,
only_text: bool = False
) -> List[Verse]:
raise NotImplementedError
class SearchEngine(Engine):
"""Supports SentenceTransformer encoder model and OpenAI's Ada v2 embeddings."""
def __init__(self, named_paths: Dict[str, Path]):
self.df = {name: pd.read_parquet(path, engine="fastparquet") for name, path in named_paths.items()}
self.model = SentenceTransformer(ST_EMBEDDING_MODEL)
self._device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def _get_query_vec(self, query: str, emb_type: EmbeddingType):
return super()._get_query_vec(query, emb_type).to(self._device)
def _get_embeddings(self, emb_type: EmbeddingType, translation: str):
return torch.tensor(self.df[translation][emb_type.value]).to(self._device)
def _get_search_results(
self, query_vec: torch.Tensor, embeddings: torch.Tensor, translation: str, k: int = 10, only_text=False
) -> List[Verse]:
"""
Cosine similarity: Ada embeddings are L2 normalized, so only require a dot product
between the query and embedding vectors.
"""
results = torch.topk(torch.matmul(embeddings, query_vec), k)
results = self.df[translation].loc[results.indices.cpu()][
["text"] if only_text else ["book", "chapter", "verse", "text"]
]
return [
Verse(None, None, None, np_row[0]) if only_text else Verse(*np_row)
for np_row in list(results.values)
]
def search(
self,
query: str,
emb_type: EmbeddingType = EmbeddingType.Ada,
only_text: bool = False,
translation: str = "NKJV"
) -> List[Verse]:
query_vec = self._get_query_vec(query, emb_type)
_embeddings = self._get_embeddings(emb_type, translation)
return self._get_search_results(
query_vec, _embeddings, translation, only_text=only_text
)
class PineconeSearchEngine(Engine):
def __init__(self, named_paths: Dict[str, Path], index: Union[str, List]) -> None:
self.df = {name: pd.read_csv(path) for name, path in named_paths.items()}
self._get_index(index)
# only Ada in present in Pinecone due to free tier limitations
# self.model = SentenceTransformer(ST_EMBEDDING_MODEL)
def _get_query_vec(self, query: str, emb_type: EmbeddingType) -> List[float]:
return super()._get_query_vec(query, emb_type).tolist()
def _get_index(self, index: Union[str, List]):
index = [index] if isinstance(index, str) else index
self.indices = {n: pinecone.Index(n) for n in index}
def _get_search_results(
self,
query_vec: List[float],
emb_type: EmbeddingType,
translation: str,
only_text: bool,
k: int = 10,
) -> List[Verse]:
index = self.indices['ada' if emb_type == EmbeddingType.Ada else "mpnet"]
results = index.query(query_vec, namespace=translation, top_k=k)
return self._convert(results, translation, only_text)
def _convert(self, results: QueryResponse, translation: str, only_text: bool) -> List[Verse]:
"""Maps Pinecone results to the book, chapter and verse"""
if not results or not results.matches:
return []
res = self.df[translation].iloc[[int(r['id']) for r in results.matches]][
["text"] if only_text else ["book", "chapter", "verse", "text"]
]
return [
Verse(None, None, None, np_row[0]) if only_text else Verse(*np_row)
for np_row in res.itertuples(index=None, name=None)
]
def search(
self,
query: str,
emb_type: EmbeddingType = EmbeddingType.Ada,
only_text: bool = False,
translation: str = "NKJV",
) -> List[Verse]:
query_vec = self._get_query_vec(query, emb_type)
return self._get_search_results(query_vec, emb_type, translation, only_text)