-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
118 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from fastapi import Depends, HTTPException, Security, status | ||
from fastapi.middleware.trustedhost import TrustedHostMiddleware | ||
from fastapi.security import APIKeyHeader, APIKeyQuery | ||
from starlette.types import Receive, Scope, Send | ||
|
||
from runner_manager.dependencies import get_settings | ||
from runner_manager.models.settings import Settings | ||
|
||
api_key_query = APIKeyQuery(name="api-key", auto_error=False) | ||
api_key_header = APIKeyHeader(name="x-api-key", auto_error=False) | ||
|
||
|
||
class TrustedHostHealthRoutes(TrustedHostMiddleware): | ||
"""A healthcheck endpoint that answers to GET requests on /_health""" | ||
|
||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: | ||
"""If the request is made on the path _health then execute the check on hosts""" | ||
if scope["path"] == "/_health/": | ||
print(scope) | ||
await super().__call__(scope, receive, send) | ||
else: | ||
await self.app(scope, receive, send) | ||
|
||
|
||
def get_api_key( | ||
api_key_query: str = Security(api_key_query), | ||
api_key_header: str = Security(api_key_header), | ||
settings: Settings = Depends(get_settings), | ||
) -> str: | ||
"""Get the API key from either the query parameter or the header""" | ||
if not settings.api_key: | ||
return "" | ||
if api_key_query in [settings.api_key.get_secret_value()]: | ||
return api_key_query | ||
if api_key_header in [settings.api_key.get_secret_value()]: | ||
return api_key_header | ||
raise HTTPException( | ||
status_code=status.HTTP_401_UNAUTHORIZED, | ||
detail="Invalid API Key", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from fastapi import APIRouter, Response | ||
|
||
router = APIRouter(prefix="/_health") | ||
|
||
|
||
@router.get("/", status_code=200) | ||
def healthcheck(): | ||
"""Healthcheck endpoint that answers to GET requests on /_health""" | ||
return Response(status_code=200) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from fastapi import APIRouter, Security | ||
|
||
from runner_manager.auth import get_api_key | ||
|
||
router = APIRouter(prefix="/private") | ||
|
||
|
||
@router.get("/") | ||
def private(api_key: str = Security(get_api_key)): | ||
"""A private endpoint that requires a valid API key to be provided.""" | ||
return f"Private Endpoint. API Key: {api_key}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from fastapi import APIRouter | ||
|
||
router = APIRouter(prefix="/public") | ||
|
||
|
||
@router.get("/") | ||
def public(): | ||
"""A public endpoint that does not require any authentication.""" | ||
return "Public Endpoint" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from functools import lru_cache | ||
|
||
from runner_manager.dependencies import get_settings | ||
from runner_manager.models.settings import Settings | ||
|
||
|
||
@lru_cache | ||
def settings_api_key(): | ||
return Settings(api_key="secret") | ||
|
||
|
||
def test_public_endpoint(client): | ||
response = client.get("/public") | ||
|
||
assert response.status_code == 200 | ||
|
||
|
||
def test_private_endpoint_without_api_key(fastapp, client): | ||
response = client.get("/private") | ||
settings_api_key.cache_clear() | ||
fastapp.dependency_overrides = {} | ||
fastapp.dependency_overrides[get_settings] = settings_api_key | ||
response = client.get("/private") | ||
assert response.status_code == 401 | ||
|
||
|
||
def test_private_endpoint_with_valid_api_key(fastapp, client): | ||
settings_api_key.cache_clear() | ||
fastapp.dependency_overrides = {} | ||
fastapp.dependency_overrides[get_settings] = settings_api_key | ||
headers = {"x-api-key": "secret"} | ||
response = client.get("/private", headers=headers) | ||
assert response.status_code == 200 |