Skip to content

Commit

Permalink
Feature/middleware asgi (#50)
Browse files Browse the repository at this point in the history
feat: asgi middleware
  • Loading branch information
livioribeiro authored Jan 11, 2025
1 parent 2a7abf2 commit adfea1d
Show file tree
Hide file tree
Showing 24 changed files with 385 additions and 200 deletions.
98 changes: 64 additions & 34 deletions docs/middleware/overview.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,38 @@
# Middleware

The middleware pipeline is configured with the `middleware` configuration property.
It must contain a list of functions that have the following signature:
It must contain a list of functions that receive the next app in the pipeline, the
settings object and the dependency injection container and must return a plain asgi
middleware instance:

```python
async def middleware(callnext, request): ...
def middleware_factory(app, settings, di): ...
```

Any asgi middleware can be used in the middleware pipeline. For instance, it is
possible to use the SessionMiddleware from starlette:


=== "application/middleware.py"

```python
from starlette.middleware.sessions import SessionMiddleware


def session_middleware(app, settings, di):
return SessionMiddleware(app, secret_key=settings.session.secret_key)
```

=== "configuration/settings.yaml"

```yaml
middleware:
- application.middleware:session_middleware

session:
secret_key: super_secret_key
```

## Usage

To demonstrate the middleware system, we will create a timing middleware that will
Expand All @@ -32,22 +58,23 @@ output to the console the time spent in the processing of the request:
from datetime import datetime

import structlog

from selva.di import service

logger = structlog.get_logger()


async def timing_middleware(callnext, request):
request_start = datetime.now()
await callnext(request) # (1)
request_end = datetime.now()

delta = request_end - request_start
logger.info("request duration", duration=str(delta))
def timing_middleware(app, settings, di):
async def inner(scope, receive, send):
request_start = datetime.now()
await app(scope, receive, send)
request_end = datetime.now()

delta = request_end - request_start
logger.info("request duration", duration=str(delta))
return inner
```

1. Invoke the middleware chain to process the request

=== "configuration/settings.yaml"

```yaml
Expand All @@ -57,41 +84,44 @@ output to the console the time spent in the processing of the request:

## Middleware dependencies

Middleware functions are called using the same machinery as handlers, and therefore
can have services injected. Our `timing_middleware`, for instance, could persist
Middleware functions can use the provided dependency injection container to get
services the middleware might need. We could rewrite the timing middleware to persist
the timings using a service instead of printing to the console:

=== "application/service.py"
=== "application/middleware.py"

```python
from datetime import datetime

from selva.di import service
from application.service import TimingService

@service
class TimingService:
async def save(start: datetime, end: datetime):
...

class TimingMiddleware:
def __init__(self, app, timing_service: TimingService):
self.timing_service = timing_service

async def __call__(self, scope, receive, send):
request_start = datetime.now()
await app(scope, receive, send)
request_end = datetime.now()

await self.timing_service.save(request_start, request_end)


async def timing_middleware(app, settings, di):
timing_service = await di.get(TimingService)
return TimingMiddleware(app, timing_service)
```

=== "application/middleware.py"
=== "application/service.py"

```python
from datetime import datetime
from typing import Annotated

from selva.di import service, Inject

from application.service import TimingService

from selva.di import service

async def __call__(
call_next, request,
timing_service: Annotated[TimingService, Inject]
):
request_start = datetime.now()
await call_next(request)
request_end = datetime.now()

await timing_service.save(request_start, request_end)
@service
class TimingService:
async def save(start: datetime, end: datetime):
...
```
Empty file.
30 changes: 30 additions & 0 deletions examples/celery/application/celery_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import asyncio

from celery import Celery

from selva.configuration.settings import get_settings
from selva.di import Container

from .service import Greeter

app = Celery("hello", broker="redis://localhost:6379/2")


_di_container = None


def di_container() -> Container:
global _di_container
if _di_container is None:
settings = get_settings()
_di_container = Container()
_di_container.scan(settings.application)
return _di_container


@app.task
def hello(name: str):
di = di_container()
greeter = asyncio.run(di.get(Greeter))
result = greeter.greet(f"{name}, from celery")
return result
14 changes: 14 additions & 0 deletions examples/celery/application/handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import asyncio
from typing import Annotated as A

from asgikit.responses import respond_text

from selva.web import FromQuery, get

from .celery_tasks import hello


@get("")
async def index(request, name: A[str, FromQuery] = "World"):
await asyncio.to_thread(hello.delay, name)
await respond_text(request.response, "Hello, world!")
7 changes: 7 additions & 0 deletions examples/celery/application/service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from selva.di import service


@service
class Greeter:
def greet(self, name: str) -> str:
return f"Hello, {name}!"
2 changes: 2 additions & 0 deletions examples/celery/configuration/settings.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
logging:
format: console
1 change: 1 addition & 0 deletions examples/celery/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
celery[redis]
2 changes: 2 additions & 0 deletions examples/exception_handler/configuration/settings.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
logging:
format: console
2 changes: 1 addition & 1 deletion examples/hello_world/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
FromPath,
FromQuery,
Json,
background,
get,
post,
background,
startup,
)

Expand Down
99 changes: 59 additions & 40 deletions examples/middleware/application/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,55 +11,74 @@
logger = structlog.get_logger()


async def auth_middleware(callnext, request: Request):
if request.path == "/protected":
response = request.response
authn = request.headers.get("authorization")
if not authn:
response.header(
"WWW-Authenticate", 'Basic realm="localhost:8000/protected"'
)
await respond_status(response, HTTPStatus.UNAUTHORIZED)
return
def auth_middleware(app, settings, di):
auth_user = settings.auth.username
auth_pass = settings.auth.password

async def func(scope, receive, send):
if scope["path"] == "/protected":
request = Request(scope, receive, send)
authn = request.headers.get("authorization")

if not authn:
request.response.header(
"WWW-Authenticate", 'Basic realm="localhost:8000/protected"'
)
await respond_status(request.response, HTTPStatus.UNAUTHORIZED)
return

authn = authn.removeprefix("Basic")
user, password = base64.urlsafe_b64decode(authn).decode().split(":")
logger.info("user logged in", user=user, password=password)
request["user"] = user
authn = authn.removeprefix("Basic").strip()
user, password = base64.urlsafe_b64decode(authn).decode().split(":")
if not (user == auth_user and password == auth_pass):
await respond_status(request.response, HTTPStatus.UNAUTHORIZED)
return

await callnext(request)
logger.info("user logged in", user=user, password=password)
request["user"] = user

await app(scope, receive, send)

async def timing_middleware(callnext, request: Request):
if request.is_websocket:
await callnext(request)
return
return func

request_start = datetime.now()
await callnext(request)
request_end = datetime.now()

delta = request_end - request_start
logger.info("request duration", duration=str(delta))
def timing_middleware(app, settings, di):
async def func(scope, receive, send):
if scope["type"] == "websocket":
await app(scope, receive, send)
return

request_start = datetime.now()
await app(scope, receive, send)
request_end = datetime.now()

delta = request_end - request_start
logger.info("request duration", duration=str(delta))

async def logging_middleware(callnext, request: Request, user: User = None):
logger.info("user", user=user.name if user else "<no user>")
return func


def logging_middleware(app, settings, di):
async def func(scope, receive, send):
if user := scope.get("user"):
logger.info("user", user=user.name)

if scope["type"] == "websocket":
await app(scope, receive, send)
return

if request.is_websocket:
await callnext(request)
return
await app(scope, receive, send)

await callnext(request)
request = Request(scope, receive, send)
client = f"{request.client[0]}:{request.client[1]}"
request_line = f"{request.method} {request.path} HTTP/{request.http_version}"
status = request.response.status

client = f"{request.client[0]}:{request.client[1]}"
request_line = f"{request.method} {request.path} HTTP/{request.http_version}"
status = request.response.status
logger.info(
"request",
client=client,
request_line=request_line,
status=status.value,
status_phrase=status.phrase,
)

logger.info(
"request",
client=client,
request_line=request_line,
status=status.value,
status_phrase=status.phrase,
)
return func
4 changes: 4 additions & 0 deletions examples/middleware/configuration/settings.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
auth:
username: admin
password: "123"

middleware:
- application.middleware.auth_middleware
- application.middleware.timing_middleware
Expand Down
4 changes: 3 additions & 1 deletion examples/middleware_files/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
async def index(request):
uploads = Path("resources/uploads")
files = [item.relative_to(uploads) for item in uploads.iterdir() if item.is_file()]
result = "\n".join(f'<li><a href="/uploads/{item}">{item}</a></li>' for item in files)
result = "\n".join(
f'<li><a href="/uploads/{item}">{item}</a></li>' for item in files
)
result = f"<html><body><ul>{result}</ul></body></html>"
request.response.content_type = "text/html"
await respond_text(request.response, result)
1 change: 0 additions & 1 deletion examples/middleware_files/configuration/settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ middleware:
- selva.web.middleware.files.uploaded_files_middleware
staticfiles:
mappings:
# /: index.html
favicon.ico: python.ico

logging:
Expand Down
2 changes: 1 addition & 1 deletion examples/websocket/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def handle_websocket(self, request: Request):

async def broadcast(self, message: str):
if message.lower() == "ping":
message = "Pong"
message = message.replace("i", "o").replace("I", "O")

for client, ws in self.clients.items():
try:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ dynamic = ["version"]
requires-python = ">=3.11"

dependencies = [
"asgikit~=0.11.0",
"asgikit~=0.12.1",
"pydantic~=2.10.3",
"python-dotenv~=1.0.1",
"ruamel.yaml~=0.18.6",
Expand Down
2 changes: 1 addition & 1 deletion src/selva/web/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# ruff: noqa: F401

from selva.web.converter import Json, Form
from selva.web.converter import Form, Json
from selva.web.converter.param_extractor import (
FromBody,
FromCookie,
Expand Down
Loading

0 comments on commit adfea1d

Please sign in to comment.