diff --git a/libs/community/langchain_community/vectorstores/cratedb/base.py b/libs/community/langchain_community/vectorstores/cratedb/base.py index 77984daf2cf39e..e9109764ecfc7e 100644 --- a/libs/community/langchain_community/vectorstores/cratedb/base.py +++ b/libs/community/langchain_community/vectorstores/cratedb/base.py @@ -260,7 +260,7 @@ def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, floa page_content=result.EmbeddingStore.document, metadata=result.EmbeddingStore.cmetadata, ), - result._score if self.embedding_function is not None else None, + result.similarity if self.embedding_function is not None else None, ) for result in results ] @@ -324,15 +324,22 @@ def _query_collection_multi( results: List[Any] = ( session.query( # type: ignore[attr-defined] self.EmbeddingStore, - # FIXME: Using `_score` is definitively the wrong choice. - # - https://github.com/crate-workbench/langchain/issues/19 - # - https://github.com/crate/crate/issues/15835 # TODO: Original pgvector code uses `self.distance_strategy`. # CrateDB currently only supports EUCLIDEAN. # self.distance_strategy(embedding).label("distance") - sqlalchemy.literal_column( - f"{self.EmbeddingStore.__tablename__}._score" - ).label("_score"), + sqlalchemy.func.vector_similarity( + self.EmbeddingStore.embedding, + # TODO: Just reference the `embedding` symbol here, don't + # serialize its value prematurely. + # https://github.com/crate/crate/issues/16912 + # + # Until that got fixed, marshal the arguments to + # `vector_similarity()` manually, in order to work around + # this edge case bug. We don't need to use JSON marshalling, + # because Python's string representation of a list is just + # right. + sqlalchemy.text(str(embedding)), + ).label("similarity"), ) .filter(filter_by) # CrateDB applies `KNN_MATCH` within the `WHERE` clause. @@ -341,7 +348,7 @@ def _query_collection_multi( self.EmbeddingStore.embedding, embedding, k ) ) - .order_by(sqlalchemy.desc("_score")) + .order_by(sqlalchemy.desc("similarity")) .join( self.CollectionStore, self.EmbeddingStore.collection_id == self.CollectionStore.uuid, @@ -450,7 +457,7 @@ def _select_relevance_score_fn(self) -> Callable[[float], float]: ) @staticmethod - def _euclidean_relevance_score_fn(score: float) -> float: + def _euclidean_relevance_score_fn(similarity: float) -> float: """Return a similarity score on a scale [0, 1].""" # The 'correct' relevance function # may differ depending on a few things, including: @@ -465,4 +472,4 @@ def _euclidean_relevance_score_fn(score: float) -> float: # Original: # return 1.0 - distance / math.sqrt(2) - return score / math.sqrt(2) + return similarity / math.sqrt(2) diff --git a/libs/community/tests/integration_tests/vectorstores/test_cratedb.py b/libs/community/tests/integration_tests/vectorstores/test_cratedb.py index c4f2b24835c62c..acc03547fe5650 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_cratedb.py +++ b/libs/community/tests/integration_tests/vectorstores/test_cratedb.py @@ -232,7 +232,7 @@ def test_cratedb_with_metadatas_with_scores() -> None: pre_delete_collection=True, ) output = docsearch.similarity_search_with_score("foo", k=1) - assert output == [(Document(page_content="foo", metadata={"page": "0"}), 2.0)] + assert output == [(Document(page_content="foo", metadata={"page": "0"}), 1.0)] def test_cratedb_with_filter_match() -> None: @@ -250,9 +250,7 @@ def test_cratedb_with_filter_match() -> None: # TODO: Original: # assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] # noqa: E501 output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "0"}) - assert output == [ - (Document(page_content="foo", metadata={"page": "0"}), pytest.approx(2.2, 0.3)) - ] + assert output == [(Document(page_content="foo", metadata={"page": "0"}), 1.0)] def test_cratedb_with_filter_distant_match() -> None: @@ -269,9 +267,7 @@ def test_cratedb_with_filter_distant_match() -> None: ) output = docsearch.similarity_search_with_score("foo", k=2, filter={"page": "2"}) # Original score value: 0.0013003906671379406 - assert output == [ - (Document(page_content="baz", metadata={"page": "2"}), pytest.approx(1.5, 0.2)) - ] + assert output == [(Document(page_content="baz", metadata={"page": "2"}), 0.2)] def test_cratedb_with_filter_no_match() -> None: @@ -425,8 +421,8 @@ def test_cratedb_with_filter_in_set() -> None: ) # Original score values: 0.0, 0.0013003906671379406 assert output == [ - (Document(page_content="foo", metadata={"page": "0"}), pytest.approx(3.0, 0.1)), - (Document(page_content="baz", metadata={"page": "2"}), pytest.approx(2.2, 0.1)), + (Document(page_content="foo", metadata={"page": "0"}), 1.0), + (Document(page_content="baz", metadata={"page": "2"}), 0.2), ] @@ -474,9 +470,9 @@ def test_cratedb_relevance_score() -> None: output = docsearch.similarity_search_with_relevance_scores("foo", k=3) # Original score values: 1.0, 0.9996744261675065, 0.9986996093328621 assert output == [ - (Document(page_content="foo", metadata={"page": "0"}), pytest.approx(1.4, 0.1)), - (Document(page_content="bar", metadata={"page": "1"}), pytest.approx(1.1, 0.1)), - (Document(page_content="baz", metadata={"page": "2"}), pytest.approx(0.8, 0.1)), + (Document(page_content="foo", metadata={"page": "0"}), 0.7071067811865475), + (Document(page_content="bar", metadata={"page": "1"}), 0.35355339059327373), + (Document(page_content="baz", metadata={"page": "2"}), 0.1414213562373095), ] @@ -495,9 +491,9 @@ def test_cratedb_retriever_search_threshold() -> None: retriever = docsearch.as_retriever( search_type="similarity_score_threshold", - search_kwargs={"k": 3, "score_threshold": 0.999}, + search_kwargs={"k": 3, "score_threshold": 0.35}, # Original value: 0.999 ) - output = retriever.get_relevant_documents("summer") + output = retriever.invoke("summer") assert output == [ Document(page_content="foo", metadata={"page": "0"}), Document(page_content="bar", metadata={"page": "1"}), @@ -522,7 +518,7 @@ def test_cratedb_retriever_search_threshold_custom_normalization_fn() -> None: search_type="similarity_score_threshold", search_kwargs={"k": 3, "score_threshold": 0.5}, ) - output = retriever.get_relevant_documents("foo") + output = retriever.invoke("foo") assert output == [] @@ -551,7 +547,7 @@ def test_cratedb_max_marginal_relevance_search_with_score() -> None: pre_delete_collection=True, ) output = docsearch.max_marginal_relevance_search_with_score("foo", k=1, fetch_k=3) - assert output == [(Document(page_content="foo"), 2.0)] + assert output == [(Document(page_content="foo"), 1.0)] def test_cratedb_multicollection_search_success() -> None: