Skip to content

Commit

Permalink
[BUG]: Metada DB cleanup upon collection delete (#1320)
Browse files Browse the repository at this point in the history
Refs: #1289

## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
	 - Implemented removal upon

## Test plan
*How are these changes tested?*

- [x] Tests pass locally with `pytest` for python

## Documentation Changes
N/A
  • Loading branch information
tazarov authored Dec 6, 2023
1 parent eea76bf commit 8adb20a
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 17 deletions.
3 changes: 3 additions & 0 deletions chromadb/segment/impl/manager/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ def delete_segments(self, collection_id: UUID) -> Sequence[UUID]:
if segment["type"] == SegmentType.HNSW_LOCAL_PERSISTED.value:
instance = self.get_segment(collection_id, VectorReader)
instance.delete()
elif segment["type"] == SegmentType.SQLITE.value:
instance = self.get_segment(collection_id, MetadataReader)
instance.delete()
del self._instances[segment["id"]]
if collection_id in self._segment_cache:
if segment["scope"] in self._segment_cache[collection_id]:
Expand Down
61 changes: 44 additions & 17 deletions chromadb/segment/impl/metadata/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ def count(self) -> int:
self._db.querybuilder()
.from_(embeddings_t)
.where(
embeddings_t.segment_id == ParameterValue(
self._db.uuid_to_db(self._id))
embeddings_t.segment_id == ParameterValue(self._db.uuid_to_db(self._id))
)
.select(fn.Count(embeddings_t.id))
)
Expand Down Expand Up @@ -140,19 +139,16 @@ def get_metadata(
metadata_t.bool_value,
)
.where(
embeddings_t.segment_id == ParameterValue(
self._db.uuid_to_db(self._id))
embeddings_t.segment_id == ParameterValue(self._db.uuid_to_db(self._id))
)
.orderby(embeddings_t.id)
)

if where:
q = q.where(self._where_map_criterion(
q, where, embeddings_t, metadata_t))
q = q.where(self._where_map_criterion(q, where, embeddings_t, metadata_t))
if where_document:
q = q.where(
self._where_doc_criterion(
q, where_document, embeddings_t, fulltext_t)
self._where_doc_criterion(q, where_document, embeddings_t, fulltext_t)
)

if ids:
Expand Down Expand Up @@ -231,8 +227,7 @@ def _insert_record(
if upsert:
return self._update_record(cur, record)
else:
logger.warning(
f"Insert of existing embedding ID: {record['id']}")
logger.warning(f"Insert of existing embedding ID: {record['id']}")
# We are trying to add for a record that already exists. Fail the call.
# We don't throw an exception since this is in principal an async path
return
Expand Down Expand Up @@ -366,8 +361,7 @@ def _delete_record(self, cur: Cursor, record: EmbeddingRecord) -> None:
sql = sql + " RETURNING id"
result = cur.execute(sql, params).fetchone()
if result is None:
logger.warning(
f"Delete of nonexisting embedding ID: {record['id']}")
logger.warning(f"Delete of nonexisting embedding ID: {record['id']}")
else:
id = result[0]

Expand Down Expand Up @@ -398,8 +392,7 @@ def _update_record(self, cur: Cursor, record: EmbeddingRecord) -> None:
sql = sql + " RETURNING id"
result = cur.execute(sql, params).fetchone()
if result is None:
logger.warning(
f"Update of nonexisting embedding ID: {record['id']}")
logger.warning(f"Update of nonexisting embedding ID: {record['id']}")
else:
id = result[0]
if record["metadata"]:
Expand Down Expand Up @@ -454,8 +447,7 @@ def _where_map_criterion(
]
clause.append(reduce(lambda x, y: x | y, criteria))
else:
expr = cast(
Union[LiteralValue, Dict[WhereOperator, LiteralValue]], v)
expr = cast(Union[LiteralValue, Dict[WhereOperator, LiteralValue]], v)
sq = (
self._db.querybuilder()
.from_(metadata_t)
Expand Down Expand Up @@ -504,9 +496,44 @@ def _where_doc_criterion(
raise ValueError(f"Unknown where_doc operator {k}")
raise ValueError("Empty where_doc")

@trace_method("SqliteMetadataSegment.delete", OpenTelemetryGranularity.ALL)
@override
def delete(self) -> None:
raise NotImplementedError()
t = Table("embeddings")
t1 = Table("embedding_metadata")
q0 = (
self._db.querybuilder()
.from_(t1)
.delete()
.where(
t1.id.isin(
self._db.querybuilder()
.from_(t)
.select(t.id)
.where(
t.segment_id == ParameterValue(self._db.uuid_to_db(self._id))
)
)
)
)
q = (
self._db.querybuilder()
.from_(t)
.delete()
.where(
t.id.isin(
self._db.querybuilder()
.from_(t)
.select(t.id)
.where(
t.segment_id == ParameterValue(self._db.uuid_to_db(self._id))
)
)
)
)
with self._db.tx() as cur:
cur.execute(*get_sql(q0))
cur.execute(*get_sql(q))


def _encode_seq_id(seq_id: SeqId) -> bytes:
Expand Down
40 changes: 40 additions & 0 deletions chromadb/test/segment/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import pytest
from typing import Generator, List, Callable, Iterator, Dict, Optional, Union, Sequence
from chromadb.config import System, Settings
from chromadb.db.base import ParameterValue, get_sql
from chromadb.db.impl.sqlite import SqliteDB
from chromadb.test.conftest import ProducerFn
from chromadb.types import (
SubmitEmbeddingRecord,
Expand All @@ -14,6 +16,7 @@
SegmentScope,
SeqId,
)
from pypika import Table
from chromadb.ingest import Producer
from chromadb.segment import MetadataReader
import uuid
Expand Down Expand Up @@ -530,3 +533,40 @@ def _test_update(
assert results[0]["metadata"] == {"baz": 42}
results = segment.get_metadata(where_document={"$contains": "biz"})
assert len(results) == 0


def test_delete_segment(
system: System,
sample_embeddings: Iterator[SubmitEmbeddingRecord],
produce_fns: ProducerFn,
) -> None:
producer = system.instance(Producer)
system.reset_state()
topic = str(segment_definition["topic"])

segment = SqliteMetadataSegment(system, segment_definition)
segment.start()

embeddings, seq_ids = produce_fns(producer, topic, sample_embeddings, 10)
max_id = seq_ids[-1]

sync(segment, max_id)

assert segment.count() == 10
results = segment.get_metadata(ids=["embedding_0"])
assert_equiv_records(embeddings[:1], results)
_id = segment._id
segment.delete()
_db = system.instance(SqliteDB)
t = Table("embeddings")
q = (
_db.querybuilder()
.from_(t)
.select(t.id)
.where(t.segment_id == ParameterValue(_db.uuid_to_db(_id)))
)
sql, params = get_sql(q)
with _db.tx() as cur:
res = cur.execute(sql, params)
# assert that the segment is gone
assert len(res.fetchall()) == 0

0 comments on commit 8adb20a

Please sign in to comment.