Skip to content

Commit

Permalink
PTFE-672: api authentication (#322)
Browse files Browse the repository at this point in the history
  • Loading branch information
Abubakarr99 authored Aug 11, 2023
2 parents 260129b + 088a8b0 commit e3eb51c
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 11 deletions.
1 change: 0 additions & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ jobs:
env:
REDIS_OM_URL: redis://localhost:6379/0
GITHUB_BASE_URL: http://localhost:4010

steps:
- uses: actions/checkout@v3
- name: Boot compose services
Expand Down
40 changes: 40 additions & 0 deletions runner_manager/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from fastapi import Depends, HTTPException, Security, status
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from fastapi.security import APIKeyHeader, APIKeyQuery
from starlette.types import Receive, Scope, Send

from runner_manager.dependencies import get_settings
from runner_manager.models.settings import Settings

api_key_query = APIKeyQuery(name="api-key", auto_error=False)
api_key_header = APIKeyHeader(name="x-api-key", auto_error=False)


class TrustedHostHealthRoutes(TrustedHostMiddleware):
"""A healthcheck endpoint that answers to GET requests on /_health"""

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""If the request is made on the path _health then execute the check on hosts"""
if scope["path"] == "/_health/":
print(scope)
await super().__call__(scope, receive, send)
else:
await self.app(scope, receive, send)


def get_api_key(
api_key_query: str = Security(api_key_query),
api_key_header: str = Security(api_key_header),
settings: Settings = Depends(get_settings),
) -> str:
"""Get the API key from either the query parameter or the header"""
if not settings.api_key:
return ""
if api_key_query in [settings.api_key.get_secret_value()]:
return api_key_query
if api_key_header in [settings.api_key.get_secret_value()]:
return api_key_header
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API Key",
)
19 changes: 10 additions & 9 deletions runner_manager/main.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
import logging

from fastapi import FastAPI, Response
from fastapi import FastAPI

from runner_manager.dependencies import get_queue
from runner_manager.auth import TrustedHostHealthRoutes
from runner_manager.dependencies import get_queue, get_settings
from runner_manager.jobs.startup import startup
from runner_manager.routers import webhook
from runner_manager.routers import _health, private, public, webhook

log = logging.getLogger(__name__)

app = FastAPI()
settings = get_settings()


app.include_router(webhook.router)
app.include_router(_health.router)
app.include_router(private.router)
app.include_router(public.router)
app.add_middleware(TrustedHostHealthRoutes, allowed_hosts=settings.allowed_hosts)


@app.on_event("startup")
Expand All @@ -19,8 +25,3 @@ def startup_event():
job = queue.enqueue(startup)
status = job.get_status()
log.info(f"Startup job is {status}")


@app.get("/_health")
def health():
return Response(status_code=200)
7 changes: 6 additions & 1 deletion runner_manager/models/settings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Sequence

import yaml
from pydantic import AnyHttpUrl, BaseSettings, RedisDsn, SecretStr
Expand Down Expand Up @@ -33,6 +33,11 @@ class Settings(BaseSettings):
name: Optional[str] = "runner-manager"
redis_om_url: Optional[RedisDsn] = None
github_base_url: Optional[AnyHttpUrl] = None
api_key: Optional[SecretStr] = None
allowed_hosts: Optional[Sequence[str]] = [
"localhost",
"testserver",
]
github_webhook_secret: Optional[SecretStr] = None
log_level: LogLevel = LogLevel.INFO

Expand Down
9 changes: 9 additions & 0 deletions runner_manager/routers/_health.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from fastapi import APIRouter, Response

router = APIRouter(prefix="/_health")


@router.get("/", status_code=200)
def healthcheck():
"""Healthcheck endpoint that answers to GET requests on /_health"""
return Response(status_code=200)
11 changes: 11 additions & 0 deletions runner_manager/routers/private.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from fastapi import APIRouter, Security

from runner_manager.auth import get_api_key

router = APIRouter(prefix="/private")


@router.get("/")
def private(api_key: str = Security(get_api_key)):
"""A private endpoint that requires a valid API key to be provided."""
return f"Private Endpoint. API Key: {api_key}"
9 changes: 9 additions & 0 deletions runner_manager/routers/public.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from fastapi import APIRouter

router = APIRouter(prefix="/public")


@router.get("/")
def public():
"""A public endpoint that does not require any authentication."""
return "Public Endpoint"
33 changes: 33 additions & 0 deletions tests/api/test_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from functools import lru_cache

from runner_manager.dependencies import get_settings
from runner_manager.models.settings import Settings


@lru_cache
def settings_api_key():
return Settings(api_key="secret")


def test_public_endpoint(client):
response = client.get("/public")

assert response.status_code == 200


def test_private_endpoint_without_api_key(fastapp, client):
response = client.get("/private")
settings_api_key.cache_clear()
fastapp.dependency_overrides = {}
fastapp.dependency_overrides[get_settings] = settings_api_key
response = client.get("/private")
assert response.status_code == 401


def test_private_endpoint_with_valid_api_key(fastapp, client):
settings_api_key.cache_clear()
fastapp.dependency_overrides = {}
fastapp.dependency_overrides[get_settings] = settings_api_key
headers = {"x-api-key": "secret"}
response = client.get("/private", headers=headers)
assert response.status_code == 200

0 comments on commit e3eb51c

Please sign in to comment.