Skip to content

Commit

Permalink
[ENH] add .clean_log() to Producers (#2549)
Browse files Browse the repository at this point in the history
Depends on #2545.

Changes:

- Adds a `clean_log()` method to producers (not called automatically in
this PR).
- The existing table `max_seq_id` is now used to track the maximum seen
sequence ID for both metadata and vector segments (formerly only used by
metadata segments).
- Segments are expected to update the `max_seq_id` table themselves.
- Vector segments will automatically migrate the `max_seq_id` field from
the old pickled metadata file source into the database upon init.

In this PR, log entries are deleted on a per-collection basis. The next
PR in this stack deletes entries globally.
  • Loading branch information
codetheweb authored Jul 29, 2024
1 parent 8907e53 commit 7edbb5b
Show file tree
Hide file tree
Showing 15 changed files with 403 additions and 56 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/_python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,13 @@ jobs:
"chromadb/test/property/test_embeddings.py",
"chromadb/test/property/test_filtering.py",
"chromadb/test/property/test_persist.py",
"chromadb/test/property/test_restart_persist.py"]
"chromadb/test/property/test_restart_persist.py",
"chromadb/test/property/test_clean_log.py"]
include:
- test-globs: "chromadb/test/property/test_embeddings.py"
parallelized: true
- test-globs: "chromadb/test/property/test_clean_log.py"
parallelized: true

runs-on: ${{ matrix.platform }}
steps:
Expand Down
27 changes: 26 additions & 1 deletion chromadb/db/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Sequence, Tuple, Type
from typing import Any, Optional, Sequence, Tuple, Type, Union
from types import TracebackType
from typing_extensions import Protocol, Self, Literal
from abc import ABC, abstractmethod
Expand All @@ -10,6 +10,8 @@
from uuid import UUID
from itertools import islice, count

from chromadb.types import SeqId


class NotFoundError(Exception):
"""Raised when a delete or update operation affects no rows"""
Expand Down Expand Up @@ -117,6 +119,29 @@ def param(self, idx: int) -> pypika.Parameter:
"""Return a PyPika Parameter object for the given index"""
return pypika.Parameter(self.parameter_format().format(idx))

@staticmethod
def decode_seq_id(seq_id_bytes: Union[bytes, int]) -> SeqId:
"""Decode a byte array into a SeqID"""
if isinstance(seq_id_bytes, int):
return seq_id_bytes

if len(seq_id_bytes) == 8:
return int.from_bytes(seq_id_bytes, "big")
elif len(seq_id_bytes) == 24:
return int.from_bytes(seq_id_bytes, "big")
else:
raise ValueError(f"Unknown SeqID type with length {len(seq_id_bytes)}")

@staticmethod
def encode_seq_id(seq_id: SeqId) -> bytes:
"""Encode a SeqID into a byte array"""
if seq_id.bit_length() <= 64:
return int.to_bytes(seq_id, 8, "big")
elif seq_id.bit_length() <= 192:
return int.to_bytes(seq_id, 24, "big")
else:
raise ValueError(f"Unsupported SeqID: {seq_id}")


_context = local()

Expand Down
46 changes: 45 additions & 1 deletion chromadb/db/mixins/embeddings_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class SqlEmbeddingsQueue(SqlDB, Producer, Consumer):
Note that this class is only suitable for use cases where the producer and consumer
are in the same process.
This is because notifiaction of new embeddings happens solely in-process: this
This is because notification of new embeddings happens solely in-process: this
implementation does not actively listen to the the database for new records added by
other processes.
"""
Expand Down Expand Up @@ -116,6 +116,50 @@ def delete_log(self, collection_id: UUID) -> None:
sql, params = get_sql(q, self.parameter_format())
cur.execute(sql, params)

@trace_method("SqlEmbeddingsQueue.purge_log", OpenTelemetryGranularity.ALL)
@override
def purge_log(self, collection_id: UUID) -> None:
topic_name = create_topic_name(
self._tenant, self._topic_namespace, collection_id
)

segments_t = Table("segments")
segment_ids_q = (
self.querybuilder()
.from_(segments_t)
.where(
segments_t.collection == ParameterValue(self.uuid_to_db(collection_id))
)
# This coalesce prevents a correctness bug when two segments exist and:
# - one has written to the max_seq_id table
# - the other has not never written to the max_seq_id table
# In that case, we should not delete any WAL entries as we can't be sure that the second segment is caught up.
.select(functions.Coalesce(Table("max_seq_id").seq_id, -1))
.left_join(Table("max_seq_id"))
.on(segments_t.id == Table("max_seq_id").segment_id)
)

with self.tx() as cur:
sql, params = get_sql(segment_ids_q, self.parameter_format())
cur.execute(sql, params)
results = cur.fetchall()
if results:
min_seq_id = min(self.decode_seq_id(row[0]) for row in results)
else:
return

t = Table("embeddings_queue")
q = (
self.querybuilder()
.from_(t)
.where(t.topic == ParameterValue(topic_name))
.where(t.seq_id < ParameterValue(min_seq_id))
.delete()
)

sql, params = get_sql(q, self.parameter_format())
cur.execute(sql, params)

@trace_method("SqlEmbeddingsQueue.submit_embedding", OpenTelemetryGranularity.ALL)
@override
def submit_embedding(
Expand Down
7 changes: 6 additions & 1 deletion chromadb/ingest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ class Producer(Component):
def delete_log(self, collection_id: UUID) -> None:
pass

@abstractmethod
def purge_log(self, collection_id: UUID) -> None:
"""Truncates the log for the given collection, removing all seen records."""
pass

@abstractmethod
def submit_embedding(
self, collection_id: UUID, embedding: OperationRecord
Expand Down Expand Up @@ -82,7 +87,7 @@ def subscribe(
end: Optional[SeqId] = None,
id: Optional[UUID] = None,
) -> UUID:
"""Register a function that will be called to recieve embeddings for a given
"""Register a function that will be called to receive embeddings for a given
collections log stream. The given function may be called any number of times, with any number of
records, and may be called concurrently.
Expand Down
30 changes: 30 additions & 0 deletions chromadb/ingest/impl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from typing import Tuple
from uuid import UUID

from chromadb.db.base import SqlDB
from chromadb.segment import SegmentManager, VectorReader

topic_regex = r"persistent:\/\/(?P<tenant>.+)\/(?P<namespace>.+)\/(?P<topic>.+)"


Expand All @@ -15,3 +18,30 @@ def parse_topic_name(topic_name: str) -> Tuple[str, str, str]:

def create_topic_name(tenant: str, namespace: str, collection_id: UUID) -> str:
return f"persistent://{tenant}/{namespace}/{str(collection_id)}"


def trigger_vector_segments_max_seq_id_migration(
db: SqlDB, segment_manager: SegmentManager
) -> None:
"""
Trigger the migration of vector segments' max_seq_id from the pickled metadata file to SQLite.
Vector segments migrate this field automatically on init—so this should be used when we know segments are likely unmigrated and unloaded.
"""
with db.tx() as cur:
cur.execute(
"""
SELECT collection
FROM "segments"
WHERE "id" NOT IN (SELECT "segment_id" FROM "max_seq_id") AND
"type" = 'urn:chroma:segment/vector/hnsw-local-persisted'
"""
)
collection_ids_with_unmigrated_segments = [row[0] for row in cur.fetchall()]

if len(collection_ids_with_unmigrated_segments) == 0:
return

for collection_id in collection_ids_with_unmigrated_segments:
# Loading the segment triggers the migration on init
segment_manager.get_segment(UUID(collection_id), VectorReader)
5 changes: 5 additions & 0 deletions chromadb/logservice/logservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ def reset_state(self) -> None:
def delete_log(self, collection_id: UUID) -> None:
raise NotImplementedError("Not implemented")

@trace_method("LogService.purge_log", OpenTelemetryGranularity.ALL)
@override
def purge_log(self, collection_id: UUID) -> None:
raise NotImplementedError("Not implemented")

@trace_method("LogService.submit_embedding", OpenTelemetryGranularity.ALL)
@override
def submit_embedding(
Expand Down
53 changes: 17 additions & 36 deletions chromadb/segment/impl/metadata/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class SqliteMetadataSegment(MetadataReader):
_id: UUID
_opentelemetry_client: OpenTelemetryClient
_collection_id: Optional[UUID]
_subscription: Optional[UUID]
_subscription: Optional[UUID] = None

def __init__(self, system: System, segment: Segment):
self._db = system.instance(SqliteDB)
Expand Down Expand Up @@ -89,7 +89,7 @@ def max_seqid(self) -> SeqId:
if result is None:
return self._consumer.min_seqid()
else:
return _decode_seq_id(result[0])
return self._db.decode_seq_id(result[0])

@trace_method("SqliteMetadataSegment.count", OpenTelemetryGranularity.ALL)
@override
Expand Down Expand Up @@ -269,7 +269,7 @@ def _insert_record(self, cur: Cursor, record: LogRecord, upsert: bool) -> None:
).insert(
ParameterValue(self._db.uuid_to_db(self._id)),
ParameterValue(record["record"]["id"]),
ParameterValue(_encode_seq_id(record["log_offset"])),
ParameterValue(self._db.encode_seq_id(record["log_offset"])),
)
sql, params = get_sql(q)
sql = sql + "RETURNING id"
Expand Down Expand Up @@ -460,7 +460,7 @@ def _update_record(self, cur: Cursor, record: LogRecord) -> None:
q = (
self._db.querybuilder()
.update(t)
.set(t.seq_id, ParameterValue(_encode_seq_id(record["log_offset"])))
.set(t.seq_id, ParameterValue(self._db.encode_seq_id(record["log_offset"])))
.where(t.segment_id == ParameterValue(self._db.uuid_to_db(self._id)))
.where(t.embedding_id == ParameterValue(record["record"]["id"]))
)
Expand All @@ -482,18 +482,6 @@ def _write_metadata(self, records: Sequence[LogRecord]) -> None:
records are append-only (that is, that seq-ids should increase monotonically)"""
with self._db.tx() as cur:
for record in records:
q = (
self._db.querybuilder()
.into(Table("max_seq_id"))
.columns("segment_id", "seq_id")
.insert(
ParameterValue(self._db.uuid_to_db(self._id)),
ParameterValue(_encode_seq_id(record["log_offset"])),
)
)
sql, params = get_sql(q)
sql = sql.replace("INSERT", "INSERT OR REPLACE")
cur.execute(sql, params)
if record["record"]["operation"] == Operation.ADD:
self._insert_record(cur, record, False)
elif record["record"]["operation"] == Operation.UPSERT:
Expand All @@ -503,6 +491,19 @@ def _write_metadata(self, records: Sequence[LogRecord]) -> None:
elif record["record"]["operation"] == Operation.UPDATE:
self._update_record(cur, record)

q = (
self._db.querybuilder()
.into(Table("max_seq_id"))
.columns("segment_id", "seq_id")
.insert(
ParameterValue(self._db.uuid_to_db(self._id)),
ParameterValue(self._db.encode_seq_id(record["log_offset"])),
)
)
sql, params = get_sql(q)
sql = sql.replace("INSERT", "INSERT OR REPLACE")
cur.execute(sql, params)

@trace_method(
"SqliteMetadataSegment._where_map_criterion", OpenTelemetryGranularity.ALL
)
Expand Down Expand Up @@ -648,26 +649,6 @@ def delete(self) -> None:
cur.execute(*get_sql(q))


def _encode_seq_id(seq_id: SeqId) -> bytes:
"""Encode a SeqID into a byte array"""
if seq_id.bit_length() <= 64:
return int.to_bytes(seq_id, 8, "big")
elif seq_id.bit_length() <= 192:
return int.to_bytes(seq_id, 24, "big")
else:
raise ValueError(f"Unsupported SeqID: {seq_id}")


def _decode_seq_id(seq_id_bytes: bytes) -> SeqId:
"""Decode a byte array into a SeqID"""
if len(seq_id_bytes) == 8:
return int.from_bytes(seq_id_bytes, "big")
elif len(seq_id_bytes) == 24:
return int.from_bytes(seq_id_bytes, "big")
else:
raise ValueError(f"Unknown SeqID type with length {len(seq_id_bytes)}")


def _where_clause(
expr: Union[
LiteralValue,
Expand Down
3 changes: 2 additions & 1 deletion chromadb/segment/impl/vector/local_hnsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class LocalHnswSegment(VectorReader):
_id: UUID
_consumer: Consumer
_collection: Optional[UUID]
_subscription: UUID
_subscription: Optional[UUID]
_settings: Settings
_params: HnswParams

Expand All @@ -60,6 +60,7 @@ def __init__(self, system: System, segment: Segment):
self._consumer = system.instance(Consumer)
self._id = segment["id"]
self._collection = segment["collection"]
self._subscription = None
self._settings = system.settings
self._params = HnswParams(segment["metadata"] or {})

Expand Down
Loading

0 comments on commit 7edbb5b

Please sign in to comment.