Skip to content

Commit

Permalink
Merge pull request #191 from nulib/4585-hybrid-search
Browse files Browse the repository at this point in the history
  • Loading branch information
mbklein authored Mar 12, 2024
2 parents f9f0f4c + dddb6db commit c619c82
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 3 deletions.
5 changes: 3 additions & 2 deletions chat/src/handlers/opensearch_neural_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class OpenSearchNeuralSearch(VectorStore):

def __init__(
self,
client: None,
endpoint: str,
index: str,
model_id: str,
Expand All @@ -17,7 +18,7 @@ def __init__(
text_field: str = "id",
**kwargs: Any,
):
self.client = OpenSearch(
self.client = client or OpenSearch(
hosts=[{"host": endpoint, "port": "443", "use_ssl": True}], **kwargs
)
self.index = index
Expand Down Expand Up @@ -64,7 +65,7 @@ def similarity_search_with_score(
for key, value in kwargs.items():
dsl[key] = value

response = self.client.search(index=self.index, body=dsl)
response = self.client.search(index=self.index, body=dsl, params={"search_pipeline": self.search_pipeline} if self.search_pipeline else None)

documents_with_scores = [
(
Expand Down
3 changes: 2 additions & 1 deletion chat/src/helpers/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ def extract_prompt_value(v):

def prepare_response(config):
try:
subquery = {"match": {"all_text": {"query": config.question}}}
docs = config.opensearch.similarity_search(
query=config.question, k=config.k
query=config.question, k=config.k, subquery=subquery, _source={"excludes": ["embedding"]}
)
original_question = get_and_send_original_question(config, docs)
response = config.chain({"question": config.question, "input_documents": docs})
Expand Down
43 changes: 43 additions & 0 deletions chat/test/handlers/test_opensearch_neural_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# ruff: noqa: E402
import sys
sys.path.append('./src')

from unittest import TestCase
from handlers.opensearch_neural_search import OpenSearchNeuralSearch
from langchain_core.documents import Document

class MockClient():
def search(self, index, body, params):
return {
"hits": {
"hits": [
{
"_source": {
"id": "test"
},
"_score": 0.12345
}
]
}
}

class TestOpenSearchNeuralSearch(TestCase):
def test_similarity_search(self):
docs = OpenSearchNeuralSearch(client=MockClient(), endpoint="test", index="test", model_id="test").similarity_search(query="test", subquery={"_source": {"excludes": ["embedding"]}}, size=10)
self.assertEqual(docs, [Document(page_content='test', metadata={'id': 'test'})])

def test_similarity_search_with_score(self):
docs = OpenSearchNeuralSearch(client=MockClient(), endpoint="test", index="test", model_id="test").similarity_search_with_score(query="test")
self.assertEqual(docs, [(Document(page_content='test', metadata={'id': 'test'}), 0.12345)])

def test_add_texts(self):
try:
OpenSearchNeuralSearch(client=MockClient(), endpoint="test", index="test", model_id="test").add_texts(texts=["test"], metadatas=[{"id": "test"}])
except Exception as e:
self.fail(f"from_texts raised an exception: {e}")

def test_from_texts(self):
try:
OpenSearchNeuralSearch.from_texts(clas="test", texts=["test"], metadatas=[{"id": "test"}])
except Exception as e:
self.fail(f"from_texts raised an exception: {e}")

0 comments on commit c619c82

Please sign in to comment.