Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Make memberlist use ips for routing #3405

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions chromadb/execution/executor/distributed.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
from typing import Dict, Optional

import grpc
from overrides import overrides

from chromadb.api.types import GetResult, Metadata, QueryResult
from chromadb.config import System
from chromadb.errors import VersionMismatchError
from chromadb.execution.executor.abstract import Executor
from chromadb.execution.expression.operator import Scan
from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan
from chromadb.proto import convert

from chromadb.proto.query_executor_pb2_grpc import QueryExecutorStub
from chromadb.proto.utils import RetryOnRpcErrorClientInterceptor
from chromadb.segment.impl.manager.distributed import DistributedSegmentManager
Expand Down Expand Up @@ -170,6 +166,6 @@ def _grpc_executuor_stub(self, scan: Scan) -> QueryExecutorStub:
channel = grpc.insecure_channel(grpc_url)
interceptors = [OtelInterceptor(), RetryOnRpcErrorClientInterceptor()]
channel = grpc.intercept_channel(channel, *interceptors)
self._grpc_stub_pool[grpc_url] = QueryExecutorStub(channel) # type: ignore[no-untyped-call]
self._grpc_stub_pool[grpc_url] = QueryExecutorStub(channel)

return self._grpc_stub_pool[grpc_url]
9 changes: 8 additions & 1 deletion chromadb/segment/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, List

from overrides import EnforceOverrides, overrides
Expand All @@ -22,7 +23,13 @@ def register_updated_segment_callback(
pass


Memberlist = List[str]
@dataclass
class Member:
id: str
ip: str


Memberlist = List[Member]


class MemberlistProvider(Component, EnforceOverrides):
Expand Down
38 changes: 32 additions & 6 deletions chromadb/segment/impl/distributed/segment_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from chromadb.config import System
from chromadb.segment.distributed import (
Member,
Memberlist,
MemberlistProvider,
SegmentDirectory,
Expand Down Expand Up @@ -35,7 +36,12 @@ class MockMemberlistProvider(MemberlistProvider, EnforceOverrides):

def __init__(self, system: System):
super().__init__(system)
self._memberlist = ["a", "b", "c"]
# self._memberlist = ["a", "b", "c"]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

todo: clean

self._memberlist = [
Member(id="a", ip="10.0.0.1"),
Member(id="b", ip="10.0.0.2"),
Member(id="c", ip="10.0.0.3"),
]

@override
def get_memberlist(self) -> Memberlist:
Expand Down Expand Up @@ -203,7 +209,12 @@ def _parse_response_memberlist(
) -> Memberlist:
if "members" not in api_response_spec:
return []
return [m["member_id"] for m in api_response_spec["members"]]
parsed = []
for m in api_response_spec["members"]:
id = m["member_id"]
ip = m["member_ip"] if "member_ip" in m else ""
parsed.append(Member(id=id, ip=ip))
return parsed

def _notify(self, memberlist: Memberlist) -> None:
for callback in self.callbacks:
Expand Down Expand Up @@ -245,11 +256,24 @@ def get_segment_endpoint(self, segment: Segment) -> str:
raise ValueError("Memberlist is not initialized")
# Query to the same collection should end up on the same endpoint
assignment = assign(
segment["collection"].hex, self._curr_memberlist, murmur3hasher, 1
segment["collection"].hex,
[m.id for m in self._curr_memberlist],
murmur3hasher,
1,
)[0]
service_name = self.extract_service_name(assignment)
assignment = f"{assignment}.{service_name}.{KUBERNETES_NAMESPACE}.{HEADLESS_SERVICE}:50051" # TODO: make port configurable
return assignment

# If the memberlist has an ip, use it, otherwise use the member id with the headless service
# this is for backwards compatibility with the old memberlist which only had ids
for member in self._curr_memberlist:
if member.id == assignment:
if member.ip is not None and member.ip != "":
print(f"[HAMMAD DEBUG] Using member ip: {member.ip}")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

todo: remove

endpoint = f"{member.ip}:50051"
return endpoint

endpoint = f"{assignment}.{service_name}.{KUBERNETES_NAMESPACE}.{HEADLESS_SERVICE}:50051" # TODO: make port configurable
return endpoint

@override
def register_updated_segment_callback(
Expand All @@ -263,7 +287,9 @@ def register_updated_segment_callback(
)
def _update_memberlist(self, memberlist: Memberlist) -> None:
with self._curr_memberlist_mutex:
add_attributes_to_current_span({"new_memberlist": memberlist})
add_attributes_to_current_span(
{"new_memberlist": [m.id for m in memberlist]}
)
self._curr_memberlist = memberlist

def extract_service_name(self, pod_name: str) -> Optional[str]:
Expand Down
13 changes: 8 additions & 5 deletions chromadb/segment/impl/manager/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,29 @@
from chromadb.segment.distributed import SegmentDirectory
from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams
from chromadb.telemetry.opentelemetry import (
OpenTelemetryClient,
OpenTelemetryGranularity,
trace_method,
)
from chromadb.types import Collection, CollectionAndSegments, Operation, Segment, SegmentScope
from chromadb.types import (
Collection,
Operation,
Segment,
SegmentScope,
)


class DistributedSegmentManager(SegmentManager):
_sysdb: SysDB
_system: System
_opentelemetry_client: OpenTelemetryClient
_instances: Dict[UUID, SegmentImplementation]
_segment_directory: SegmentDirectory
_lock: Lock
# _segment_server_stubs: Dict[str, SegmentServerStub] # grpc_url -> grpc stub

def __init__(self, system: System):
super().__init__(system)
self._sysdb = self.require(SysDB)
self._segment_directory = self.require(SegmentDirectory)
self._system = system
self._opentelemetry_client = system.require(OpenTelemetryClient)
self._instances = {}
self._lock = Lock()

Expand Down Expand Up @@ -77,6 +78,8 @@ def prepare_segments_for_new_collection(

@override
def delete_segments(self, collection_id: UUID) -> Sequence[UUID]:
# TODO: this should be a pass, delete_collection is expected to delete segments in
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, i will remove the get_segments() calls.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, i will remove the get_segments() calls.

https://github.com/chroma-core/hosted-chroma/issues/814

# distributed
segments = self._sysdb.get_segments(collection=collection_id)
return [s["id"] for s in segments]

Expand Down
74 changes: 74 additions & 0 deletions chromadb/test/distributed/test_reroute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Sequence
from chromadb.test.conftest import (
reset,
skip_if_not_cluster,
)
from chromadb.api import ClientAPI
from kubernetes import client as k8s_client, config
import time


@skip_if_not_cluster()
def test_reroute(
client: ClientAPI,
) -> None:
reset(client)
collection = client.create_collection(
name="test",
metadata={"hnsw:construction_ef": 128, "hnsw:search_ef": 128, "hnsw:M": 128},
)

ids = [str(i) for i in range(10)]
embeddings: list[Sequence[float]] = [
[float(i), float(i), float(i)] for i in range(10)
]
collection.add(ids=ids, embeddings=embeddings)
collection.query(query_embeddings=[embeddings[0]])

# Restart the query service using k8s api, in order to trigger a reroute
# of the query service
config.load_kube_config()
v1 = k8s_client.CoreV1Api()
# Find all pods with the label "app=query"
res = v1.list_namespaced_pod("chroma", label_selector="app=query-service")
assert len(res.items) > 0
items = res.items
seen_ids = set()

# Restart all the pods by deleting them
for item in items:
seen_ids.add(item.metadata.uid)
name = item.metadata.name
namespace = item.metadata.namespace
v1.delete_namespaced_pod(name, namespace)

# Wait until we have len(seen_ids) pods running with new UIDs
timeout_secs = 10
start_time = time.time()
while True:
res = v1.list_namespaced_pod("chroma", label_selector="app=query-service")
items = res.items
new_ids = set([item.metadata.uid for item in items])
if len(new_ids) == len(seen_ids) and len(new_ids.intersection(seen_ids)) == 0:
break
if time.time() - start_time > timeout_secs:
assert False, "Timed out waiting for new pods to start"
time.sleep(1)

# Wait for the query service to be ready, or timeout
while True:
res = v1.list_namespaced_pod("chroma", label_selector="app=query-service")
items = res.items
ready = True
for item in items:
if item.status.phase != "Running":
ready = False
break
if ready:
break
if time.time() - start_time > timeout_secs:
assert False, "Timed out waiting for new pods to be ready"
time.sleep(1)

time.sleep(1)
collection.query(query_embeddings=[embeddings[0]])
13 changes: 8 additions & 5 deletions chromadb/test/segment/distributed/test_memberlist_provider.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Tests the CustomResourceMemberlist provider
from dataclasses import asdict
import threading
from chromadb.test.conftest import skip_if_not_cluster
from kubernetes import client, config
from chromadb.config import System, Settings
from chromadb.segment.distributed import Memberlist
from chromadb.segment.distributed import Memberlist, Member
from chromadb.segment.impl.distributed.segment_directory import (
CustomResourceMemberlistProvider,
KUBERNETES_GROUP,
Expand All @@ -17,12 +18,12 @@ def update_memberlist(n: int, memberlist_name: str = "test-memberlist") -> Membe
config.load_config()
api_instance = client.CustomObjectsApi()

members = [{"member_id": f"test-{i}"} for i in range(1, n + 1)]
members = [Member(id=f"test-{i}", ip=f"10.0.0.{i}") for i in range(1, n + 1)]

body = {
"kind": "MemberList",
"metadata": {"name": memberlist_name},
"spec": {"members": members},
"spec": {"members": [{"member_id": m.id, "member_ip": m.ip} for m in members]},
}

_ = api_instance.patch_namespaced_custom_object(
Expand All @@ -34,11 +35,13 @@ def update_memberlist(n: int, memberlist_name: str = "test-memberlist") -> Membe
body=body,
)

return [m["member_id"] for m in members]
return members


def compare_memberlists(m1: Memberlist, m2: Memberlist) -> bool:
return sorted(m1) == sorted(m2)
m1_as_dict = sorted([asdict(m) for m in m1], key=lambda x: x["id"])
m2_as_dict = sorted([asdict(m) for m in m2], key=lambda x: x["id"])
return m1_as_dict == m2_as_dict


@skip_if_not_cluster()
Expand Down
10 changes: 10 additions & 0 deletions go/pkg/memberlist_manager/memberlist_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,16 @@ func memberlistSame(oldMemberlist Memberlist, newMemberlist Memberlist) bool {
if len(oldMemberlist) != len(newMemberlist) {
return false
}
oldMemberlistIps := make(map[string]string)
for _, member := range oldMemberlist {
oldMemberlistIps[member.id] = member.ip
}
for _, member := range newMemberlist {
if ip, ok := oldMemberlistIps[member.id]; !ok || ip != member.ip {
return false
}
}

// use a map to check if the new memberlist contains all the old members
newMemberlistMap := make(map[string]bool)
for _, member := range newMemberlist {
Expand Down
26 changes: 13 additions & 13 deletions go/pkg/memberlist_manager/memberlist_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func TestNodeWatcher(t *testing.T) {
t.Fatalf("Error getting node status: %v", err)
}

return reflect.DeepEqual(memberlist, Memberlist{Member{id: "test-pod-0"}})
return reflect.DeepEqual(memberlist, Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}})
}, 10, 1*time.Second)
if !ok {
t.Fatalf("Node status did not update after adding a pod")
Expand Down Expand Up @@ -83,7 +83,7 @@ func TestNodeWatcher(t *testing.T) {
if err != nil {
t.Fatalf("Error getting node status: %v", err)
}
return reflect.DeepEqual(memberlist, Memberlist{Member{id: "test-pod-0"}})
return reflect.DeepEqual(memberlist, Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}})
}, 10, 1*time.Second)
if !ok {
t.Fatalf("Node status did not update after adding a not ready pod")
Expand All @@ -108,13 +108,13 @@ func TestMemberlistStore(t *testing.T) {
assert.Equal(t, Memberlist{}, memberlist)

// Add a member to the memberlist
memberlist_store.UpdateMemberlist(context.Background(), Memberlist{Member{id: "test-pod-0"}, Member{id: "test-pod-1"}}, "0")
memberlist_store.UpdateMemberlist(context.Background(), Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}, Member{id: "test-pod-1", ip: "10.0.0.2"}}, "0")
memberlist, _, err = memberlist_store.GetMemberlist(context.Background())
if err != nil {
t.Fatalf("Error getting memberlist: %v", err)
}
// assert the memberlist has the correct members
if !memberlistSame(memberlist, Memberlist{Member{id: "test-pod-0"}, Member{id: "test-pod-1"}}) {
if !memberlistSame(memberlist, Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}, Member{id: "test-pod-1", ip: "10.0.0.2"}}) {
t.Fatalf("Memberlist did not update after adding a member")
}
}
Expand Down Expand Up @@ -184,7 +184,7 @@ func TestMemberlistManager(t *testing.T) {

// Get the memberlist
ok := retryUntilCondition(func() bool {
return getMemberlistAndCompare(t, memberlistStore, Memberlist{Member{id: "test-pod-0"}})
return getMemberlistAndCompare(t, memberlistStore, Memberlist{Member{id: "test-pod-0", ip: "10.0.0.49"}})
}, 30, 1*time.Second)
if !ok {
t.Fatalf("Memberlist did not update after adding a pod")
Expand All @@ -195,7 +195,7 @@ func TestMemberlistManager(t *testing.T) {

// Get the memberlist
ok = retryUntilCondition(func() bool {
return getMemberlistAndCompare(t, memberlistStore, Memberlist{Member{id: "test-pod-0"}, Member{id: "test-pod-1"}})
return getMemberlistAndCompare(t, memberlistStore, Memberlist{Member{id: "test-pod-0", ip: "10.0.0.49"}, Member{id: "test-pod-1", ip: "10.0.0.50"}})
}, 30, 1*time.Second)
if !ok {
t.Fatalf("Memberlist did not update after adding a pod")
Expand All @@ -206,7 +206,7 @@ func TestMemberlistManager(t *testing.T) {

// Get the memberlist
ok = retryUntilCondition(func() bool {
return getMemberlistAndCompare(t, memberlistStore, Memberlist{Member{id: "test-pod-1"}})
return getMemberlistAndCompare(t, memberlistStore, Memberlist{Member{id: "test-pod-1", ip: "10.0.0.50"}})
}, 30, 1*time.Second)
if !ok {
t.Fatalf("Memberlist did not update after deleting a pod")
Expand All @@ -217,23 +217,23 @@ func TestMemberlistSame(t *testing.T) {
memberlist := Memberlist{}
assert.True(t, memberlistSame(memberlist, memberlist))

newMemberlist := Memberlist{Member{id: "test-pod-0"}}
newMemberlist := Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}}
assert.False(t, memberlistSame(memberlist, newMemberlist))
assert.False(t, memberlistSame(newMemberlist, memberlist))
assert.True(t, memberlistSame(newMemberlist, newMemberlist))

memberlist = Memberlist{Member{id: "test-pod-1"}}
memberlist = Memberlist{Member{id: "test-pod-1", ip: "10.0.0.2"}}
assert.False(t, memberlistSame(newMemberlist, memberlist))
assert.False(t, memberlistSame(memberlist, newMemberlist))
assert.True(t, memberlistSame(memberlist, memberlist))

memberlist = Memberlist{Member{id: "test-pod-0"}, Member{id: "test-pod-1"}}
newMemberlist = Memberlist{Member{id: "test-pod-0"}, Member{id: "test-pod-1"}}
memberlist = Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}, Member{id: "test-pod-1", ip: "10.0.0.2"}}
newMemberlist = Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}, Member{id: "test-pod-1", ip: "10.0.0.2"}}
assert.True(t, memberlistSame(memberlist, newMemberlist))
assert.True(t, memberlistSame(newMemberlist, memberlist))

memberlist = Memberlist{Member{id: "test-pod-0"}, Member{id: "test-pod-1"}}
newMemberlist = Memberlist{Member{id: "test-pod-1"}, Member{id: "test-pod-0"}}
memberlist = Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}, Member{id: "test-pod-1", ip: "10.0.0.2"}}
newMemberlist = Memberlist{Member{id: "test-pod-1", ip: "10.0.0.2"}, Member{id: "test-pod-0", ip: "10.0.0.1"}}
assert.True(t, memberlistSame(memberlist, newMemberlist))
assert.True(t, memberlistSame(newMemberlist, memberlist))
}
Expand Down
Loading
Loading