Skip to content

Commit

Permalink
CrateDB: Vector Store -- make it work using CrateDB's vector_similarity
Browse files Browse the repository at this point in the history
Before, the adapter used CrateDB's built-in `_score` field for ranking.
Now, it uses the dedicated `vector_similarity()` function to compute the
similarity between two vectors.
  • Loading branch information
amotl committed Nov 4, 2024
1 parent ffda5c8 commit 476d718
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 26 deletions.
27 changes: 17 additions & 10 deletions libs/community/langchain_community/vectorstores/cratedb/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, floa
page_content=result.EmbeddingStore.document,
metadata=result.EmbeddingStore.cmetadata,
),
result._score if self.embedding_function is not None else None,
result.similarity if self.embedding_function is not None else None,
)
for result in results
]
Expand Down Expand Up @@ -324,15 +324,22 @@ def _query_collection_multi(
results: List[Any] = (
session.query( # type: ignore[attr-defined]
self.EmbeddingStore,
# FIXME: Using `_score` is definitively the wrong choice.
# - https://github.com/crate-workbench/langchain/issues/19
# - https://github.com/crate/crate/issues/15835
# TODO: Original pgvector code uses `self.distance_strategy`.
# CrateDB currently only supports EUCLIDEAN.
# self.distance_strategy(embedding).label("distance")
sqlalchemy.literal_column(
f"{self.EmbeddingStore.__tablename__}._score"
).label("_score"),
sqlalchemy.func.vector_similarity(
self.EmbeddingStore.embedding,
# TODO: Just reference the `embedding` symbol here, don't
# serialize its value prematurely.
# https://github.com/crate/crate/issues/16912
#
# Until that got fixed, marshal the arguments to
# `vector_similarity()` manually, in order to work around
# this edge case bug. We don't need to use JSON marshalling,
# because Python's string representation of a list is just
# right.
sqlalchemy.text(str(embedding)),
).label("similarity"),
)
.filter(filter_by)
# CrateDB applies `KNN_MATCH` within the `WHERE` clause.
Expand All @@ -341,7 +348,7 @@ def _query_collection_multi(
self.EmbeddingStore.embedding, embedding, k
)
)
.order_by(sqlalchemy.desc("_score"))
.order_by(sqlalchemy.desc("similarity"))
.join(
self.CollectionStore,
self.EmbeddingStore.collection_id == self.CollectionStore.uuid,
Expand Down Expand Up @@ -450,7 +457,7 @@ def _select_relevance_score_fn(self) -> Callable[[float], float]:
)

@staticmethod
def _euclidean_relevance_score_fn(score: float) -> float:
def _euclidean_relevance_score_fn(similarity: float) -> float:
"""Return a similarity score on a scale [0, 1]."""
# The 'correct' relevance function
# may differ depending on a few things, including:
Expand All @@ -465,4 +472,4 @@ def _euclidean_relevance_score_fn(score: float) -> float:

# Original:
# return 1.0 - distance / math.sqrt(2)
return score / math.sqrt(2)
return similarity / math.sqrt(2)
28 changes: 12 additions & 16 deletions libs/community/tests/integration_tests/vectorstores/test_cratedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def test_cratedb_with_metadatas_with_scores() -> None:
pre_delete_collection=True,
)
output = docsearch.similarity_search_with_score("foo", k=1)
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 2.0)]
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 1.0)]


def test_cratedb_with_filter_match() -> None:
Expand All @@ -250,9 +250,7 @@ def test_cratedb_with_filter_match() -> None:
# TODO: Original:
# assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] # noqa: E501
output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "0"})
assert output == [
(Document(page_content="foo", metadata={"page": "0"}), pytest.approx(2.2, 0.3))
]
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 1.0)]


def test_cratedb_with_filter_distant_match() -> None:
Expand All @@ -269,9 +267,7 @@ def test_cratedb_with_filter_distant_match() -> None:
)
output = docsearch.similarity_search_with_score("foo", k=2, filter={"page": "2"})
# Original score value: 0.0013003906671379406
assert output == [
(Document(page_content="baz", metadata={"page": "2"}), pytest.approx(1.5, 0.2))
]
assert output == [(Document(page_content="baz", metadata={"page": "2"}), 0.2)]


def test_cratedb_with_filter_no_match() -> None:
Expand Down Expand Up @@ -425,8 +421,8 @@ def test_cratedb_with_filter_in_set() -> None:
)
# Original score values: 0.0, 0.0013003906671379406
assert output == [
(Document(page_content="foo", metadata={"page": "0"}), pytest.approx(3.0, 0.1)),
(Document(page_content="baz", metadata={"page": "2"}), pytest.approx(2.2, 0.1)),
(Document(page_content="foo", metadata={"page": "0"}), 1.0),
(Document(page_content="baz", metadata={"page": "2"}), 0.2),
]


Expand Down Expand Up @@ -474,9 +470,9 @@ def test_cratedb_relevance_score() -> None:
output = docsearch.similarity_search_with_relevance_scores("foo", k=3)
# Original score values: 1.0, 0.9996744261675065, 0.9986996093328621
assert output == [
(Document(page_content="foo", metadata={"page": "0"}), pytest.approx(1.4, 0.1)),
(Document(page_content="bar", metadata={"page": "1"}), pytest.approx(1.1, 0.1)),
(Document(page_content="baz", metadata={"page": "2"}), pytest.approx(0.8, 0.1)),
(Document(page_content="foo", metadata={"page": "0"}), 0.7071067811865475),
(Document(page_content="bar", metadata={"page": "1"}), 0.35355339059327373),
(Document(page_content="baz", metadata={"page": "2"}), 0.1414213562373095),
]


Expand All @@ -495,9 +491,9 @@ def test_cratedb_retriever_search_threshold() -> None:

retriever = docsearch.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": 3, "score_threshold": 0.999},
search_kwargs={"k": 3, "score_threshold": 0.35}, # Original value: 0.999
)
output = retriever.get_relevant_documents("summer")
output = retriever.invoke("summer")
assert output == [
Document(page_content="foo", metadata={"page": "0"}),
Document(page_content="bar", metadata={"page": "1"}),
Expand All @@ -522,7 +518,7 @@ def test_cratedb_retriever_search_threshold_custom_normalization_fn() -> None:
search_type="similarity_score_threshold",
search_kwargs={"k": 3, "score_threshold": 0.5},
)
output = retriever.get_relevant_documents("foo")
output = retriever.invoke("foo")
assert output == []


Expand Down Expand Up @@ -551,7 +547,7 @@ def test_cratedb_max_marginal_relevance_search_with_score() -> None:
pre_delete_collection=True,
)
output = docsearch.max_marginal_relevance_search_with_score("foo", k=1, fetch_k=3)
assert output == [(Document(page_content="foo"), 2.0)]
assert output == [(Document(page_content="foo"), 1.0)]


def test_cratedb_multicollection_search_success() -> None:
Expand Down

0 comments on commit 476d718

Please sign in to comment.