From ec59917ea5586e501657ab72ac58ea7aa6334ea1 Mon Sep 17 00:00:00 2001 From: Hammad Bashir Date: Thu, 16 Nov 2023 17:22:15 -0800 Subject: [PATCH] [ENH] Worker Topic Assignment (#1376) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Adds worker topic assignment based on rendezvous hashing. - New functionality - ... ## Test plan *How are these changes tested?* - [ ] Tests pass locally with `pytest` for python, `yarn test` for js ## Documentation Changes *Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?* --- chromadb/ingest/__init__.py | 5 + chromadb/ingest/impl/simple_policy.py | 36 +++ chromadb/proto/chroma_pb2.pyi | 6 - chromadb/proto/chroma_pb2_grpc.py | 128 ---------- .../impl/distributed/segment_directory.py | 4 +- chromadb/segment/impl/distributed/server.py | 220 +++++++++++------- chromadb/segment/impl/manager/distributed.py | 40 ++-- .../distributed/test_memberlist_provider.py | 18 +- go/coordinator/cmd/grpccoordinator/cmd.go | 7 + .../internal/grpccoordinator/server.go | 60 ++++- .../memberlist_manager/memberlist_manager.go | 4 +- .../memberlist_manager_test.go | 12 +- .../memberlist_manager/memberlist_store.go | 27 ++- go/coordinator/internal/utils/kubernetes.go | 24 +- idl/chromadb/proto/chroma.proto | 13 -- k8s/crd/memberlist_crd.yaml | 4 +- k8s/deployment/kubernetes.yaml | 75 ------ k8s/deployment/segment-server.yaml | 97 ++++++++ k8s/test/test_memberlist_cr.yaml | 48 ++++ 19 files changed, 440 insertions(+), 388 deletions(-) create mode 100644 k8s/deployment/segment-server.yaml create mode 100644 k8s/test/test_memberlist_cr.yaml diff --git a/chromadb/ingest/__init__.py b/chromadb/ingest/__init__.py index 5a5abf1c99b..73f9cb065f2 100644 --- a/chromadb/ingest/__init__.py +++ b/chromadb/ingest/__init__.py @@ -127,3 +127,8 @@ class CollectionAssignmentPolicy(Component): def assign_collection(self, collection_id: UUID) -> str: """Return the topic that should be used for the given collection""" pass + + @abstractmethod + def get_topics(self) -> Sequence[str]: + """Return the list of topics that this policy is currently using""" + pass diff --git a/chromadb/ingest/impl/simple_policy.py b/chromadb/ingest/impl/simple_policy.py index 06ee2e001e0..f8068ee2046 100644 --- a/chromadb/ingest/impl/simple_policy.py +++ b/chromadb/ingest/impl/simple_policy.py @@ -1,3 +1,4 @@ +from typing import Sequence from uuid import UUID from overrides import overrides from chromadb.config import System @@ -23,3 +24,38 @@ def _topic(self, collection_id: UUID) -> str: @overrides def assign_collection(self, collection_id: UUID) -> str: return self._topic(collection_id) + + @overrides + def get_topics(self) -> Sequence[str]: + raise NotImplementedError( + "SimpleAssignmentPolicy does not support get_topics, each collection has its own topic" + ) + + +class RendezvousHashingAssignmentPolicy(CollectionAssignmentPolicy): + """The rendezvous hashing assignment policy assigns a collection to a topic based on the + rendezvous hashing algorithm. This is not actually used in the python sysdb. It is only used in the + go sysdb. However, it is useful here in order to provide a way to get the topic list used for the whole system. + """ + + _tenant_id: str + _topic_ns: str + + def __init__(self, system: System): + self._tenant_id = system.settings.tenant_id + self._topic_ns = system.settings.topic_namespace + super().__init__(system) + + @overrides + def assign_collection(self, collection_id: UUID) -> str: + raise NotImplementedError( + "RendezvousHashingAssignmentPolicy is not implemented" + ) + + @overrides + def get_topics(self) -> Sequence[str]: + # Mirrors go/coordinator/internal/coordinator/assignment_policy.go + return [ + f"persistent://{self._tenant_id}/{self._topic_ns}/chroma_log_{i}" + for i in range(16) + ] diff --git a/chromadb/proto/chroma_pb2.pyi b/chromadb/proto/chroma_pb2.pyi index 00128c82525..026bfac8821 100644 --- a/chromadb/proto/chroma_pb2.pyi +++ b/chromadb/proto/chroma_pb2.pyi @@ -170,12 +170,6 @@ class VectorQueryResults(_message.Message): results: _containers.RepeatedCompositeFieldContainer[VectorQueryResult] def __init__(self, results: _Optional[_Iterable[_Union[VectorQueryResult, _Mapping]]] = ...) -> None: ... -class SegmentServerResponse(_message.Message): - __slots__ = ["success"] - SUCCESS_FIELD_NUMBER: _ClassVar[int] - success: bool - def __init__(self, success: bool = ...) -> None: ... - class GetVectorsRequest(_message.Message): __slots__ = ["ids", "segment_id"] IDS_FIELD_NUMBER: _ClassVar[int] diff --git a/chromadb/proto/chroma_pb2_grpc.py b/chromadb/proto/chroma_pb2_grpc.py index 6d98cc34681..ccd53e449c0 100644 --- a/chromadb/proto/chroma_pb2_grpc.py +++ b/chromadb/proto/chroma_pb2_grpc.py @@ -5,134 +5,6 @@ from chromadb.proto import chroma_pb2 as chromadb_dot_proto_dot_chroma__pb2 -class SegmentServerStub(object): - """Segment Server Interface - - TODO: figure out subpackaging, ideally this file is colocated with the segment server implementation - """ - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.LoadSegment = channel.unary_unary( - "/chroma.SegmentServer/LoadSegment", - request_serializer=chromadb_dot_proto_dot_chroma__pb2.Segment.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_chroma__pb2.SegmentServerResponse.FromString, - ) - self.ReleaseSegment = channel.unary_unary( - "/chroma.SegmentServer/ReleaseSegment", - request_serializer=chromadb_dot_proto_dot_chroma__pb2.Segment.SerializeToString, - response_deserializer=chromadb_dot_proto_dot_chroma__pb2.SegmentServerResponse.FromString, - ) - - -class SegmentServerServicer(object): - """Segment Server Interface - - TODO: figure out subpackaging, ideally this file is colocated with the segment server implementation - """ - - def LoadSegment(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") - - def ReleaseSegment(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") - - -def add_SegmentServerServicer_to_server(servicer, server): - rpc_method_handlers = { - "LoadSegment": grpc.unary_unary_rpc_method_handler( - servicer.LoadSegment, - request_deserializer=chromadb_dot_proto_dot_chroma__pb2.Segment.FromString, - response_serializer=chromadb_dot_proto_dot_chroma__pb2.SegmentServerResponse.SerializeToString, - ), - "ReleaseSegment": grpc.unary_unary_rpc_method_handler( - servicer.ReleaseSegment, - request_deserializer=chromadb_dot_proto_dot_chroma__pb2.Segment.FromString, - response_serializer=chromadb_dot_proto_dot_chroma__pb2.SegmentServerResponse.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - "chroma.SegmentServer", rpc_method_handlers - ) - server.add_generic_rpc_handlers((generic_handler,)) - - -# This class is part of an EXPERIMENTAL API. -class SegmentServer(object): - """Segment Server Interface - - TODO: figure out subpackaging, ideally this file is colocated with the segment server implementation - """ - - @staticmethod - def LoadSegment( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_unary( - request, - target, - "/chroma.SegmentServer/LoadSegment", - chromadb_dot_proto_dot_chroma__pb2.Segment.SerializeToString, - chromadb_dot_proto_dot_chroma__pb2.SegmentServerResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - ) - - @staticmethod - def ReleaseSegment( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_unary( - request, - target, - "/chroma.SegmentServer/ReleaseSegment", - chromadb_dot_proto_dot_chroma__pb2.Segment.SerializeToString, - chromadb_dot_proto_dot_chroma__pb2.SegmentServerResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - ) - - class VectorReaderStub(object): """Vector Reader Interface""" diff --git a/chromadb/segment/impl/distributed/segment_directory.py b/chromadb/segment/impl/distributed/segment_directory.py index 9068f5ce645..7107e156a03 100644 --- a/chromadb/segment/impl/distributed/segment_directory.py +++ b/chromadb/segment/impl/distributed/segment_directory.py @@ -13,7 +13,7 @@ # These could go in config but given that they will rarely change, they are here for now to avoid # polluting the config file further. -WATCH_TIMEOUT_SECONDS = 10 +WATCH_TIMEOUT_SECONDS = 60 KUBERNETES_NAMESPACE = "chroma" KUBERNETES_GROUP = "chroma.cluster" @@ -213,7 +213,7 @@ def stop(self) -> None: @override def get_segment_endpoint(self, segment: Segment) -> str: # TODO: This should rendezvous hash the segment ID to a worker given the current memberlist - return "segment-worker.chroma:50051" + return "segment-server.chroma:50051" @override def register_updated_segment_callback( diff --git a/chromadb/segment/impl/distributed/server.py b/chromadb/segment/impl/distributed/server.py index 9b56ed4d18a..0189641290b 100644 --- a/chromadb/segment/impl/distributed/server.py +++ b/chromadb/segment/impl/distributed/server.py @@ -1,9 +1,10 @@ -from typing import Any, Dict, Type, cast +from typing import Any, Dict, List, Sequence, Set, Type, cast from uuid import UUID from chromadb.config import Settings, System, get_class +from chromadb.ingest import CollectionAssignmentPolicy, Consumer from chromadb.proto.chroma_pb2_grpc import ( - SegmentServerServicer, - add_SegmentServerServicer_to_server, + # SegmentServerServicer, + # add_SegmentServerServicer_to_server, VectorReaderServicer, add_VectorReaderServicer_to_server, ) @@ -22,9 +23,12 @@ OpenTelemetryGranularity, trace_method, ) -from chromadb.types import ScalarEncoding, Segment, SegmentScope +from chromadb.types import EmbeddingRecord, ScalarEncoding, Segment, SegmentScope +from chromadb.segment.distributed import MemberlistProvider, Memberlist +from chromadb.utils.rendezvous_hash import assign, murmur3hasher +from chromadb.ingest.impl.pulsar_admin import PulsarAdmin import logging - +import os # Run this with python -m chromadb.segment.impl.distributed.server @@ -39,97 +43,139 @@ } -class SegmentServer(SegmentServerServicer, VectorReaderServicer): +class SegmentServer(VectorReaderServicer): _segment_cache: Dict[UUID, SegmentImplementation] = {} _system: System _opentelemetry_client: OpenTelemetryClient + _memberlist_provider: MemberlistProvider + _curr_memberlist: Memberlist + _assigned_topics: Set[str] + _topic_to_subscription: Dict[str, UUID] + _consumer: Consumer def __init__(self, system: System) -> None: super().__init__() self._system = system + + # Init dependency services self._opentelemetry_client = system.require(OpenTelemetryClient) + # TODO: add term and epoch to segment server + self._memberlist_provider = system.require(MemberlistProvider) + self._memberlist_provider.set_memberlist_name("worker-memberlist") + self._assignment_policy = system.require(CollectionAssignmentPolicy) + self._create_pulsar_topics() + self._consumer = system.require(Consumer) - @trace_method( - "SegmentServer.LoadSegment", OpenTelemetryGranularity.OPERATION_AND_SEGMENT - ) - def LoadSegment( - self, request: proto.Segment, context: Any - ) -> proto.SegmentServerResponse: - logging.info(f"LoadSegment scope {request.type}") - id = UUID(hex=request.id) - if id in self._segment_cache: - return proto.SegmentServerResponse( - success=True, - ) - else: - if request.scope == proto.SegmentScope.METADATA: - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Metadata segments are not yet implemented") - return proto.SegmentServerResponse(success=False) - elif request.scope == proto.SegmentScope.VECTOR: - logging.info(f"Loading segment {request}") - if request.type == SegmentType.HNSW_DISTRIBUTED.value: - self._create_instance(from_proto_segment(request)) - return proto.SegmentServerResponse(success=True) - else: - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Segment type not implemented yet") - return proto.SegmentServerResponse(success=False) - else: - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Segment scope not implemented") - return proto.SegmentServerResponse(success=False) - - def ReleaseSegment( - self, request: proto.Segment, context: Any - ) -> proto.SegmentServerResponse: - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Release segment not implemented yet") - return proto.SegmentServerResponse(success=False) - - def QueryVectors( - self, request: proto.QueryVectorsRequest, context: Any - ) -> proto.QueryVectorsResponse: - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Query segment not implemented yet") - return proto.QueryVectorsResponse() - - @trace_method( - "SegmentServer.GetVectors", OpenTelemetryGranularity.OPERATION_AND_SEGMENT - ) - def GetVectors( - self, request: proto.GetVectorsRequest, context: Any - ) -> proto.GetVectorsResponse: - segment_id = UUID(hex=request.segment_id) - if segment_id not in self._segment_cache: - context.set_code(grpc.StatusCode.NOT_FOUND) - context.set_details("Segment not found") - return proto.GetVectorsResponse() + # Init data + self._topic_to_subscription = {} + self._assigned_topics = set() + self._curr_memberlist = self._memberlist_provider.get_memberlist() + self._compute_assigned_topics() + + self._memberlist_provider.register_updated_memberlist_callback( + self._on_memberlist_update + ) + + def _compute_assigned_topics(self) -> None: + """Uses rendezvous hashing to compute the topics that this node is responsible for""" + if not self._curr_memberlist: + self._assigned_topics = set() + return + topics = self._assignment_policy.get_topics() + my_ip = os.environ["MY_POD_IP"] + new_assignments: List[str] = [] + for topic in topics: + assigned = assign(topic, self._curr_memberlist, murmur3hasher) + if assigned == my_ip: + new_assignments.append(topic) + new_assignments_set = set(new_assignments) + # TODO: We need to lock around this assignment + net_new_assignments = new_assignments_set - self._assigned_topics + removed_assignments = self._assigned_topics - new_assignments_set + + for topic in removed_assignments: + subscription = self._topic_to_subscription[topic] + self._consumer.unsubscribe(subscription) + del self._topic_to_subscription[topic] + + for topic in net_new_assignments: + subscription = self._consumer.subscribe(topic, self._on_message) + self._topic_to_subscription[topic] = subscription + + self._assigned_topics = new_assignments_set + print( + f"Topic assigment updated and now assigned to {len(self._assigned_topics)} topics" + ) + + def _on_memberlist_update(self, memberlist: Memberlist) -> None: + """Called when the memberlist is updated""" + self._curr_memberlist = memberlist + if len(self._curr_memberlist) > 0: + self._compute_assigned_topics() else: - segment = self._segment_cache[segment_id] - segment = cast(VectorReader, segment) - segment_results = segment.get_vectors(request.ids) - return_records = [] - for record in segment_results: - # TODO: encoding should be based on stored encoding for segment - # For now we just assume float32 - return_record = to_proto_vector_embedding_record( - record, ScalarEncoding.FLOAT32 - ) - return_records.append(return_record) - return proto.GetVectorsResponse(records=return_records) - - def _cls(self, segment: Segment) -> Type[SegmentImplementation]: - classname = SEGMENT_TYPE_IMPLS[SegmentType(segment["type"])] - cls = get_class(classname, SegmentImplementation) - return cls - - def _create_instance(self, segment: Segment) -> None: - if segment["id"] not in self._segment_cache: - cls = self._cls(segment) - instance = cls(self._system, segment) - instance.start() - self._segment_cache[segment["id"]] = instance + # In this case we'd want to warn that there are no members but + # this is not an error, as it could be that the cluster is just starting up + print("Memberlist is empty") + + def _on_message(self, embedding_records: Sequence[EmbeddingRecord]) -> None: + """Called when a message is received from the consumer""" + print(f"Received {len(embedding_records)} records") + print( + f"First record: {embedding_records[0]} is for collection {embedding_records[0]['collection_id']}" + ) + return None + + def _create_pulsar_topics(self) -> None: + """This creates the pulsar topics used by the system. THIS IS COMPLETELY A HACK AND WILL BE REPLACED + BY A PROPER TOPIC MANAGEMENT SYSTEM IN THE COORDINATOR""" + topics = self._assignment_policy.get_topics() + admin = PulsarAdmin(self._system) + for topic in topics: + admin.create_topic(topic) + + # def QueryVectors( + # self, request: proto.QueryVectorsRequest, context: Any + # ) -> proto.QueryVectorsResponse: + # context.set_code(grpc.StatusCode.UNIMPLEMENTED) + # context.set_details("Query segment not implemented yet") + # return proto.QueryVectorsResponse() + + # @trace_method( + # "SegmentServer.GetVectors", OpenTelemetryGranularity.OPERATION_AND_SEGMENT + # ) + # def GetVectors( + # self, request: proto.GetVectorsRequest, context: Any + # ) -> proto.GetVectorsResponse: + # segment_id = UUID(hex=request.segment_id) + # if segment_id not in self._segment_cache: + # context.set_code(grpc.StatusCode.NOT_FOUND) + # context.set_details("Segment not found") + # return proto.GetVectorsResponse() + # else: + # segment = self._segment_cache[segment_id] + # segment = cast(VectorReader, segment) + # segment_results = segment.get_vectors(request.ids) + # return_records = [] + # for record in segment_results: + # # TODO: encoding should be based on stored encoding for segment + # # For now we just assume float32 + # return_record = to_proto_vector_embedding_record( + # record, ScalarEncoding.FLOAT32 + # ) + # return_records.append(return_record) + # return proto.GetVectorsResponse(records=return_records) + + # def _cls(self, segment: Segment) -> Type[SegmentImplementation]: + # classname = SEGMENT_TYPE_IMPLS[SegmentType(segment["type"])] + # cls = get_class(classname, SegmentImplementation) + # return cls + + # def _create_instance(self, segment: Segment) -> None: + # if segment["id"] not in self._segment_cache: + # cls = self._cls(segment) + # instance = cls(self._system, segment) + # instance.start() + # self._segment_cache[segment["id"]] = instance if __name__ == "__main__": @@ -137,7 +183,7 @@ def _create_instance(self, segment: Segment) -> None: system = System(Settings()) server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) segment_server = SegmentServer(system) - add_SegmentServerServicer_to_server(segment_server, server) # type: ignore + # add_SegmentServerServicer_to_server(segment_server, server) # type: ignore add_VectorReaderServicer_to_server(segment_server, server) # type: ignore server.add_insecure_port( f"[::]:{system.settings.require('chroma_server_grpc_port')}" diff --git a/chromadb/segment/impl/manager/distributed.py b/chromadb/segment/impl/manager/distributed.py index e03b58db224..a82648d41bb 100644 --- a/chromadb/segment/impl/manager/distributed.py +++ b/chromadb/segment/impl/manager/distributed.py @@ -1,8 +1,4 @@ from threading import Lock - -import grpc -from chromadb.proto.chroma_pb2_grpc import SegmentServerStub # type: ignore -from chromadb.proto.convert import to_proto_segment from chromadb.segment import ( SegmentImplementation, SegmentManager, @@ -47,7 +43,7 @@ class DistributedSegmentManager(SegmentManager): ] # collection_id -> scope -> segment _segment_directory: SegmentDirectory _lock: Lock - _segment_server_stubs: Dict[str, SegmentServerStub] # grpc_url -> grpc stub + # _segment_server_stubs: Dict[str, SegmentServerStub] # grpc_url -> grpc stub def __init__(self, system: System): super().__init__(system) @@ -57,7 +53,6 @@ def __init__(self, system: System): self._opentelemetry_client = system.require(OpenTelemetryClient) self._instances = {} self._segment_cache = defaultdict(dict) - self._segment_server_stubs = {} self._lock = Lock() @trace_method( @@ -131,22 +126,23 @@ def hint_use_collection(self, collection_id: UUID, hint_type: Operation) -> None ) known_types = set([k.value for k in SEGMENT_TYPE_IMPLS.keys()]) segment = next(filter(lambda s: s["type"] in known_types, segments)) - grpc_url = self._segment_directory.get_segment_endpoint(segment) - - if grpc_url not in self._segment_server_stubs: - channel = grpc.insecure_channel(grpc_url) - self._segment_server_stubs[grpc_url] = SegmentServerStub(channel) # type: ignore - - self._segment_server_stubs[grpc_url].LoadSegment( - to_proto_segment(segment) - ) - if grpc_url not in self._segment_server_stubs: - channel = grpc.insecure_channel(grpc_url) - self._segment_server_stubs[grpc_url] = SegmentServerStub(channel) - - self._segment_server_stubs[grpc_url].LoadSegment( - to_proto_segment(segment) - ) + # grpc_url = self._segment_directory.get_segment_endpoint(segment) + + # if grpc_url not in self._segment_server_stubs: + # channel = grpc.insecure_channel(grpc_url) + # self._segment_server_stubs[grpc_url] = SegmentServerStub(channel) # type: ignore + + # TODO: this load is not necessary + # self._segment_server_stubs[grpc_url].LoadSegment( + # to_proto_segment(segment) + # ) + # if grpc_url not in self._segment_server_stubs: + # channel = grpc.insecure_channel(grpc_url) + # self._segment_server_stubs[grpc_url] = SegmentServerStub(channel) + + # self._segment_server_stubs[grpc_url].LoadSegment( + # to_proto_segment(segment) + # ) # TODO: rethink duplication from local segment manager def _cls(self, segment: Segment) -> Type[SegmentImplementation]: diff --git a/chromadb/test/segment/distributed/test_memberlist_provider.py b/chromadb/test/segment/distributed/test_memberlist_provider.py index 1df0b5005e1..4acb2b224dc 100644 --- a/chromadb/test/segment/distributed/test_memberlist_provider.py +++ b/chromadb/test/segment/distributed/test_memberlist_provider.py @@ -2,8 +2,6 @@ import threading from chromadb.test.conftest import skip_if_not_cluster from kubernetes import client, config -import pytest -import os from chromadb.config import System, Settings from chromadb.segment.distributed import Memberlist from chromadb.segment.impl.distributed.segment_directory import ( @@ -15,11 +13,11 @@ # Used for testing to update the memberlist CRD -def update_memberlist(n: int, memberlist_name: str = "worker-memberlist") -> Memberlist: +def update_memberlist(n: int, memberlist_name: str = "test-memberlist") -> Memberlist: config.load_config() api_instance = client.CustomObjectsApi() - members = [{"url": f"ip.{i}.com"} for i in range(1, n + 1)] + members = [{"url": f"10.0.0.{i}"} for i in range(1, n + 1)] body = { "kind": "MemberList", @@ -45,10 +43,10 @@ def compare_memberlists(m1: Memberlist, m2: Memberlist) -> bool: @skip_if_not_cluster() def test_can_get_memberlist() -> None: - # This test assumes that the memberlist CRD is already created with the name "worker-memberlist" + # This test assumes that the memberlist CRD is already created with the name "test-memberlist" system = System(Settings(allow_reset=True)) provider = system.instance(CustomResourceMemberlistProvider) - provider.set_memberlist_name("worker-memberlist") + provider.set_memberlist_name("test-memberlist") system.reset_state() system.start() @@ -64,10 +62,10 @@ def test_can_get_memberlist() -> None: @skip_if_not_cluster() def test_can_update_memberlist_multiple_times() -> None: - # This test assumes that the memberlist CRD is already created with the name "worker-memberlist" + # This test assumes that the memberlist CRD is already created with the name "test-memberlist" system = System(Settings(allow_reset=True)) provider = system.instance(CustomResourceMemberlistProvider) - provider.set_memberlist_name("worker-memberlist") + provider.set_memberlist_name("test-memberlist") system.reset_state() system.start() @@ -90,10 +88,10 @@ def test_can_update_memberlist_multiple_times() -> None: @skip_if_not_cluster() def test_stop_memberlist_kills_thread() -> None: - # This test assumes that the memberlist CRD is already created with the name "worker-memberlist" + # This test assumes that the memberlist CRD is already created with the name "test-memberlist" system = System(Settings(allow_reset=True)) provider = system.instance(CustomResourceMemberlistProvider) - provider.set_memberlist_name("worker-memberlist") + provider.set_memberlist_name("test-memberlist") system.reset_state() system.start() diff --git a/go/coordinator/cmd/grpccoordinator/cmd.go b/go/coordinator/cmd/grpccoordinator/cmd.go index 6cc6a1376fe..d4dc50d0dd3 100644 --- a/go/coordinator/cmd/grpccoordinator/cmd.go +++ b/go/coordinator/cmd/grpccoordinator/cmd.go @@ -2,6 +2,7 @@ package grpccoordinator import ( "io" + "time" "github.com/chroma/chroma-coordinator/cmd/flag" "github.com/chroma/chroma-coordinator/internal/grpccoordinator" @@ -29,6 +30,12 @@ func init() { Cmd.Flags().StringVar(&conf.DBName, "db-name", "", "MetaTable db name") Cmd.Flags().IntVar(&conf.MaxIdleConns, "max-idle-conns", 10, "MetaTable max idle connections") Cmd.Flags().IntVar(&conf.MaxOpenConns, "max-open-conns", 10, "MetaTable max open connections") + Cmd.Flags().StringVar(&conf.PulsarTenant, "pulsar-tenant", "default", "Pulsar tenant") + Cmd.Flags().StringVar(&conf.PulsarNamespace, "pulsar-namespace", "default", "Pulsar namespace") + Cmd.Flags().StringVar(&conf.KubernetesNamespace, "kubernetes-namespace", "chroma", "Kubernetes namespace") + Cmd.Flags().StringVar(&conf.WorkerMemberlistName, "worker-memberlist-name", "worker-memberlist", "Worker memberlist name") + Cmd.Flags().StringVar(&conf.AssignmentPolicy, "assignment-policy", "rendezvous", "Assignment policy") + Cmd.Flags().DurationVar(&conf.WatchInterval, "watch-interval", 60*time.Second, "Watch interval") } func exec(*cobra.Command, []string) { diff --git a/go/coordinator/internal/grpccoordinator/server.go b/go/coordinator/internal/grpccoordinator/server.go index e26bdb976ba..ac99a18bdf0 100644 --- a/go/coordinator/internal/grpccoordinator/server.go +++ b/go/coordinator/internal/grpccoordinator/server.go @@ -3,11 +3,15 @@ package grpccoordinator import ( "context" "errors" + "time" "github.com/chroma/chroma-coordinator/internal/coordinator" "github.com/chroma/chroma-coordinator/internal/grpccoordinator/grpcutils" + "github.com/chroma/chroma-coordinator/internal/memberlist_manager" "github.com/chroma/chroma-coordinator/internal/metastore/db/dbcore" "github.com/chroma/chroma-coordinator/internal/proto/coordinatorpb" + "github.com/chroma/chroma-coordinator/internal/utils" + "github.com/pingcap/log" "google.golang.org/grpc" "google.golang.org/grpc/health" "gorm.io/gorm" @@ -28,6 +32,20 @@ type Config struct { MaxIdleConns int MaxOpenConns int + // Pulsar config + PulsarTenant string + PulsarNamespace string + + // Kubernetes config + KubernetesNamespace string + WorkerMemberlistName string + + // Assignment policy config can be "simple" or "rendezvous" + AssignmentPolicy string + + // Watcher config + WatchInterval time.Duration + // Config for testing Testing bool } @@ -71,17 +89,32 @@ func NewWithGrpcProvider(config Config, provider grpcutils.GrpcProvider, db *gor s := &Server{ healthServer: health.NewServer(), } - // assignmentPolicy := coordinator.NewSimpleAssignmentPolicy("test-tenant", "test-topic") - // TODO: make this configuration, and make the pulsar tenant configuration too - assignmentPolicy := coordinator.NewRendezvousAssignmentPolicy("test-tenant", "test-topic") + + var assignmentPolicy coordinator.CollectionAssignmentPolicy + if config.AssignmentPolicy == "simple" { + log.Info("Using simple assignment policy") + assignmentPolicy = coordinator.NewSimpleAssignmentPolicy(config.PulsarTenant, config.PulsarNamespace) + } else if config.AssignmentPolicy == "rendezvous" { + log.Info("Using rendezvous assignment policy") + assignmentPolicy = coordinator.NewRendezvousAssignmentPolicy(config.PulsarTenant, config.PulsarNamespace) + } else { + return nil, errors.New("invalid assignment policy, only simple and rendezvous are supported") + } coordinator, err := coordinator.NewCoordinator(ctx, assignmentPolicy, db) if err != nil { return nil, err } s.coordinator = coordinator s.coordinator.Start() - if !config.Testing { + memberlist_manager, err := createMemberlistManager(config) + + // Start the memberlist manager + err = memberlist_manager.Start() + if err != nil { + return nil, err + } + s.grpcServer, err = provider.StartGrpcServer("coordinator", config.BindAddress, func(registrar grpc.ServiceRegistrar) { coordinatorpb.RegisterSysDBServer(registrar, s) }) @@ -92,6 +125,25 @@ func NewWithGrpcProvider(config Config, provider grpcutils.GrpcProvider, db *gor return s, nil } +func createMemberlistManager(config Config) (*memberlist_manager.MemberlistManager, error) { + // TODO: Make this configuration + log.Info("Starting memberlist manager") + memberlist_name := config.WorkerMemberlistName + namespace := config.KubernetesNamespace + clientset, err := utils.GetKubernetesInterface() + if err != nil { + return nil, err + } + dynamicClient, err := utils.GetKubernetesDynamicInterface() + if err != nil { + return nil, err + } + nodeWatcher := memberlist_manager.NewKubernetesWatcher(clientset, namespace, "worker", config.WatchInterval) + memberlistStore := memberlist_manager.NewCRMemberlistStore(dynamicClient, namespace, memberlist_name) + memberlist_manager := memberlist_manager.NewMemberlistManager(nodeWatcher, memberlistStore) + return memberlist_manager, nil +} + func (s *Server) Close() error { s.healthServer.Shutdown() return nil diff --git a/go/coordinator/internal/memberlist_manager/memberlist_manager.go b/go/coordinator/internal/memberlist_manager/memberlist_manager.go index 18c97c43e9d..3da53fbc3b9 100644 --- a/go/coordinator/internal/memberlist_manager/memberlist_manager.go +++ b/go/coordinator/internal/memberlist_manager/memberlist_manager.go @@ -82,7 +82,7 @@ func (m *MemberlistManager) run() { } func (m *MemberlistManager) reconcile(nodeIp string, status Status) error { - memberlist, err := m.memberlistStore.GetMemberlist(context.TODO()) + memberlist, resourceVersion, err := m.memberlistStore.GetMemberlist(context.Background()) if err != nil { return err } @@ -110,7 +110,7 @@ func (m *MemberlistManager) reconcile(nodeIp string, status Status) error { if !exists && status == Ready { newMemberlist = append(newMemberlist, nodeIp) } - return m.memberlistStore.UpdateMemberlist(context.TODO(), &newMemberlist) + return m.memberlistStore.UpdateMemberlist(context.TODO(), &newMemberlist, resourceVersion) } func (m *MemberlistManager) Stop() error { diff --git a/go/coordinator/internal/memberlist_manager/memberlist_manager_test.go b/go/coordinator/internal/memberlist_manager/memberlist_manager_test.go index 5ffc59d6e26..4a26fdd484b 100644 --- a/go/coordinator/internal/memberlist_manager/memberlist_manager_test.go +++ b/go/coordinator/internal/memberlist_manager/memberlist_manager_test.go @@ -88,13 +88,13 @@ func TestMemberlistStore(t *testing.T) { memberlistName := "test-memberlist" namespace := "chroma" memberlist := &Memberlist{} - cr_memberlist := memberlistToCr(memberlist, namespace, memberlistName) + cr_memberlist := memberlistToCr(memberlist, namespace, memberlistName, "0") // Following the assumptions of the real system, we initialize the CR with no members. dynamicClient := fake.NewSimpleDynamicClient(runtime.NewScheme(), cr_memberlist) memberlist_store := NewCRMemberlistStore(dynamicClient, namespace, memberlistName) - memberlist, err := memberlist_store.GetMemberlist(context.TODO()) + memberlist, _, err := memberlist_store.GetMemberlist(context.TODO()) if err != nil { t.Fatalf("Error getting memberlist: %v", err) } @@ -102,8 +102,8 @@ func TestMemberlistStore(t *testing.T) { assert.Equal(t, Memberlist{}, *memberlist) // Add a member to the memberlist - memberlist_store.UpdateMemberlist(context.TODO(), &Memberlist{"10.0.0.1", "10.0.0.2"}) - memberlist, err = memberlist_store.GetMemberlist(context.TODO()) + memberlist_store.UpdateMemberlist(context.TODO(), &Memberlist{"10.0.0.1", "10.0.0.2"}, "0") + memberlist, _, err = memberlist_store.GetMemberlist(context.TODO()) if err != nil { t.Fatalf("Error getting memberlist: %v", err) } @@ -139,7 +139,7 @@ func TestMemberlistManager(t *testing.T) { memberlist_name := "test-memberlist" namespace := "chroma" initialMemberlist := &Memberlist{} - initialCrMemberlist := memberlistToCr(initialMemberlist, namespace, memberlist_name) + initialCrMemberlist := memberlistToCr(initialMemberlist, namespace, memberlist_name, "0") // Create a fake kubernetes client clientset, err := utils.GetTestKubenertesInterface() @@ -201,7 +201,7 @@ func retryUntilCondition(t *testing.T, f func() bool, retry_count int, retry_int } func getMemberlistAndCompare(t *testing.T, memberlistStore IMemberlistStore, expected_memberlist Memberlist) bool { - memberlist, err := memberlistStore.GetMemberlist(context.TODO()) + memberlist, _, err := memberlistStore.GetMemberlist(context.TODO()) if err != nil { t.Fatalf("Error getting memberlist: %v", err) } diff --git a/go/coordinator/internal/memberlist_manager/memberlist_store.go b/go/coordinator/internal/memberlist_manager/memberlist_store.go index f1844b6920f..0567897f46e 100644 --- a/go/coordinator/internal/memberlist_manager/memberlist_store.go +++ b/go/coordinator/internal/memberlist_manager/memberlist_store.go @@ -10,8 +10,8 @@ import ( ) type IMemberlistStore interface { - GetMemberlist(ctx context.Context) (*Memberlist, error) - UpdateMemberlist(ctx context.Context, memberlist *Memberlist) error + GetMemberlist(ctx context.Context) (return_memberlist *Memberlist, resourceVersion string, err error) + UpdateMemberlist(ctx context.Context, memberlist *Memberlist, resourceVersion string) error } type Memberlist []string @@ -30,27 +30,31 @@ func NewCRMemberlistStore(dynamicClient dynamic.Interface, coordinatorNamespace } } -func (s *CRMemberlistStore) GetMemberlist(ctx context.Context) (*Memberlist, error) { +func (s *CRMemberlistStore) GetMemberlist(ctx context.Context) (return_memberlist *Memberlist, resourceVersion string, err error) { gvr := getGvr() unstrucuted, err := s.dynamicClient.Resource(gvr).Namespace(s.coordinatorNamespace).Get(ctx, s.memberlistCustomResource, metav1.GetOptions{}) if err != nil { - return nil, err + return nil, "", err } cr := unstrucuted.UnstructuredContent() members := cr["spec"].(map[string]interface{})["members"] memberlist := Memberlist{} + if members == nil { + // Empty memberlist + return &memberlist, unstrucuted.GetResourceVersion(), nil + } cast_members := members.([]interface{}) for _, member := range cast_members { member_map := member.(map[string]interface{}) memberlist = append(memberlist, member_map["url"].(string)) } - return &memberlist, nil + return &memberlist, unstrucuted.GetResourceVersion(), nil } -func (s *CRMemberlistStore) UpdateMemberlist(ctx context.Context, memberlist *Memberlist) error { +func (s *CRMemberlistStore) UpdateMemberlist(ctx context.Context, memberlist *Memberlist, resourceVersion string) error { gvr := getGvr() - unstructured := memberlistToCr(memberlist, s.coordinatorNamespace, s.memberlistCustomResource) - _, err := s.dynamicClient.Resource(gvr).Namespace(s.coordinatorNamespace).Update(ctx, unstructured, metav1.UpdateOptions{}) + unstructured := memberlistToCr(memberlist, s.coordinatorNamespace, s.memberlistCustomResource, resourceVersion) + _, err := s.dynamicClient.Resource(gvr).Namespace("chroma").Update(context.TODO(), unstructured, metav1.UpdateOptions{}) if err != nil { return err } @@ -62,7 +66,7 @@ func getGvr() schema.GroupVersionResource { return gvr } -func memberlistToCr(memberlist *Memberlist, namespace string, memberlistName string) *unstructured.Unstructured { +func memberlistToCr(memberlist *Memberlist, namespace string, memberlistName string, resourceVersion string) *unstructured.Unstructured { members := []interface{}{} for _, member := range *memberlist { members = append(members, map[string]interface{}{ @@ -75,8 +79,9 @@ func memberlistToCr(memberlist *Memberlist, namespace string, memberlistName str "apiVersion": "chroma.cluster/v1", "kind": "MemberList", "metadata": map[string]interface{}{ - "name": memberlistName, - "namespace": namespace, + "name": memberlistName, + "namespace": namespace, + "resourceVersion": resourceVersion, }, "spec": map[string]interface{}{ "members": members, diff --git a/go/coordinator/internal/utils/kubernetes.go b/go/coordinator/internal/utils/kubernetes.go index 77b07245bf2..b2784cbf462 100644 --- a/go/coordinator/internal/utils/kubernetes.go +++ b/go/coordinator/internal/utils/kubernetes.go @@ -5,7 +5,6 @@ import ( "k8s.io/client-go/kubernetes" "k8s.io/client-go/kubernetes/fake" "k8s.io/client-go/rest" - "k8s.io/client-go/tools/clientcmd" ) func GetTestKubenertesInterface() (kubernetes.Interface, error) { @@ -14,19 +13,12 @@ func GetTestKubenertesInterface() (kubernetes.Interface, error) { } func getKubernetesConfig() (*rest.Config, error) { - // Load the default kubeconfig file - loadingRules := clientcmd.NewDefaultClientConfigLoadingRules() - config, err := loadingRules.Load() + config, err := rest.InClusterConfig() if err != nil { return nil, err } + return config, nil - clientConfig, err := clientcmd.NewDefaultClientConfig(*config, &clientcmd.ConfigOverrides{}).ClientConfig() - if err != nil { - return nil, err - } - - return clientConfig, nil } func GetKubernetesDynamicInterface() (dynamic.Interface, error) { @@ -44,20 +36,12 @@ func GetKubernetesDynamicInterface() (dynamic.Interface, error) { } func GetKubernetesInterface() (kubernetes.Interface, error) { - // Load the default kubeconfig file - loadingRules := clientcmd.NewDefaultClientConfigLoadingRules() - config, err := loadingRules.Load() + config, err := getKubernetesConfig() if err != nil { return nil, err } - - clientConfig, err := clientcmd.NewDefaultClientConfig(*config, &clientcmd.ConfigOverrides{}).ClientConfig() - if err != nil { - return nil, err - } - // Create a clientset for the coordinator - clientset, err := kubernetes.NewForConfig(clientConfig) + clientset, err := kubernetes.NewForConfig(config) if err != nil { return nil, err } diff --git a/idl/chromadb/proto/chroma.proto b/idl/chromadb/proto/chroma.proto index b034d30b374..51579aae921 100644 --- a/idl/chromadb/proto/chroma.proto +++ b/idl/chromadb/proto/chroma.proto @@ -106,19 +106,6 @@ message VectorQueryResults { repeated VectorQueryResult results = 1; } -/* Segment Server Interface */ - -// TODO: figure out subpackaging, ideally this file is colocated with the segment server implementation -service SegmentServer { - rpc LoadSegment (Segment) returns (SegmentServerResponse) {} - rpc ReleaseSegment (Segment) returns (SegmentServerResponse) {} // TODO: this maybe should only take id/type/scope -} - -// TODO: enum of succcess/failure/or already loaded -message SegmentServerResponse { - bool success = 1; -} - /* Vector Reader Interface */ service VectorReader { diff --git a/k8s/crd/memberlist_crd.yaml b/k8s/crd/memberlist_crd.yaml index 96be7388d01..9d31706aad2 100644 --- a/k8s/crd/memberlist_crd.yaml +++ b/k8s/crd/memberlist_crd.yaml @@ -24,9 +24,9 @@ spec: items: type: object properties: - url: + url: # Rename to ip type: string - pattern: '^(https?:\/\/)?([\da-z\.-]+)\.([a-z\.]{2,6})([\/\w \.-]*)*\/?$' + pattern: '^((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.?\b){4}$' scope: Namespaced names: plural: memberlists diff --git a/k8s/deployment/kubernetes.yaml b/k8s/deployment/kubernetes.yaml index faa0e436ce7..5498c3704a5 100644 --- a/k8s/deployment/kubernetes.yaml +++ b/k8s/deployment/kubernetes.yaml @@ -159,81 +159,6 @@ spec: --- -apiVersion: v1 -kind: Service -metadata: - name: segment-server - namespace: chroma -spec: - ports: - - name: segment-server-port - port: 50051 - targetPort: 50051 - selector: - app: segment-server - type: ClusterIP - ---- - -apiVersion: apps/v1 -kind: Deployment -metadata: - name: segment-server - namespace: chroma -spec: - replicas: 1 - selector: - matchLabels: - app: segment-server - template: - metadata: - labels: - app: segment-server - member-type: worker - spec: - containers: - - name: segment-server - image: server - imagePullPolicy: IfNotPresent - command: ["python", "-m", "chromadb.segment.impl.distributed.server"] - ports: - - containerPort: 50051 - volumeMounts: - - name: chroma - mountPath: /index_data - env: - - name: IS_PERSISTENT - value: "TRUE" - - name: CHROMA_PRODUCER_IMPL - value: "chromadb.ingest.impl.pulsar.PulsarProducer" - - name: CHROMA_CONSUMER_IMPL - value: "chromadb.ingest.impl.pulsar.PulsarConsumer" - - name: PULSAR_BROKER_URL - value: "pulsar.chroma" - - name: PULSAR_BROKER_PORT - value: "6650" - - name: PULSAR_ADMIN_PORT - value: "8080" - - name: CHROMA_SERVER_GRPC_PORT - value: "50051" - # readinessProbe: - # httpGet: - # path: /healthz - # port: 50051 - # initialDelaySeconds: 10 - # periodSeconds: 5 - # livenessProbe: - # httpGet: - # path: /healthz - # port: 50051 - # initialDelaySeconds: 20 - # periodSeconds: 10 - volumes: - - name: chroma - emptyDir: {} - ---- - # apiVersion: v1 # kind: PersistentVolumeClaim # metadata: diff --git a/k8s/deployment/segment-server.yaml b/k8s/deployment/segment-server.yaml new file mode 100644 index 00000000000..1df7cec9ff4 --- /dev/null +++ b/k8s/deployment/segment-server.yaml @@ -0,0 +1,97 @@ +apiVersion: v1 +kind: Service +metadata: + name: segment-server + namespace: chroma +spec: + ports: + - name: segment-server-port + port: 50051 + targetPort: 50051 + selector: + app: segment-server + type: ClusterIP + +--- + +apiVersion: apps/v1 +kind: Deployment +metadata: + name: segment-server + namespace: chroma +spec: + replicas: 1 + selector: + matchLabels: + app: segment-server + template: + metadata: + labels: + app: segment-server + member-type: worker + spec: + containers: + - name: segment-server + image: server + imagePullPolicy: IfNotPresent + command: ["python", "-m", "chromadb.segment.impl.distributed.server"] + ports: + - containerPort: 50051 + volumeMounts: + - name: chroma + mountPath: /index_data + env: + - name: IS_PERSISTENT + value: "TRUE" + - name: CHROMA_PRODUCER_IMPL + value: "chromadb.ingest.impl.pulsar.PulsarProducer" + - name: CHROMA_CONSUMER_IMPL + value: "chromadb.ingest.impl.pulsar.PulsarConsumer" + - name: PULSAR_BROKER_URL + value: "pulsar.chroma" + - name: PULSAR_BROKER_PORT + value: "6650" + - name: PULSAR_ADMIN_PORT + value: "8080" + - name: CHROMA_SERVER_GRPC_PORT + value: "50051" + - name: CHROMA_COLLECTION_ASSIGNMENT_POLICY_IMPL + value: "chromadb.ingest.impl.simple_policy.RendezvousHashingAssignmentPolicy" + - name: MY_POD_IP + valueFrom: + fieldRef: + fieldPath: status.podIP + # livenessProbe: + # grpc: + # port: 50051 + # initialDelaySeconds: 10 + volumes: + - name: chroma + emptyDir: {} + +--- + +apiVersion: rbac.authorization.k8s.io/v1 +kind: Role +metadata: + namespace: chroma + name: pod-watcher +rules: +- apiGroups: [""] + resources: ["pods"] + verbs: ["get", "list", "watch"] + +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + name: pod-watcher-binding + namespace: chroma +subjects: +- kind: ServiceAccount + name: default + namespace: chroma +roleRef: + kind: Role + name: pod-watcher + apiGroup: rbac.authorization.k8s.io diff --git a/k8s/test/test_memberlist_cr.yaml b/k8s/test/test_memberlist_cr.yaml new file mode 100644 index 00000000000..174e19ccef5 --- /dev/null +++ b/k8s/test/test_memberlist_cr.yaml @@ -0,0 +1,48 @@ +# These kubernetes manifests are UNDER ACTIVE DEVELOPMENT and are not yet ready for production use. +# They will be used for the upcoming distributed version of chroma. They are not even ready +# for testing yet. Please do not use them unless you are working on the distributed version of chroma. + +# Create a memberlist called worker-memberlist +apiVersion: chroma.cluster/v1 +kind: MemberList +metadata: + name: test-memberlist + namespace: chroma +spec: + members: + +--- + +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: test-memberlist-reader +rules: +- apiGroups: + - chroma.cluster + resources: + - memberlists + verbs: + - get + - list + - watch + # TODO: FIX THIS LEAKY PERMISSION + - create + - update + - patch + - delete + +--- + +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: test-memberlist-reader-binding +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: ClusterRole + name: test-memberlist-reader +subjects: +- kind: ServiceAccount + name: default + namespace: chroma