diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index 33bd00054a7..a5ceacb69f0 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -1,7 +1,8 @@ from chromadb.api import ServerAPI from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System from chromadb.db.system import SysDB -from chromadb.quota import QuotaEnforcer +from chromadb.quota import QuotaEnforcer, Resource +from chromadb.rate_limiting import rate_limit from chromadb.segment import SegmentManager, MetadataReader, VectorReader from chromadb.telemetry.opentelemetry import ( add_attributes_to_current_span, @@ -347,6 +348,7 @@ def delete_collection( raise ValueError(f"Collection {name} does not exist.") @trace_method("SegmentAPI._add", OpenTelemetryGranularity.OPERATION) + @rate_limit(subject="collection_id", resource=Resource.ADD_PER_MINUTE) @override def _add( self, @@ -469,6 +471,7 @@ def _upsert( return True @trace_method("SegmentAPI._get", OpenTelemetryGranularity.OPERATION) + @rate_limit(subject="collection_id", resource=Resource.GET_PER_MINUTE) @override def _get( self, @@ -647,6 +650,7 @@ def _count(self, collection_id: UUID) -> int: return metadata_segment.count() @trace_method("SegmentAPI._query", OpenTelemetryGranularity.OPERATION) + @rate_limit(subject="collection_id", resource=Resource.QUERY_PER_MINUTE) @override def _query( self, diff --git a/chromadb/quota/__init__.py b/chromadb/quota/__init__.py index eebc8f70b00..d74a369462d 100644 --- a/chromadb/quota/__init__.py +++ b/chromadb/quota/__init__.py @@ -14,6 +14,9 @@ class Resource(Enum): METADATA_KEY_LENGTH = "METADATA_KEY_LENGTH" METADATA_VALUE_LENGTH = "METADATA_VALUE_LENGTH" DOCUMENT_SIZE = "DOCUMENT_SIZE" + ADD_PER_MINUTE = "ADD_PER_MINUTE" + QUERY_PER_MINUTE = "QUERY_PER_MINUTE" + GET_PER_MINUTE = "QUERY_PER_MINUTE" EMBEDDINGS_DIMENSION = "EMBEDDINGS_DIMENSION" diff --git a/chromadb/rate_limiting/__init__.py b/chromadb/rate_limiting/__init__.py index fb1a955ff12..8bb2b7040e0 100644 --- a/chromadb/rate_limiting/__init__.py +++ b/chromadb/rate_limiting/__init__.py @@ -29,17 +29,18 @@ def is_allowed(self, key: str, quota: int, point: Optional[int] = 1) -> bool: def rate_limit( subject: str, - resource: str + resource: Resource ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + def decorator(f: Callable[..., Any]) -> Callable[..., Any]: args_name = inspect.getfullargspec(f)[0] if subject not in args_name: raise Exception(f'rate_limit decorator have unknown subject "{subject}", available {args_name}') key_index = args_name.index(subject) - @wraps(f) def wrapper(self, *args: Any, **kwargs: Dict[Any, Any]) -> Any: # If not rate limiting provider is present, just run and return the function. + if self._system.settings.chroma_rate_limiting_provider_impl is None: return f(self, *args, **kwargs) @@ -49,14 +50,16 @@ def wrapper(self, *args: Any, **kwargs: Dict[Any, Any]) -> Any: if len(args) < key_index: return f(self, *args, **kwargs) subject_value = args[key_index-1] - key_value = resource + "-" + subject_value + key_value = resource.value + "-" + str(subject_value) self._system.settings.chroma_rate_limiting_provider_impl quota_provider = self._system.require(QuotaProvider) rate_limiter = self._system.require(RateLimitingProvider) - quota = quota_provider.get_for_subject(resource=resource,subject=subject) + quota = quota_provider.get_for_subject(resource=resource,subject=str(subject_value)) + if quota is None: + return f(self, *args, **kwargs) is_allowed = rate_limiter.is_allowed(key_value, quota) if is_allowed is False: - raise RateLimitError(resource=resource, quota=quota) + raise RateLimitError(resource=resource.value, quota=quota) return f(self, *args, **kwargs) return wrapper diff --git a/chromadb/test/rate_limiting/test_rate_limiting.py b/chromadb/test/rate_limiting/test_rate_limiting.py index 9f7e8c677e7..6f8fb7b8c42 100644 --- a/chromadb/test/rate_limiting/test_rate_limiting.py +++ b/chromadb/test/rate_limiting/test_rate_limiting.py @@ -13,7 +13,7 @@ def __init__(self, system: System): super().__init__(system) self.system = system - @rate_limit(subject="bar", resource="FAKE_RESOURCE") + @rate_limit(subject="bar", resource=Resource.DOCUMENT_SIZE) def bench(self, foo: str, bar: str) -> str: return foo @@ -37,7 +37,7 @@ def rate_limiting_gym() -> QuotaEnforcer: def test_rate_limiting_should_raise(rate_limiting_gym: RateLimitingGym): with pytest.raises(Exception) as exc_info: rate_limiting_gym.bench("foo", "bar") - assert "FAKE_RESOURCE" in str(exc_info.value.resource) + assert Resource.DOCUMENT_SIZE.value in str(exc_info.value.resource) @patch('chromadb.quota.test_provider.QuotaProviderForTest.get_for_subject', mock_get_for_subject) @patch('chromadb.rate_limiting.test_provider.RateLimitingTestProvider.is_allowed', lambda self, key, quota, point=1: True)