-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Make memberlist use ips for routing #3405
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
|
||
from chromadb.config import System | ||
from chromadb.segment.distributed import ( | ||
Member, | ||
Memberlist, | ||
MemberlistProvider, | ||
SegmentDirectory, | ||
|
@@ -35,7 +36,12 @@ class MockMemberlistProvider(MemberlistProvider, EnforceOverrides): | |
|
||
def __init__(self, system: System): | ||
super().__init__(system) | ||
self._memberlist = ["a", "b", "c"] | ||
# self._memberlist = ["a", "b", "c"] | ||
self._memberlist = [ | ||
Member(id="a", ip="10.0.0.1"), | ||
Member(id="b", ip="10.0.0.2"), | ||
Member(id="c", ip="10.0.0.3"), | ||
] | ||
|
||
@override | ||
def get_memberlist(self) -> Memberlist: | ||
|
@@ -203,7 +209,12 @@ def _parse_response_memberlist( | |
) -> Memberlist: | ||
if "members" not in api_response_spec: | ||
return [] | ||
return [m["member_id"] for m in api_response_spec["members"]] | ||
parsed = [] | ||
for m in api_response_spec["members"]: | ||
id = m["member_id"] | ||
ip = m["member_ip"] if "member_ip" in m else "" | ||
parsed.append(Member(id=id, ip=ip)) | ||
return parsed | ||
|
||
def _notify(self, memberlist: Memberlist) -> None: | ||
for callback in self.callbacks: | ||
|
@@ -245,11 +256,24 @@ def get_segment_endpoint(self, segment: Segment) -> str: | |
raise ValueError("Memberlist is not initialized") | ||
# Query to the same collection should end up on the same endpoint | ||
assignment = assign( | ||
segment["collection"].hex, self._curr_memberlist, murmur3hasher, 1 | ||
segment["collection"].hex, | ||
[m.id for m in self._curr_memberlist], | ||
murmur3hasher, | ||
1, | ||
)[0] | ||
service_name = self.extract_service_name(assignment) | ||
assignment = f"{assignment}.{service_name}.{KUBERNETES_NAMESPACE}.{HEADLESS_SERVICE}:50051" # TODO: make port configurable | ||
return assignment | ||
|
||
# If the memberlist has an ip, use it, otherwise use the member id with the headless service | ||
# this is for backwards compatibility with the old memberlist which only had ids | ||
for member in self._curr_memberlist: | ||
if member.id == assignment: | ||
if member.ip is not None and member.ip != "": | ||
print(f"[HAMMAD DEBUG] Using member ip: {member.ip}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. todo: remove |
||
endpoint = f"{member.ip}:50051" | ||
return endpoint | ||
|
||
endpoint = f"{assignment}.{service_name}.{KUBERNETES_NAMESPACE}.{HEADLESS_SERVICE}:50051" # TODO: make port configurable | ||
return endpoint | ||
|
||
@override | ||
def register_updated_segment_callback( | ||
|
@@ -263,7 +287,9 @@ def register_updated_segment_callback( | |
) | ||
def _update_memberlist(self, memberlist: Memberlist) -> None: | ||
with self._curr_memberlist_mutex: | ||
add_attributes_to_current_span({"new_memberlist": memberlist}) | ||
add_attributes_to_current_span( | ||
{"new_memberlist": [m.id for m in memberlist]} | ||
) | ||
self._curr_memberlist = memberlist | ||
|
||
def extract_service_name(self, pod_name: str) -> Optional[str]: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,28 +14,29 @@ | |
from chromadb.segment.distributed import SegmentDirectory | ||
from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams | ||
from chromadb.telemetry.opentelemetry import ( | ||
OpenTelemetryClient, | ||
OpenTelemetryGranularity, | ||
trace_method, | ||
) | ||
from chromadb.types import Collection, CollectionAndSegments, Operation, Segment, SegmentScope | ||
from chromadb.types import ( | ||
Collection, | ||
Operation, | ||
Segment, | ||
SegmentScope, | ||
) | ||
|
||
|
||
class DistributedSegmentManager(SegmentManager): | ||
_sysdb: SysDB | ||
_system: System | ||
_opentelemetry_client: OpenTelemetryClient | ||
_instances: Dict[UUID, SegmentImplementation] | ||
_segment_directory: SegmentDirectory | ||
_lock: Lock | ||
# _segment_server_stubs: Dict[str, SegmentServerStub] # grpc_url -> grpc stub | ||
|
||
def __init__(self, system: System): | ||
super().__init__(system) | ||
self._sysdb = self.require(SysDB) | ||
self._segment_directory = self.require(SegmentDirectory) | ||
self._system = system | ||
self._opentelemetry_client = system.require(OpenTelemetryClient) | ||
self._instances = {} | ||
self._lock = Lock() | ||
|
||
|
@@ -77,6 +78,8 @@ def prepare_segments_for_new_collection( | |
|
||
@override | ||
def delete_segments(self, collection_id: UUID) -> Sequence[UUID]: | ||
# TODO: this should be a pass, delete_collection is expected to delete segments in | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks, i will remove the get_segments() calls. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
# distributed | ||
segments = self._sysdb.get_segments(collection=collection_id) | ||
return [s["id"] for s in segments] | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from typing import Sequence | ||
from chromadb.test.conftest import ( | ||
reset, | ||
skip_if_not_cluster, | ||
) | ||
from chromadb.api import ClientAPI | ||
from kubernetes import client as k8s_client, config | ||
import time | ||
|
||
|
||
@skip_if_not_cluster() | ||
def test_reroute( | ||
client: ClientAPI, | ||
) -> None: | ||
reset(client) | ||
collection = client.create_collection( | ||
name="test", | ||
metadata={"hnsw:construction_ef": 128, "hnsw:search_ef": 128, "hnsw:M": 128}, | ||
) | ||
|
||
ids = [str(i) for i in range(10)] | ||
embeddings: list[Sequence[float]] = [ | ||
[float(i), float(i), float(i)] for i in range(10) | ||
] | ||
collection.add(ids=ids, embeddings=embeddings) | ||
collection.query(query_embeddings=[embeddings[0]]) | ||
|
||
# Restart the query service using k8s api, in order to trigger a reroute | ||
# of the query service | ||
config.load_kube_config() | ||
v1 = k8s_client.CoreV1Api() | ||
# Find all pods with the label "app=query" | ||
res = v1.list_namespaced_pod("chroma", label_selector="app=query-service") | ||
assert len(res.items) > 0 | ||
items = res.items | ||
seen_ids = set() | ||
|
||
# Restart all the pods by deleting them | ||
for item in items: | ||
seen_ids.add(item.metadata.uid) | ||
name = item.metadata.name | ||
namespace = item.metadata.namespace | ||
v1.delete_namespaced_pod(name, namespace) | ||
|
||
# Wait until we have len(seen_ids) pods running with new UIDs | ||
timeout_secs = 10 | ||
start_time = time.time() | ||
while True: | ||
res = v1.list_namespaced_pod("chroma", label_selector="app=query-service") | ||
items = res.items | ||
new_ids = set([item.metadata.uid for item in items]) | ||
if len(new_ids) == len(seen_ids) and len(new_ids.intersection(seen_ids)) == 0: | ||
break | ||
if time.time() - start_time > timeout_secs: | ||
assert False, "Timed out waiting for new pods to start" | ||
time.sleep(1) | ||
|
||
# Wait for the query service to be ready, or timeout | ||
while True: | ||
res = v1.list_namespaced_pod("chroma", label_selector="app=query-service") | ||
items = res.items | ||
ready = True | ||
for item in items: | ||
if item.status.phase != "Running": | ||
ready = False | ||
break | ||
if ready: | ||
break | ||
if time.time() - start_time > timeout_secs: | ||
assert False, "Timed out waiting for new pods to be ready" | ||
time.sleep(1) | ||
|
||
time.sleep(1) | ||
collection.query(query_embeddings=[embeddings[0]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
todo: clean