Skip to content

Commit

Permalink
Use MetadataFilter in LanceDB (#461)
Browse files Browse the repository at this point in the history
Co-authored-by: Philip Meier <[email protected]>
Co-authored-by: Nick Byrne <[email protected]>
  • Loading branch information
3 people authored Aug 8, 2024
1 parent 1934b72 commit 13b1c0f
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 81 deletions.
179 changes: 122 additions & 57 deletions ragna/source_storages/_lancedb.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
from __future__ import annotations

import uuid
from typing import cast
from collections import defaultdict
from typing import TYPE_CHECKING, Optional, cast

import ragna
from ragna.core import (
Document,
MetadataFilter,
MetadataOperator,
PackageRequirement,
RagnaException,
Requirement,
Source,
)

from ._vector_database import VectorDatabaseSourceStorage

if TYPE_CHECKING:
import lancedb


class LanceDB(VectorDatabaseSourceStorage):
"""[LanceDB vector database](https://lancedb.com/)
Expand Down Expand Up @@ -40,64 +47,115 @@ def __init__(self) -> None:
super().__init__()

import lancedb
import pyarrow as pa

self._db = lancedb.connect(ragna.local_root() / "lancedb")
# Create schema at runtime!
self._schema = pa.schema(
[
pa.field("id", pa.string()),
pa.field("document_id", pa.string()),
pa.field("page_numbers", pa.string()),
pa.field("text", pa.string()),
pa.field(
self._VECTOR_COLUMN_NAME,
pa.list_(pa.float32(), self._embedding_dimensions),
),
pa.field("num_tokens", pa.int32()),
]
)

_VECTOR_COLUMN_NAME = "embedded_text"

def _get_table(self, corpus_name: Optional[str] = None) -> lancedb.table.Table:
if corpus_name is None:
corpus_name = self._embedding_id

if corpus_name in self._db.table_names():
return self._db.open_table(corpus_name)
else:
import pyarrow as pa

return self._db.create_table(
name=corpus_name,
schema=pa.schema(
[
pa.field("id", pa.string()),
pa.field("document_id", pa.string()),
pa.field("document_name", pa.string()),
pa.field("page_numbers", pa.string()),
pa.field("text", pa.string()),
pa.field(
self._VECTOR_COLUMN_NAME,
pa.list_(pa.float32(), self._embedding_dimensions),
),
pa.field("num_tokens", pa.int32()),
]
),
)

_PYTHON_TO_LANCE_TYPE_MAP = {
bool: "boolean",
int: "int",
float: "float",
str: "string",
}

def store(
self,
documents: list[Document],
*,
corpus_name: Optional[str] = None,
chunk_size: int = 500,
chunk_overlap: int = 250,
) -> None:
table = self._db.create_table(name=self._embedding_id, schema=self._schema)
table = self._get_table(corpus_name)

document_field_types = defaultdict(set)
for document in documents:
for chunk in self._chunk_pages(
document.extract_pages(),
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
):
table.add(
[
{
"id": str(uuid.uuid4()),
"document_id": str(document.id),
"document_name": str(document.name),
"page_numbers": self._page_numbers_to_str(
chunk.page_numbers
),
"text": chunk.text,
self._VECTOR_COLUMN_NAME: self._embedding_function(
[chunk.text]
)[0],
"num_tokens": chunk.num_tokens,
}
]
for field, value in document.metadata.items():
document_field_types[field].add(type(value))

document_fields = {}
for field, types in document_field_types.items():
if len(types) > 1:
raise RagnaException(
"Multiple types for metadata value", key=field, types=sorted(types)
)
document_fields[field] = self._PYTHON_TO_LANCE_TYPE_MAP[types.pop()]

schema_fields = set(table.schema.names)

missing_fields = document_fields.keys() - schema_fields
if missing_fields:
# Unfortunately, LanceDB does not support adding columns with a specific
# type, but the the type is automatically inferred from the value.
table.add_columns(
{
field: f"CAST(NULL as {document_fields[field]})"
for field in missing_fields
}
)

default_metadata = {
field: None for field in document_fields.keys() | schema_fields
}

table.add(
[
{
# Unpacking the default metadata first so it can be
# overridden by concrete values if present
**default_metadata,
**document.metadata,
"id": str(uuid.uuid4()),
"document_id": str(document.id),
"document_name": str(document.name),
"page_numbers": self._page_numbers_to_str(chunk.page_numbers),
"text": chunk.text,
self._VECTOR_COLUMN_NAME: self._embedding_function([chunk.text])[0],
"num_tokens": chunk.num_tokens,
}
for document in documents
for chunk in self._chunk_pages(
document.extract_pages(),
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
]
)

# https://lancedb.github.io/lancedb/sql/
_METADATA_OPERATOR_MAP = {
MetadataOperator.AND: "AND",
MetadataOperator.OR: "OR",
MetadataOperator.EQ: "=",
MetadataOperator.NE: "!=",
MetadataOperator.LT: "<",
MetadataOperator.LE: "<=",
MetadataOperator.GT: ">",
Expand All @@ -112,53 +170,60 @@ def _translate_metadata_filter(self, metadata_filter: MetadataFilter) -> str:
MetadataOperator.AND,
MetadataOperator.OR,
}:
return f" {self._METADATA_OPERATOR_MAP[metadata_filter.operator]} ".join(
operator = f" {self._METADATA_OPERATOR_MAP[metadata_filter.operator]} "
return operator.join(
f"({self._translate_metadata_filter(child)})"
for child in metadata_filter.value
)
elif metadata_filter.operator is MetadataOperator.NE:
return f"NOT ({self._translate_metadata_filter(MetadataFilter.eq(metadata_filter.key, metadata_filter.value))})"
elif metadata_filter.operator is MetadataOperator.NOT_IN:
return f"NOT ({self._translate_metadata_filter(MetadataFilter.in_(metadata_filter.key, metadata_filter.value))})"
in_ = self._translate_metadata_filter(
MetadataFilter.in_(metadata_filter.key, metadata_filter.value)
)
return f"NOT ({in_})"
else:
key = metadata_filter.key
operator = self._METADATA_OPERATOR_MAP[metadata_filter.operator]
value = (
tuple(metadata_filter.value)
if metadata_filter.operator is MetadataOperator.IN
else metadata_filter.value
)
return f"{metadata_filter.key} {self._METADATA_OPERATOR_MAP[metadata_filter.operator]} {value!r}"
return f"{key} {operator} {value!r}"

def retrieve(
self,
documents: list[Document],
metadata_filter: Optional[MetadataFilter],
prompt: str,
*,
chat_id: uuid.UUID,
corpus_name: Optional[str] = None,
chunk_size: int = 500,
num_tokens: int = 1024,
) -> list[Source]:
table = self._db.open_table(str(chat_id))
table = self._get_table(corpus_name)

# We cannot retrieve source by a maximum number of tokens. Thus, we estimate how
# many sources we have to query. We overestimate by a factor of two to avoid
# retrieving to few sources and needed to query again.
# retrieving too few sources and needing to query again.
limit = int(num_tokens * 2 / chunk_size)
results = (
table.search(
self._embedding_function([prompt])[0],
vector_column_name=self._VECTOR_COLUMN_NAME,
)
.limit(limit)
.to_arrow()

search = table.search(
self._embedding_function([prompt])[0],
vector_column_name=self._VECTOR_COLUMN_NAME,
)

document_map = {str(document.id): document for document in documents}
if metadata_filter:
search = search.where(
self._translate_metadata_filter(metadata_filter), prefilter=True
)

results = search.limit(limit).to_arrow()

return self._take_sources_up_to_max_tokens(
(
Source(
id=result["id"],
document_id=uuid.UUID(result["document_id"]),
document_name=document_map[result["document_id"]].name,
document_id=result["document_id"],
document_name=result["document_name"],
# For some reason adding an empty string during store() results
# in this field being None. Thus, we need to parse it back here.
# TODO: See if there is a configuration option for this
Expand Down
117 changes: 93 additions & 24 deletions tests/source_storages/test_source_storages.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,112 @@
import pytest

from ragna.core import LocalDocument, MetadataFilter
from ragna.core import LocalDocument, MetadataFilter, PlainTextDocumentHandler
from ragna.source_storages import Chroma, LanceDB

METADATAS = {
0: {"key": "value"},
1: {"key": "value", "other_key": "other_value"},
2: {"key": "other_value"},
3: {"other_key": "value"},
4: {"other_key": "other_value"},
5: {"key": "foo"},
6: {"key": "bar"},
}

@pytest.mark.parametrize(
"source_storage_cls",
metadata_filters = pytest.mark.parametrize(
("metadata_filter", "expected_idcs"),
[
Chroma,
# FIXME: remove after LanceDB is fixed
pytest.param(LanceDB, marks=pytest.mark.xfail()),
pytest.param(
MetadataFilter.and_(
[
MetadataFilter.eq("key", "value"),
MetadataFilter.eq("other_key", "other_value"),
]
),
[1],
id="and",
),
pytest.param(
MetadataFilter.or_(
[
MetadataFilter.eq("key", "value"),
MetadataFilter.eq("key", "other_value"),
]
),
[0, 1, 2],
id="or",
),
pytest.param(
MetadataFilter.and_(
[
MetadataFilter.eq("key", "value"),
MetadataFilter.or_(
[
MetadataFilter.eq("key", "other_value"),
MetadataFilter.eq("other_key", "other_value"),
]
),
]
),
[1],
id="and-nested",
),
pytest.param(
MetadataFilter.or_(
[
MetadataFilter.eq("key", "value"),
MetadataFilter.and_(
[
MetadataFilter.eq("key", "other_value"),
MetadataFilter.ne("other_key", "other_value"),
]
),
]
),
[0, 1],
id="or-nested",
),
pytest.param(MetadataFilter.eq("key", "value"), [0, 1], id="eq"),
pytest.param(MetadataFilter.ne("key", "value"), [2, 5, 6], id="ne"),
pytest.param(MetadataFilter.in_("key", ["foo", "bar"]), [5, 6], id="in"),
pytest.param(
MetadataFilter.not_in("key", ["foo", "bar"]), [0, 1, 2], id="not_in"
),
pytest.param(None, [0, 1, 2, 3, 4, 5, 6], id="none"),
],
)
def test_smoke(tmp_local_root, source_storage_cls):


@metadata_filters
@pytest.mark.parametrize("source_storage_cls", [Chroma, LanceDB])
def test_smoke(tmp_local_root, source_storage_cls, metadata_filter, expected_idcs):
document_root = tmp_local_root / "documents"
document_root.mkdir()
documents = []
for idx in range(10):
path = document_root / f"irrelevant{idx}.txt"
for idx, meta_dict in METADATAS.items():
path = document_root / str(idx)
with open(path, "w") as file:
file.write(f"This is irrelevant information for the {idx}. time!\n")

documents.append(LocalDocument.from_path(path))

secret = "Ragna"
path = document_root / "secret.txt"
with open(path, "w") as file:
file.write(f"The secret is {secret}!\n")
file.write(f"The secret number is {idx}!\n")

documents.insert(len(documents) // 2, LocalDocument.from_path(path))
documents.append(
LocalDocument.from_path(
path,
metadata=meta_dict | {"idx": idx},
handler=PlainTextDocumentHandler(),
)
)

source_storage = source_storage_cls()

source_storage.store(documents)

metadata_filter = MetadataFilter.or_(
[MetadataFilter.eq("document_id", str(document.id)) for document in documents]
prompt = "What is the secret number?"
num_tokens = 4096
sources = source_storage.retrieve(
metadata_filter=metadata_filter, prompt=prompt, num_tokens=num_tokens
)
prompt = "What is the secret?"
sources = source_storage.retrieve(metadata_filter, prompt)

assert secret in sources[0].content
actual_idcs = sorted(map(int, (source.document_name for source in sources)))
assert actual_idcs == expected_idcs

# Should be able to call .store() multiple times
source_storage.store(documents)

0 comments on commit 13b1c0f

Please sign in to comment.