diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index cf6b7247..2626ce77 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -27,7 +27,6 @@ jobs: env: REDIS_OM_URL: redis://localhost:6379/0 GITHUB_BASE_URL: http://localhost:4010 - steps: - uses: actions/checkout@v3 - name: Boot compose services diff --git a/runner_manager/auth.py b/runner_manager/auth.py new file mode 100644 index 00000000..26e4ed7f --- /dev/null +++ b/runner_manager/auth.py @@ -0,0 +1,40 @@ +from fastapi import Depends, HTTPException, Security, status +from fastapi.middleware.trustedhost import TrustedHostMiddleware +from fastapi.security import APIKeyHeader, APIKeyQuery +from starlette.types import Receive, Scope, Send + +from runner_manager.dependencies import get_settings +from runner_manager.models.settings import Settings + +api_key_query = APIKeyQuery(name="api-key", auto_error=False) +api_key_header = APIKeyHeader(name="x-api-key", auto_error=False) + + +class TrustedHostHealthRoutes(TrustedHostMiddleware): + """A healthcheck endpoint that answers to GET requests on /_health""" + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """If the request is made on the path _health then execute the check on hosts""" + if scope["path"] == "/_health/": + print(scope) + await super().__call__(scope, receive, send) + else: + await self.app(scope, receive, send) + + +def get_api_key( + api_key_query: str = Security(api_key_query), + api_key_header: str = Security(api_key_header), + settings: Settings = Depends(get_settings), +) -> str: + """Get the API key from either the query parameter or the header""" + if not settings.api_key: + return "" + if api_key_query in [settings.api_key.get_secret_value()]: + return api_key_query + if api_key_header in [settings.api_key.get_secret_value()]: + return api_key_header + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API Key", + ) diff --git a/runner_manager/main.py b/runner_manager/main.py index 0a7d95cd..cd85159b 100644 --- a/runner_manager/main.py +++ b/runner_manager/main.py @@ -1,16 +1,22 @@ import logging -from fastapi import FastAPI, Response +from fastapi import FastAPI -from runner_manager.dependencies import get_queue +from runner_manager.auth import TrustedHostHealthRoutes +from runner_manager.dependencies import get_queue, get_settings from runner_manager.jobs.startup import startup -from runner_manager.routers import webhook +from runner_manager.routers import _health, private, public, webhook log = logging.getLogger(__name__) - app = FastAPI() +settings = get_settings() + app.include_router(webhook.router) +app.include_router(_health.router) +app.include_router(private.router) +app.include_router(public.router) +app.add_middleware(TrustedHostHealthRoutes, allowed_hosts=settings.allowed_hosts) @app.on_event("startup") @@ -19,8 +25,3 @@ def startup_event(): job = queue.enqueue(startup) status = job.get_status() log.info(f"Startup job is {status}") - - -@app.get("/_health") -def health(): - return Response(status_code=200) diff --git a/runner_manager/models/settings.py b/runner_manager/models/settings.py index 23118add..a63503cf 100644 --- a/runner_manager/models/settings.py +++ b/runner_manager/models/settings.py @@ -1,6 +1,6 @@ from enum import Enum from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Sequence import yaml from pydantic import AnyHttpUrl, BaseSettings, RedisDsn, SecretStr @@ -33,6 +33,11 @@ class Settings(BaseSettings): name: Optional[str] = "runner-manager" redis_om_url: Optional[RedisDsn] = None github_base_url: Optional[AnyHttpUrl] = None + api_key: Optional[SecretStr] = None + allowed_hosts: Optional[Sequence[str]] = [ + "localhost", + "testserver", + ] github_webhook_secret: Optional[SecretStr] = None log_level: LogLevel = LogLevel.INFO diff --git a/runner_manager/routers/_health.py b/runner_manager/routers/_health.py new file mode 100644 index 00000000..5927db75 --- /dev/null +++ b/runner_manager/routers/_health.py @@ -0,0 +1,9 @@ +from fastapi import APIRouter, Response + +router = APIRouter(prefix="/_health") + + +@router.get("/", status_code=200) +def healthcheck(): + """Healthcheck endpoint that answers to GET requests on /_health""" + return Response(status_code=200) diff --git a/runner_manager/routers/private.py b/runner_manager/routers/private.py new file mode 100644 index 00000000..4d62c3cf --- /dev/null +++ b/runner_manager/routers/private.py @@ -0,0 +1,11 @@ +from fastapi import APIRouter, Security + +from runner_manager.auth import get_api_key + +router = APIRouter(prefix="/private") + + +@router.get("/") +def private(api_key: str = Security(get_api_key)): + """A private endpoint that requires a valid API key to be provided.""" + return f"Private Endpoint. API Key: {api_key}" diff --git a/runner_manager/routers/public.py b/runner_manager/routers/public.py new file mode 100644 index 00000000..e110d9c7 --- /dev/null +++ b/runner_manager/routers/public.py @@ -0,0 +1,9 @@ +from fastapi import APIRouter + +router = APIRouter(prefix="/public") + + +@router.get("/") +def public(): + """A public endpoint that does not require any authentication.""" + return "Public Endpoint" diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py new file mode 100644 index 00000000..d472896a --- /dev/null +++ b/tests/api/test_auth.py @@ -0,0 +1,33 @@ +from functools import lru_cache + +from runner_manager.dependencies import get_settings +from runner_manager.models.settings import Settings + + +@lru_cache +def settings_api_key(): + return Settings(api_key="secret") + + +def test_public_endpoint(client): + response = client.get("/public") + + assert response.status_code == 200 + + +def test_private_endpoint_without_api_key(fastapp, client): + response = client.get("/private") + settings_api_key.cache_clear() + fastapp.dependency_overrides = {} + fastapp.dependency_overrides[get_settings] = settings_api_key + response = client.get("/private") + assert response.status_code == 401 + + +def test_private_endpoint_with_valid_api_key(fastapp, client): + settings_api_key.cache_clear() + fastapp.dependency_overrides = {} + fastapp.dependency_overrides[get_settings] = settings_api_key + headers = {"x-api-key": "secret"} + response = client.get("/private", headers=headers) + assert response.status_code == 200