Skip to content

Commit

Permalink
implement any and all metadata filters for weaviate vector store (run…
Browse files Browse the repository at this point in the history
  • Loading branch information
brenkehoe authored May 24, 2024
1 parent d6934bc commit 4eab119
Show file tree
Hide file tree
Showing 4 changed files with 276 additions and 44 deletions.
50 changes: 40 additions & 10 deletions llama-index-core/llama_index/core/vector_stores/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
BasePydanticVectorStore,
MetadataFilters,
FilterCondition,
FilterOperator,
VectorStoreQuery,
VectorStoreQueryMode,
VectorStoreQueryResult,
Expand All @@ -47,27 +48,56 @@ def _build_metadata_filter_fn(
metadata_filters: Optional[MetadataFilters] = None,
) -> Callable[[str], bool]:
"""Build metadata filter function."""
filter_list = metadata_filters.legacy_filters() if metadata_filters else []
filter_list = metadata_filters.filters if metadata_filters else []
if not filter_list:
return lambda _: True

filter_condition = cast(MetadataFilters, metadata_filters.condition)

def filter_fn(node_id: str) -> bool:
def _process_filter_match(
operator: FilterOperator, value: Any, metadata_value: Any
) -> bool:
if metadata_value is None:
return False
if operator == FilterOperator.EQ:
return metadata_value == value
if operator == FilterOperator.NE:
return metadata_value != value
if operator == FilterOperator.GT:
return metadata_value > value
if operator == FilterOperator.GTE:
return metadata_value >= value
if operator == FilterOperator.LT:
return metadata_value < value
if operator == FilterOperator.LTE:
return metadata_value <= value
if operator == FilterOperator.IN:
return value in metadata_value
if operator == FilterOperator.NIN:
return value not in metadata_value
if operator == FilterOperator.CONTAINS:
return value in metadata_value
if operator == FilterOperator.TEXT_MATCH:
return value.lower() in metadata_value.lower()
if operator == FilterOperator.ALL:
return all(val in metadata_value for val in value)
if operator == FilterOperator.ANY:
return any(val in metadata_value for val in value)
raise ValueError(f"Invalid operator: {operator}")

metadata = metadata_lookup_fn(node_id)

filter_matches_list = []
for filter_ in filter_list:
filter_matches = True
metadata_value = metadata.get(filter_.key, None)
if metadata_value is None:
filter_matches = False
elif isinstance(metadata_value, list):
if filter_.value not in metadata_value:
filter_matches = False
elif isinstance(metadata_value, (int, float, str, bool)):
if metadata_value != filter_.value:
filter_matches = False

filter_matches = _process_filter_match(
operator=filter_.operator,
value=filter_.value,
metadata_value=metadata.get(filter_.key, None),
)

filter_matches_list.append(filter_matches)

if filter_condition == FilterCondition.AND:
Expand Down
14 changes: 5 additions & 9 deletions llama-index-core/llama_index/core/vector_stores/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Vector store index types."""

from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
Expand Down Expand Up @@ -69,8 +68,10 @@ class FilterOperator(str, Enum):
NE = "!=" # not equal to (string, int, float)
GTE = ">=" # greater than or equal to (int, float)
LTE = "<=" # less than or equal to (int, float)
IN = "in" # metadata in value array (string or number)
NIN = "nin" # metadata not in value array (string or number)
IN = "in" # In array (string or number)
NIN = "nin" # Not in array (string or number)
ANY = "any" # Contains any (array of strings)
ALL = "all" # Contains all (array of strings)
TEXT_MATCH = "text_match" # full text match (allows you to search for a specific substring, token or phrase within the text field)
CONTAINS = "contains" # metadata array contains value (string or number)

Expand All @@ -93,12 +94,7 @@ class MetadataFilter(BaseModel):
"""

key: str
value: Union[
StrictInt,
StrictFloat,
StrictStr,
List[Union[StrictInt, StrictFloat, StrictStr]],
]
value: Union[StrictInt, StrictFloat, StrictStr, List[StrictStr]]
operator: FilterOperator = FilterOperator.EQ

@classmethod
Expand Down
224 changes: 221 additions & 3 deletions llama-index-core/tests/vector_stores/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
MetadataFilters,
VectorStoreQuery,
FilterCondition,
MetadataFilter,
FilterOperator,
)

_NODE_ID_WEIGHT_1_RANK_A = "AF3BE6C4-5F43-4D74-B075-6B0E07900DE8"
Expand All @@ -22,21 +24,36 @@ def _node_embeddings_for_test() -> List[TextNode]:
id_=_NODE_ID_WEIGHT_1_RANK_A,
embedding=[1.0, 0.0],
relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="test-0")},
metadata={"weight": 1.0, "rank": "a"},
metadata={
"weight": 1.0,
"rank": "a",
"quality": ["medium", "high"],
"identifier": "6FTR78Yun",
},
),
TextNode(
text="lorem ipsum",
id_=_NODE_ID_WEIGHT_2_RANK_C,
embedding=[0.0, 1.0],
relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="test-1")},
metadata={"weight": 2.0, "rank": "c"},
metadata={
"weight": 2.0,
"rank": "c",
"quality": ["medium"],
"identifier": "6FTR78Ygl",
},
),
TextNode(
text="lorem ipsum",
id_=_NODE_ID_WEIGHT_3_RANK_C,
embedding=[1.0, 1.0],
relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="test-2")},
metadata={"weight": 3.0, "rank": "c"},
metadata={
"weight": 3.0,
"rank": "c",
"quality": ["low", "medium", "high"],
"identifier": "6FTR78Ztl",
},
),
]

Expand Down Expand Up @@ -182,6 +199,207 @@ def test_query_with_filters_with_filter_condition(self) -> None:
result = simple_vector_store.query(query)
self.assertEqual(len(result.ids), 0)

def test_query_with_equal_filter_returns_matches(self) -> None:
simple_vector_store = SimpleVectorStore()
simple_vector_store.add(_node_embeddings_for_test())

filters = MetadataFilters(
filters=[
MetadataFilter(key="weight", operator=FilterOperator.EQ, value=1.0)
]
)
query = VectorStoreQuery(
query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3
)
result = simple_vector_store.query(query)
assert result.ids is not None
self.assertEqual(len(result.ids), 1)

def test_query_with_notequal_filter_returns_matches(self) -> None:
simple_vector_store = SimpleVectorStore()
simple_vector_store.add(_node_embeddings_for_test())

filters = MetadataFilters(
filters=[
MetadataFilter(key="weight", operator=FilterOperator.NE, value=1.0)
]
)
query = VectorStoreQuery(
query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3
)
result = simple_vector_store.query(query)
assert result.ids is not None
self.assertEqual(len(result.ids), 2)

def test_query_with_greaterthan_filter_returns_matches(self) -> None:
simple_vector_store = SimpleVectorStore()
simple_vector_store.add(_node_embeddings_for_test())

filters = MetadataFilters(
filters=[
MetadataFilter(key="weight", operator=FilterOperator.GT, value=1.5)
]
)
query = VectorStoreQuery(
query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3
)
result = simple_vector_store.query(query)
assert result.ids is not None
self.assertEqual(len(result.ids), 2)

def test_query_with_greaterthanequal_filter_returns_matches(self) -> None:
simple_vector_store = SimpleVectorStore()
simple_vector_store.add(_node_embeddings_for_test())

filters = MetadataFilters(
filters=[
MetadataFilter(key="weight", operator=FilterOperator.GTE, value=1.0)
]
)
query = VectorStoreQuery(
query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3
)
result = simple_vector_store.query(query)
assert result.ids is not None
self.assertEqual(len(result.ids), 3)

def test_query_with_lessthan_filter_returns_matches(self) -> None:
simple_vector_store = SimpleVectorStore()
simple_vector_store.add(_node_embeddings_for_test())

filters = MetadataFilters(
filters=[
MetadataFilter(key="weight", operator=FilterOperator.LT, value=1.1)
]
)
query = VectorStoreQuery(
query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3
)
result = simple_vector_store.query(query)
assert result.ids is not None

def test_query_with_lessthanequal_filter_returns_matches(self) -> None:
simple_vector_store = SimpleVectorStore()
simple_vector_store.add(_node_embeddings_for_test())

filters = MetadataFilters(
filters=[
MetadataFilter(key="weight", operator=FilterOperator.LTE, value=1.0)
]
)
query = VectorStoreQuery(
query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3
)
result = simple_vector_store.query(query)
assert result.ids is not None
self.assertEqual(len(result.ids), 1)

def test_query_with_in_filter_returns_matches(self) -> None:
simple_vector_store = SimpleVectorStore()
simple_vector_store.add(_node_embeddings_for_test())

filters = MetadataFilters(
filters=[
MetadataFilter(key="quality", operator=FilterOperator.IN, value="high")
]
)
query = VectorStoreQuery(
query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3
)
result = simple_vector_store.query(query)
assert result.ids is not None
self.assertEqual(len(result.ids), 2)

def test_query_with_notin_filter_returns_matches(self) -> None:
simple_vector_store = SimpleVectorStore()
simple_vector_store.add(_node_embeddings_for_test())

filters = MetadataFilters(
filters=[
MetadataFilter(key="quality", operator=FilterOperator.NIN, value="high")
]
)
query = VectorStoreQuery(
query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3
)
result = simple_vector_store.query(query)
assert result.ids is not None
self.assertEqual(len(result.ids), 1)

def test_query_with_contains_filter_returns_matches(self) -> None:
simple_vector_store = SimpleVectorStore()
simple_vector_store.add(_node_embeddings_for_test())

filters = MetadataFilters(
filters=[
MetadataFilter(
key="quality", operator=FilterOperator.CONTAINS, value="high"
)
]
)
query = VectorStoreQuery(
query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3
)
result = simple_vector_store.query(query)
assert result.ids is not None
self.assertEqual(len(result.ids), 2)

def test_query_with_textmatch_filter_returns_matches(self) -> None:
simple_vector_store = SimpleVectorStore()
simple_vector_store.add(_node_embeddings_for_test())

filters = MetadataFilters(
filters=[
MetadataFilter(
key="identifier",
operator=FilterOperator.TEXT_MATCH,
value="6FTR78Y",
)
]
)
query = VectorStoreQuery(
query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3
)
result = simple_vector_store.query(query)
assert result.ids is not None
self.assertEqual(len(result.ids), 2)

def test_query_with_any_filter_returns_matches(self) -> None:
simple_vector_store = SimpleVectorStore()
simple_vector_store.add(_node_embeddings_for_test())

filters = MetadataFilters(
filters=[
MetadataFilter(
key="quality", operator=FilterOperator.ANY, value=["high", "low"]
)
]
)
query = VectorStoreQuery(
query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3
)
result = simple_vector_store.query(query)
assert result.ids is not None
self.assertEqual(len(result.ids), 2)

def test_query_with_all_filter_returns_matches(self) -> None:
simple_vector_store = SimpleVectorStore()
simple_vector_store.add(_node_embeddings_for_test())

filters = MetadataFilters(
filters=[
MetadataFilter(
key="quality", operator=FilterOperator.ALL, value=["medium", "high"]
)
]
)
query = VectorStoreQuery(
query_embedding=[1.0, 1.0], filters=filters, similarity_top_k=3
)
result = simple_vector_store.query(query)
assert result.ids is not None
self.assertEqual(len(result.ids), 2)

def test_clear(self) -> None:
simple_vector_store = SimpleVectorStore()
simple_vector_store.add(_node_embeddings_for_test())
Expand Down
Loading

0 comments on commit 4eab119

Please sign in to comment.