From 4eab119528b74eabcfbeff33822b171cae94005f Mon Sep 17 00:00:00 2001 From: Brendan Kehoe Date: Fri, 24 May 2024 19:55:53 +0100 Subject: [PATCH] implement any and all metadata filters for weaviate vector store (#13365) --- .../llama_index/core/vector_stores/simple.py | 50 +++- .../llama_index/core/vector_stores/types.py | 14 +- .../tests/vector_stores/test_simple.py | 224 +++++++++++++++++- .../vector_stores/weaviate/base.py | 32 +-- 4 files changed, 276 insertions(+), 44 deletions(-) diff --git a/llama-index-core/llama_index/core/vector_stores/simple.py b/llama-index-core/llama_index/core/vector_stores/simple.py index 9b7272447a446..c9343c255b7b1 100644 --- a/llama-index-core/llama_index/core/vector_stores/simple.py +++ b/llama-index-core/llama_index/core/vector_stores/simple.py @@ -22,6 +22,7 @@ BasePydanticVectorStore, MetadataFilters, FilterCondition, + FilterOperator, VectorStoreQuery, VectorStoreQueryMode, VectorStoreQueryResult, @@ -47,27 +48,56 @@ def _build_metadata_filter_fn( metadata_filters: Optional[MetadataFilters] = None, ) -> Callable[[str], bool]: """Build metadata filter function.""" - filter_list = metadata_filters.legacy_filters() if metadata_filters else [] + filter_list = metadata_filters.filters if metadata_filters else [] if not filter_list: return lambda _: True filter_condition = cast(MetadataFilters, metadata_filters.condition) def filter_fn(node_id: str) -> bool: + def _process_filter_match( + operator: FilterOperator, value: Any, metadata_value: Any + ) -> bool: + if metadata_value is None: + return False + if operator == FilterOperator.EQ: + return metadata_value == value + if operator == FilterOperator.NE: + return metadata_value != value + if operator == FilterOperator.GT: + return metadata_value > value + if operator == FilterOperator.GTE: + return metadata_value >= value + if operator == FilterOperator.LT: + return metadata_value < value + if operator == FilterOperator.LTE: + return metadata_value <= value + if operator == FilterOperator.IN: + return value in metadata_value + if operator == FilterOperator.NIN: + return value not in metadata_value + if operator == FilterOperator.CONTAINS: + return value in metadata_value + if operator == FilterOperator.TEXT_MATCH: + return value.lower() in metadata_value.lower() + if operator == FilterOperator.ALL: + return all(val in metadata_value for val in value) + if operator == FilterOperator.ANY: + return any(val in metadata_value for val in value) + raise ValueError(f"Invalid operator: {operator}") + metadata = metadata_lookup_fn(node_id) filter_matches_list = [] for filter_ in filter_list: filter_matches = True - metadata_value = metadata.get(filter_.key, None) - if metadata_value is None: - filter_matches = False - elif isinstance(metadata_value, list): - if filter_.value not in metadata_value: - filter_matches = False - elif isinstance(metadata_value, (int, float, str, bool)): - if metadata_value != filter_.value: - filter_matches = False + + filter_matches = _process_filter_match( + operator=filter_.operator, + value=filter_.value, + metadata_value=metadata.get(filter_.key, None), + ) + filter_matches_list.append(filter_matches) if filter_condition == FilterCondition.AND: diff --git a/llama-index-core/llama_index/core/vector_stores/types.py b/llama-index-core/llama_index/core/vector_stores/types.py index 9345b5046a56a..961c8dc4bbd9e 100644 --- a/llama-index-core/llama_index/core/vector_stores/types.py +++ b/llama-index-core/llama_index/core/vector_stores/types.py @@ -1,5 +1,4 @@ """Vector store index types.""" - from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum @@ -69,8 +68,10 @@ class FilterOperator(str, Enum): NE = "!=" # not equal to (string, int, float) GTE = ">=" # greater than or equal to (int, float) LTE = "<=" # less than or equal to (int, float) - IN = "in" # metadata in value array (string or number) - NIN = "nin" # metadata not in value array (string or number) + IN = "in" # In array (string or number) + NIN = "nin" # Not in array (string or number) + ANY = "any" # Contains any (array of strings) + ALL = "all" # Contains all (array of strings) TEXT_MATCH = "text_match" # full text match (allows you to search for a specific substring, token or phrase within the text field) CONTAINS = "contains" # metadata array contains value (string or number) @@ -93,12 +94,7 @@ class MetadataFilter(BaseModel): """ key: str - value: Union[ - StrictInt, - StrictFloat, - StrictStr, - List[Union[StrictInt, StrictFloat, StrictStr]], - ] + value: Union[StrictInt, StrictFloat, StrictStr, List[StrictStr]] operator: FilterOperator = FilterOperator.EQ @classmethod diff --git a/llama-index-core/tests/vector_stores/test_simple.py b/llama-index-core/tests/vector_stores/test_simple.py index 00f098429b3b0..7c5721d403715 100644 --- a/llama-index-core/tests/vector_stores/test_simple.py +++ b/llama-index-core/tests/vector_stores/test_simple.py @@ -8,6 +8,8 @@ MetadataFilters, VectorStoreQuery, FilterCondition, + MetadataFilter, + FilterOperator, ) _NODE_ID_WEIGHT_1_RANK_A = "AF3BE6C4-5F43-4D74-B075-6B0E07900DE8" @@ -22,21 +24,36 @@ def _node_embeddings_for_test() -> List[TextNode]: id_=_NODE_ID_WEIGHT_1_RANK_A, embedding=[1.0, 0.0], relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="test-0")}, - metadata={"weight": 1.0, "rank": "a"}, + metadata={ + "weight": 1.0, + "rank": "a", + "quality": ["medium", "high"], + "identifier": "6FTR78Yun", + }, ), TextNode( text="lorem ipsum", id_=_NODE_ID_WEIGHT_2_RANK_C, embedding=[0.0, 1.0], relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="test-1")}, - metadata={"weight": 2.0, "rank": "c"}, + metadata={ + "weight": 2.0, + "rank": "c", + "quality": ["medium"], + "identifier": "6FTR78Ygl", + }, ), TextNode( text="lorem ipsum", id_=_NODE_ID_WEIGHT_3_RANK_C, embedding=[1.0, 1.0], relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="test-2")}, - metadata={"weight": 3.0, "rank": "c"}, + metadata={ + "weight": 3.0, + "rank": "c", + "quality": ["low", "medium", "high"], + "identifier": "6FTR78Ztl", + }, ), ] @@ -182,6 +199,207 @@ def test_query_with_filters_with_filter_condition(self) -> None: result = simple_vector_store.query(query) self.assertEqual(len(result.ids), 0) + def test_query_with_equal_filter_returns_matches(self) -> None: + simple_vector_store = SimpleVectorStore() + simple_vector_store.add(_node_embeddings_for_test()) + + filters = MetadataFilters( + filters=[ + MetadataFilter(key="weight", operator=FilterOperator.EQ, value=1.0) + ] + ) + query = VectorStoreQuery( + query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3 + ) + result = simple_vector_store.query(query) + assert result.ids is not None + self.assertEqual(len(result.ids), 1) + + def test_query_with_notequal_filter_returns_matches(self) -> None: + simple_vector_store = SimpleVectorStore() + simple_vector_store.add(_node_embeddings_for_test()) + + filters = MetadataFilters( + filters=[ + MetadataFilter(key="weight", operator=FilterOperator.NE, value=1.0) + ] + ) + query = VectorStoreQuery( + query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3 + ) + result = simple_vector_store.query(query) + assert result.ids is not None + self.assertEqual(len(result.ids), 2) + + def test_query_with_greaterthan_filter_returns_matches(self) -> None: + simple_vector_store = SimpleVectorStore() + simple_vector_store.add(_node_embeddings_for_test()) + + filters = MetadataFilters( + filters=[ + MetadataFilter(key="weight", operator=FilterOperator.GT, value=1.5) + ] + ) + query = VectorStoreQuery( + query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3 + ) + result = simple_vector_store.query(query) + assert result.ids is not None + self.assertEqual(len(result.ids), 2) + + def test_query_with_greaterthanequal_filter_returns_matches(self) -> None: + simple_vector_store = SimpleVectorStore() + simple_vector_store.add(_node_embeddings_for_test()) + + filters = MetadataFilters( + filters=[ + MetadataFilter(key="weight", operator=FilterOperator.GTE, value=1.0) + ] + ) + query = VectorStoreQuery( + query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3 + ) + result = simple_vector_store.query(query) + assert result.ids is not None + self.assertEqual(len(result.ids), 3) + + def test_query_with_lessthan_filter_returns_matches(self) -> None: + simple_vector_store = SimpleVectorStore() + simple_vector_store.add(_node_embeddings_for_test()) + + filters = MetadataFilters( + filters=[ + MetadataFilter(key="weight", operator=FilterOperator.LT, value=1.1) + ] + ) + query = VectorStoreQuery( + query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3 + ) + result = simple_vector_store.query(query) + assert result.ids is not None + + def test_query_with_lessthanequal_filter_returns_matches(self) -> None: + simple_vector_store = SimpleVectorStore() + simple_vector_store.add(_node_embeddings_for_test()) + + filters = MetadataFilters( + filters=[ + MetadataFilter(key="weight", operator=FilterOperator.LTE, value=1.0) + ] + ) + query = VectorStoreQuery( + query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3 + ) + result = simple_vector_store.query(query) + assert result.ids is not None + self.assertEqual(len(result.ids), 1) + + def test_query_with_in_filter_returns_matches(self) -> None: + simple_vector_store = SimpleVectorStore() + simple_vector_store.add(_node_embeddings_for_test()) + + filters = MetadataFilters( + filters=[ + MetadataFilter(key="quality", operator=FilterOperator.IN, value="high") + ] + ) + query = VectorStoreQuery( + query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3 + ) + result = simple_vector_store.query(query) + assert result.ids is not None + self.assertEqual(len(result.ids), 2) + + def test_query_with_notin_filter_returns_matches(self) -> None: + simple_vector_store = SimpleVectorStore() + simple_vector_store.add(_node_embeddings_for_test()) + + filters = MetadataFilters( + filters=[ + MetadataFilter(key="quality", operator=FilterOperator.NIN, value="high") + ] + ) + query = VectorStoreQuery( + query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3 + ) + result = simple_vector_store.query(query) + assert result.ids is not None + self.assertEqual(len(result.ids), 1) + + def test_query_with_contains_filter_returns_matches(self) -> None: + simple_vector_store = SimpleVectorStore() + simple_vector_store.add(_node_embeddings_for_test()) + + filters = MetadataFilters( + filters=[ + MetadataFilter( + key="quality", operator=FilterOperator.CONTAINS, value="high" + ) + ] + ) + query = VectorStoreQuery( + query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3 + ) + result = simple_vector_store.query(query) + assert result.ids is not None + self.assertEqual(len(result.ids), 2) + + def test_query_with_textmatch_filter_returns_matches(self) -> None: + simple_vector_store = SimpleVectorStore() + simple_vector_store.add(_node_embeddings_for_test()) + + filters = MetadataFilters( + filters=[ + MetadataFilter( + key="identifier", + operator=FilterOperator.TEXT_MATCH, + value="6FTR78Y", + ) + ] + ) + query = VectorStoreQuery( + query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3 + ) + result = simple_vector_store.query(query) + assert result.ids is not None + self.assertEqual(len(result.ids), 2) + + def test_query_with_any_filter_returns_matches(self) -> None: + simple_vector_store = SimpleVectorStore() + simple_vector_store.add(_node_embeddings_for_test()) + + filters = MetadataFilters( + filters=[ + MetadataFilter( + key="quality", operator=FilterOperator.ANY, value=["high", "low"] + ) + ] + ) + query = VectorStoreQuery( + query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3 + ) + result = simple_vector_store.query(query) + assert result.ids is not None + self.assertEqual(len(result.ids), 2) + + def test_query_with_all_filter_returns_matches(self) -> None: + simple_vector_store = SimpleVectorStore() + simple_vector_store.add(_node_embeddings_for_test()) + + filters = MetadataFilters( + filters=[ + MetadataFilter( + key="quality", operator=FilterOperator.ALL, value=["medium", "high"] + ) + ] + ) + query = VectorStoreQuery( + query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3 + ) + result = simple_vector_store.query(query) + assert result.ids is not None + self.assertEqual(len(result.ids), 2) + def test_clear(self) -> None: simple_vector_store = SimpleVectorStore() simple_vector_store.add(_node_embeddings_for_test()) diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-weaviate/llama_index/vector_stores/weaviate/base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-weaviate/llama_index/vector_stores/weaviate/base.py index 2dce3c4e12c3a..8d24c0b18894a 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-weaviate/llama_index/vector_stores/weaviate/base.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-weaviate/llama_index/vector_stores/weaviate/base.py @@ -24,7 +24,6 @@ create_default_schema, get_all_properties, get_node_similarity, - parse_get_response, to_node, ) @@ -59,6 +58,10 @@ def _transform_weaviate_filter_operator(operator: str) -> str: return "greater_or_equal" elif operator == "<=": return "less_or_equal" + elif operator == "any": + return "contains_any" + elif operator == "all": + return "contains_all" else: raise ValueError(f"Filter operator {operator} not supported") @@ -246,29 +249,14 @@ def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: ref_doc_id (str): The doc_id of the document to delete. """ - where_filter = { - "path": ["ref_doc_id"], - "operator": "Equal", - "valueText": ref_doc_id, - } + collection = self._client.collections.get(self.index_name) + + where_filter = wvc.query.Filter.by_property("ref_doc_id").equal(ref_doc_id) + if "filter" in delete_kwargs and delete_kwargs["filter"] is not None: - where_filter = { - "operator": "And", - "operands": [where_filter, delete_kwargs["filter"]], # type: ignore - } - - query = ( - self._client.query.get(self.index_name) - .with_additional(["id"]) - .with_where(where_filter) - .with_limit(10000) # 10,000 is the max weaviate can fetch - ) + where_filter = where_filter & _to_weaviate_filter(delete_kwargs["filter"]) - query_result = query.do() - parsed_result = parse_get_response(query_result) - entries = parsed_result[self.index_name] - for entry in entries: - self._client.data_object.delete(entry["_additional"]["id"], self.index_name) + collection.data.delete_many(where=where_filter) def delete_index(self) -> None: """Delete the index associated with the client.