Skip to content

Commit

Permalink
Add rate limiting to /register endpoint and update documented examp…
Browse files Browse the repository at this point in the history
…le config

Refactor rate limiting to consolidate code
Refactor rate limit buckets to be semantically consistent
  • Loading branch information
NeonDaniel committed Nov 8, 2024
1 parent d4efcdf commit bd3c7a2
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 19 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ hana:
mq_default_timeout: 10
access_token_ttl: 86400 # 1 day
refresh_token_ttl: 604800 # 1 week
requests_per_minute: 60
requests_per_minute: 60 # All other requests (auth, registration, etc) also count towards this limit
auth_requests_per_minute: 6 # This counts valid and invalid requests from an IP address
registration_requests_per_hour: 4 # This is low to prevent malicious user creation that will pollute the database
access_token_secret: a800445648142061fc238d1f84e96200da87f4f9fa7835cac90db8b4391b117b
refresh_token_secret: 833d369ac73d883123743a44b4a7fe21203cffc956f4c8fec712e71aafa8e1aa
jwt_issuer: neon.ai # Used in the `iss` field of generated JWT tokens.
Expand Down
6 changes: 4 additions & 2 deletions neon_hana/app/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,7 @@ async def check_refresh(request: RefreshRequest) -> AuthenticationResponse:


@auth_route.post("/register")
async def register_user(request: RegistrationRequest) -> User:
return client_manager.check_registration_request(**dict(request))
async def register_user(register_request: RegistrationRequest,
request: Request) -> User:
return client_manager.check_registration_request(**dict(register_request),
origin_ip=request.client.host)
45 changes: 29 additions & 16 deletions neon_hana/auth/client_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(self, config: dict,
self._refresh_secret = config.get("refresh_token_secret")
self._rpm = config.get("requests_per_minute", 60)
self._auth_rpm = config.get("auth_requests_per_minute", 6)
self._register_rph = config.get("registration_requests_per_hour", 4)
self._disable_auth = config.get("disable_auth")
self._max_streaming_clients = config.get("max_streaming_clients")
self._jwt_algo = "HS256"
Expand Down Expand Up @@ -156,11 +157,31 @@ def disconnect_stream(self):
with self._stream_check_lock:
self._connected_streams -= 1

def _consume_rate_limit_token(self, ratelimit_id: str):
if not self.rate_limiter.consume(ratelimit_id):
bucket = list(self.rate_limiter.get_all_buckets(ratelimit_id).
values())[0]
replenish_time = bucket.last_replenished + bucket.replenish_time
wait_time = round(replenish_time - time())
ip_addr, request_cls = ratelimit_id.split('-', 1)
raise HTTPException(status_code=429,
detail=f"Too many {request_cls} requests from: "
f"{ip_addr}. Wait {wait_time}s.")

def check_registration_request(self, username: str, password: str,
user_config: NeonUserConfig) -> User:
user_config: NeonUserConfig,
origin_ip: str = "127.0.0.1") -> User:
"""
Handle a request to register a new user.
"""

ratelimit_id = f"{origin_ip}-register"
if not self.rate_limiter.get_all_buckets(ratelimit_id):
self.rate_limiter.add_bucket(ratelimit_id,
TokenBucket(replenish_time=3600,
max_tokens=self._register_rph))
self._consume_rate_limit_token(ratelimit_id)

new_user = User(username=username, password_hash=password,
neon=user_config, permissions=_DEFAULT_USER_PERMISSIONS)
if self._mq_connector:
Expand Down Expand Up @@ -190,19 +211,12 @@ def check_auth_request(self, client_id: str, username: str,
# print(f"Using cached client: {self.authorized_clients[client_id]}")
# return self.authorized_clients[client_id]

ratelimit_id = f"auth{origin_ip}"
ratelimit_id = f"{origin_ip}-auth"
if not self.rate_limiter.get_all_buckets(ratelimit_id):
self.rate_limiter.add_bucket(ratelimit_id,
TokenBucket(replenish_time=60,
max_tokens=self._auth_rpm))
if not self.rate_limiter.consume(ratelimit_id):
bucket = list(self.rate_limiter.get_all_buckets(ratelimit_id).
values())[0]
replenish_time = bucket.last_replenished + bucket.replenish_time
wait_time = round(replenish_time - time())
raise HTTPException(status_code=429,
detail=f"Too many auth requests from: "
f"{origin_ip}. Wait {wait_time}s.")
self._consume_rate_limit_token(ratelimit_id)

if self._mq_connector is None:
# Auth is disabled; every auth request gets a successful response
Expand Down Expand Up @@ -321,14 +335,13 @@ def get_client_id(self, token: str) -> str:
return auth.client_id

def validate_auth(self, token: str, origin_ip: str) -> bool:
if not self.rate_limiter.get_all_buckets(origin_ip):
self.rate_limiter.add_bucket(origin_ip,
ratelimit_id = f"{origin_ip}-total"
if not self.rate_limiter.get_all_buckets(ratelimit_id):
self.rate_limiter.add_bucket(ratelimit_id,
TokenBucket(replenish_time=60,
max_tokens=self._rpm))
if not self.rate_limiter.consume(origin_ip) and self._rpm > 0:
raise HTTPException(status_code=429,
detail=f"Requests limited to {self._rpm}/min "
f"per client connection")
if self._rpm > 0:
self._consume_rate_limit_token(ratelimit_id)

if self._disable_auth:
return True
Expand Down

0 comments on commit bd3c7a2

Please sign in to comment.