-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use MetadataFilter in LanceDB (#461)
Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Nick Byrne <[email protected]>
- Loading branch information
1 parent
1934b72
commit 13b1c0f
Showing
2 changed files
with
215 additions
and
81 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,43 +1,112 @@ | ||
import pytest | ||
|
||
from ragna.core import LocalDocument, MetadataFilter | ||
from ragna.core import LocalDocument, MetadataFilter, PlainTextDocumentHandler | ||
from ragna.source_storages import Chroma, LanceDB | ||
|
||
METADATAS = { | ||
0: {"key": "value"}, | ||
1: {"key": "value", "other_key": "other_value"}, | ||
2: {"key": "other_value"}, | ||
3: {"other_key": "value"}, | ||
4: {"other_key": "other_value"}, | ||
5: {"key": "foo"}, | ||
6: {"key": "bar"}, | ||
} | ||
|
||
@pytest.mark.parametrize( | ||
"source_storage_cls", | ||
metadata_filters = pytest.mark.parametrize( | ||
("metadata_filter", "expected_idcs"), | ||
[ | ||
Chroma, | ||
# FIXME: remove after LanceDB is fixed | ||
pytest.param(LanceDB, marks=pytest.mark.xfail()), | ||
pytest.param( | ||
MetadataFilter.and_( | ||
[ | ||
MetadataFilter.eq("key", "value"), | ||
MetadataFilter.eq("other_key", "other_value"), | ||
] | ||
), | ||
[1], | ||
id="and", | ||
), | ||
pytest.param( | ||
MetadataFilter.or_( | ||
[ | ||
MetadataFilter.eq("key", "value"), | ||
MetadataFilter.eq("key", "other_value"), | ||
] | ||
), | ||
[0, 1, 2], | ||
id="or", | ||
), | ||
pytest.param( | ||
MetadataFilter.and_( | ||
[ | ||
MetadataFilter.eq("key", "value"), | ||
MetadataFilter.or_( | ||
[ | ||
MetadataFilter.eq("key", "other_value"), | ||
MetadataFilter.eq("other_key", "other_value"), | ||
] | ||
), | ||
] | ||
), | ||
[1], | ||
id="and-nested", | ||
), | ||
pytest.param( | ||
MetadataFilter.or_( | ||
[ | ||
MetadataFilter.eq("key", "value"), | ||
MetadataFilter.and_( | ||
[ | ||
MetadataFilter.eq("key", "other_value"), | ||
MetadataFilter.ne("other_key", "other_value"), | ||
] | ||
), | ||
] | ||
), | ||
[0, 1], | ||
id="or-nested", | ||
), | ||
pytest.param(MetadataFilter.eq("key", "value"), [0, 1], id="eq"), | ||
pytest.param(MetadataFilter.ne("key", "value"), [2, 5, 6], id="ne"), | ||
pytest.param(MetadataFilter.in_("key", ["foo", "bar"]), [5, 6], id="in"), | ||
pytest.param( | ||
MetadataFilter.not_in("key", ["foo", "bar"]), [0, 1, 2], id="not_in" | ||
), | ||
pytest.param(None, [0, 1, 2, 3, 4, 5, 6], id="none"), | ||
], | ||
) | ||
def test_smoke(tmp_local_root, source_storage_cls): | ||
|
||
|
||
@metadata_filters | ||
@pytest.mark.parametrize("source_storage_cls", [Chroma, LanceDB]) | ||
def test_smoke(tmp_local_root, source_storage_cls, metadata_filter, expected_idcs): | ||
document_root = tmp_local_root / "documents" | ||
document_root.mkdir() | ||
documents = [] | ||
for idx in range(10): | ||
path = document_root / f"irrelevant{idx}.txt" | ||
for idx, meta_dict in METADATAS.items(): | ||
path = document_root / str(idx) | ||
with open(path, "w") as file: | ||
file.write(f"This is irrelevant information for the {idx}. time!\n") | ||
|
||
documents.append(LocalDocument.from_path(path)) | ||
|
||
secret = "Ragna" | ||
path = document_root / "secret.txt" | ||
with open(path, "w") as file: | ||
file.write(f"The secret is {secret}!\n") | ||
file.write(f"The secret number is {idx}!\n") | ||
|
||
documents.insert(len(documents) // 2, LocalDocument.from_path(path)) | ||
documents.append( | ||
LocalDocument.from_path( | ||
path, | ||
metadata=meta_dict | {"idx": idx}, | ||
handler=PlainTextDocumentHandler(), | ||
) | ||
) | ||
|
||
source_storage = source_storage_cls() | ||
|
||
source_storage.store(documents) | ||
|
||
metadata_filter = MetadataFilter.or_( | ||
[MetadataFilter.eq("document_id", str(document.id)) for document in documents] | ||
prompt = "What is the secret number?" | ||
num_tokens = 4096 | ||
sources = source_storage.retrieve( | ||
metadata_filter=metadata_filter, prompt=prompt, num_tokens=num_tokens | ||
) | ||
prompt = "What is the secret?" | ||
sources = source_storage.retrieve(metadata_filter, prompt) | ||
|
||
assert secret in sources[0].content | ||
actual_idcs = sorted(map(int, (source.document_name for source in sources))) | ||
assert actual_idcs == expected_idcs | ||
|
||
# Should be able to call .store() multiple times | ||
source_storage.store(documents) |