diff --git a/runner_manager/auth.py b/runner_manager/auth.py index eceb6622..083b0279 100644 --- a/runner_manager/auth.py +++ b/runner_manager/auth.py @@ -1,12 +1,15 @@ - +from fastapi import Depends, HTTPException, Security, status from fastapi.middleware.trustedhost import TrustedHostMiddleware +from fastapi.security import APIKeyHeader, APIKeyQuery from starlette.types import Receive, Scope, Send +from runner_manager.dependencies import get_settings +from runner_manager.models.settings import Settings + class TrustedHostHealthRoutes(TrustedHostMiddleware): """A healthcheck endpoint that answers to GET requests on /_health""" - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: """If the request is made on the path _health then execute the check on hosts""" if scope["path"] == "/_health/": @@ -14,3 +17,26 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await super().__call__(scope, receive, send) else: await self.app(scope, receive, send) + + +api_key_query = APIKeyQuery(name="api-key", auto_error=False) +api_key_header = APIKeyHeader(name="x-api-key", auto_error=False) +settings = get_settings() + + +def get_api_key( + api_key_query: str = Security(api_key_query), + api_key_header: str = Security(api_key_header), + settings: Settings = Depends(get_settings), +) -> str: + """Get the API key from either the query parameter or the header""" + if not settings.api_key: + return "" + if api_key_query in [settings.api_key.get_secret_value()]: + return api_key_query + if api_key_header in [settings.api_key.get_secret_value()]: + return api_key_header + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API Key", + ) diff --git a/runner_manager/main.py b/runner_manager/main.py index 8f9c512b..419126f2 100644 --- a/runner_manager/main.py +++ b/runner_manager/main.py @@ -1,46 +1,20 @@ import logging -from fastapi import Depends, FastAPI, HTTPException, Security, status -from fastapi.security import APIKeyHeader, APIKeyQuery -from runner_manager.auth import TrustedHostHealthRoutes +from fastapi import FastAPI, Security +from runner_manager.auth import TrustedHostHealthRoutes, get_api_key from runner_manager.dependencies import get_queue, get_settings from runner_manager.jobs.startup import startup -from runner_manager.models.settings import Settings -from runner_manager.routers import webhook, _health +from runner_manager.routers import _health, webhook log = logging.getLogger(__name__) - app = FastAPI() -api_key_query = APIKeyQuery(name="api-key", auto_error=False) -api_key_header = APIKeyHeader(name="x-api-key", auto_error=False) settings = get_settings() -def get_api_key( - api_key_query: str = Security(api_key_query), - api_key_header: str = Security(api_key_header), - settings: Settings = Depends(get_settings), -) -> str: - """Get the API key from either the query parameter or the header""" - if not settings.api_key: - return "" - if api_key_query in [settings.api_key.get_secret_value()]: - return api_key_query - if api_key_header in [settings.api_key.get_secret_value()]: - return api_key_header - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid API Key", - ) - - app.include_router(webhook.router) app.include_router(_health.router) -app.add_middleware( - TrustedHostHealthRoutes, - allowed_hosts=settings.allowed_hosts -) +app.add_middleware(TrustedHostHealthRoutes, allowed_hosts=settings.allowed_hosts) @app.on_event("startup") @@ -51,8 +25,6 @@ def startup_event(): log.info(f"Startup job is {status}") - - @app.get("/public") def public(): """A public endpoint that does not require any authentication.""" diff --git a/runner_manager/routers/_health.py b/runner_manager/routers/_health.py index 52a50bfb..4add2c54 100644 --- a/runner_manager/routers/_health.py +++ b/runner_manager/routers/_health.py @@ -1,4 +1,3 @@ - from fastapi import APIRouter, Response router = APIRouter(prefix="/_health") diff --git a/tests/api/tests_auth.py b/tests/api/tests_auth.py index b68d19ee..d472896a 100644 --- a/tests/api/tests_auth.py +++ b/tests/api/tests_auth.py @@ -4,21 +4,24 @@ from runner_manager.models.settings import Settings +@lru_cache +def settings_api_key(): + return Settings(api_key="secret") + + def test_public_endpoint(client): response = client.get("/public") assert response.status_code == 200 -def test_private_endpoint_without_api_key(client): +def test_private_endpoint_without_api_key(fastapp, client): response = client.get("/private") - # for now we are not using api key - assert response.status_code == 200 - - -@lru_cache -def settings_api_key(): - return Settings(api_key="secret") + settings_api_key.cache_clear() + fastapp.dependency_overrides = {} + fastapp.dependency_overrides[get_settings] = settings_api_key + response = client.get("/private") + assert response.status_code == 401 def test_private_endpoint_with_valid_api_key(fastapp, client):