Skip to content

Commit

Permalink
Add aiohttp client sharing (#3024)
Browse files Browse the repository at this point in the history
* Run the app as ASGI

* Fix static file handling on localhost

Note: this does not apply to production because we serve static files from Nginx there: those static file requests never make it to the Django application and, indeed, it is not configured to serve static files. This change uses the ASGI static file handler that the Django `runserver` management command uses and correctly handles streaming responses. The only consequence of not doing this is that warnings will appear locally and, if for some reason local interactions are bypassing the static file cache on the browser, you could get a memory leak. Again, that only applies to local environments. Python code never interacts with, considers, or is configured for static files in production, so this is not an issue for production. The correct behaviour for production, which you can test by setting ENVIRONMENT to something other than `local` in `api/.env`, is to 404 on static files.

* Add aiohttp client sharing

* Add aiohttp session manager tests

* Use clearer middleware pattern and consistent app export name

* Fix default environment in api env template

* Switch to django-asgi-lifecycle instead of custom asgi lifecycle implementation
  • Loading branch information
sarayourfriend authored Nov 22, 2023
1 parent b91cdc3 commit 1de2d21
Show file tree
Hide file tree
Showing 11 changed files with 725 additions and 621 deletions.
1 change: 1 addition & 0 deletions api/Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ aiohttp = "~=3.8"
aws-requests-auth = "~=0.4"
deepdiff = "~=6.4"
Django = "~=4.2"
django-asgi-lifespan = "~=0.1"
django-cors-headers = "~=4.2"
django-log-request-id = "~=2.0"
django-oauth-toolkit = "~=2.3"
Expand Down
1,053 changes: 539 additions & 514 deletions api/Pipfile.lock

Large diffs are not rendered by default.

79 changes: 79 additions & 0 deletions api/api/utils/aiohttp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import asyncio
import logging
import weakref

import aiohttp
import sentry_sdk
from django_asgi_lifespan.signals import asgi_shutdown


logger = logging.getLogger(__name__)


_SESSIONS: weakref.WeakKeyDictionary[
asyncio.AbstractEventLoop, aiohttp.ClientSession
] = weakref.WeakKeyDictionary()

_LOCKS: weakref.WeakKeyDictionary[
asyncio.AbstractEventLoop, asyncio.Lock
] = weakref.WeakKeyDictionary()


@asgi_shutdown.connect
async def _close_sessions(sender, **kwargs):
logger.debug("Closing aiohttp sessions on application shutdown")

closed_sessions = 0

while _SESSIONS:
loop, session = _SESSIONS.popitem()
try:
await session.close()
closed_sessions += 1
except BaseException as exc:
logger.error(exc)
sentry_sdk.capture_exception(exc)

logger.debug("Successfully closed %s session(s)", closed_sessions)


async def get_aiohttp_session() -> aiohttp.ClientSession:
"""
Safely retrieve a shared aiohttp session for the current event loop.
If the loop already has an aiohttp session associated, it will be reused.
If the loop has not yet had an aiohttp session created for it, a new one
will be created and returned.
While the main application will always run in the same loop, and while
that covers 99% of our use cases, it is still possible for `async_to_sync`
to cause a new loop to be created if, for example, `force_new_loop` is
passed. In order to prevent surprises should that ever be the case, this
function assumes that it's possible for multiple loops to be present in
the lifetime of the application and therefore we need to verify that each
loop gets its own session.
"""

loop = asyncio.get_running_loop()

if loop not in _LOCKS:
_LOCKS[loop] = asyncio.Lock()

async with _LOCKS[loop]:
if loop not in _SESSIONS:
create_session = True
msg = "No session for loop. Creating new session."
elif _SESSIONS[loop].closed:
create_session = True
msg = "Loop's previous session closed. Creating new session."
else:
create_session = False
msg = "Reusing existing session for loop."

logger.info(msg)

if create_session:
session = aiohttp.ClientSession()
_SESSIONS[loop] = session

return _SESSIONS[loop]
21 changes: 13 additions & 8 deletions api/api/utils/check_dead_links/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from decouple import config
from elasticsearch_dsl.response import Hit

from api.utils.aiohttp import get_aiohttp_session
from api.utils.check_dead_links.provider_status_mappings import provider_status_mappings
from api.utils.dead_link_mask import get_query_mask, save_query_mask

Expand All @@ -32,10 +33,16 @@ def _get_expiry(status, default):
return config(f"LINK_VALIDATION_CACHE_EXPIRY__{status}", default=default, cast=int)


async def _head(url: str, session: aiohttp.ClientSession) -> tuple[str, int]:
_timeout = aiohttp.ClientTimeout(total=2)


async def _head(url: str) -> tuple[str, int]:
try:
async with session.head(url, allow_redirects=False) as response:
return url, response.status
session = await get_aiohttp_session()
response = await session.head(
url, allow_redirects=False, headers=HEADERS, timeout=_timeout
)
return url, response.status
except (aiohttp.ClientError, asyncio.TimeoutError) as exception:
_log_validation_failure(exception)
return url, -1
Expand All @@ -45,11 +52,9 @@ async def _head(url: str, session: aiohttp.ClientSession) -> tuple[str, int]:
@async_to_sync
async def _make_head_requests(urls: list[str]) -> list[tuple[str, int]]:
tasks = []
timeout = aiohttp.ClientTimeout(total=2)
async with aiohttp.ClientSession(headers=HEADERS, timeout=timeout) as session:
tasks = [asyncio.ensure_future(_head(url, session)) for url in urls]
responses = asyncio.gather(*tasks)
await responses
tasks = [asyncio.ensure_future(_head(url)) for url in urls]
responses = asyncio.gather(*tasks)
await responses
return responses.result()


Expand Down
10 changes: 2 additions & 8 deletions api/conf/asgi.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,19 @@
import os

import django
from django.conf import settings
from django.contrib.staticfiles.handlers import ASGIStaticFilesHandler

from conf.asgi_handler import OpenverseASGIHandler
from django_asgi_lifespan.asgi import get_asgi_application


os.environ.setdefault("DJANGO_SETTINGS_MODULE", "conf.settings")


def get_asgi_application():
django.setup(set_prefix=False)
return OpenverseASGIHandler()


application = get_asgi_application()


if settings.ENVIRONMENT == "local":
static_files_application = ASGIStaticFilesHandler(application)
application = ASGIStaticFilesHandler(application)


if settings.GC_DEBUG_LOGGING:
Expand Down
75 changes: 0 additions & 75 deletions api/conf/asgi_handler.py

This file was deleted.

2 changes: 1 addition & 1 deletion api/env.template
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ DJANGO_SECRET_KEY="ny#b__$$f6ry4wy8oxre97&-68u_0lk3gw(z=d40_dxey3zw0v1"
DJANGO_DEBUG_ENABLED=True

BASE_URL=http://localhost:50280/
ENVIRONMENT=development
ENVIRONMENT=local
# List of comma-separated hosts/domain names, e.g., 127.17.0.1,local.app
ALLOWED_HOSTS=localhost,172.17.0.1,host.docker.internal

Expand Down
4 changes: 1 addition & 3 deletions api/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
if __name__ == "__main__":
is_local = os.getenv("ENVIRONMENT") == "local"

app = "conf.asgi:static_files_application" if is_local else "conf.asgi:application"

uvicorn.run(
app,
"conf.asgi:application",
host="0.0.0.0",
port=8000,
workers=1,
Expand Down
30 changes: 30 additions & 0 deletions api/test/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest
from asgiref.sync import async_to_sync

from conf.asgi import application


@pytest.fixture(scope="session", autouse=True)
def ensure_asgi_lifecycle():
"""
Call application shutdown lifecycle event.
This cannot be an async fixture because the scope is session
and pytest-asynio's `event_loop` fixture, which is auto-used
for async tests and fixtures, is function scoped, which is
incomatible with session scoped fixtures. `async_to_sync` works
fine here, so it's not a problem.
This cannot yet call the startup signal due to:
https://github.com/illagrenan/django-asgi-lifespan/pull/80
"""
scope = {"type": "lifespan"}

async def noop(*args, **kwargs):
...

async def shutdown():
return {"type": "lifespan.shutdown"}

yield
async_to_sync(application)(scope, shutdown, noop)
52 changes: 52 additions & 0 deletions api/test/unit/utils/test_aiohttp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import asyncio

import pytest

from api.utils.aiohttp import get_aiohttp_session


@pytest.fixture(autouse=True)
def get_new_loop():
loops: list[asyncio.AbstractEventLoop] = []

def _get_new_loop():
loop = asyncio.new_event_loop()
loops.append(loop)
return loop

yield _get_new_loop

for loop in loops:
loop.close()


def test_reuses_session_within_same_loop(get_new_loop):
loop = get_new_loop()

session_1 = loop.run_until_complete(get_aiohttp_session())
session_2 = loop.run_until_complete(get_aiohttp_session())

assert session_1 is session_2


def test_creates_new_session_for_separate_loops(get_new_loop):
loop_1 = get_new_loop()
loop_2 = get_new_loop()

loop_1_session = loop_1.run_until_complete(get_aiohttp_session())
loop_2_session = loop_2.run_until_complete(get_aiohttp_session())

assert loop_1_session is not loop_2_session


def test_multiple_loops_reuse_separate_sessions(get_new_loop):
loop_1 = get_new_loop()
loop_2 = get_new_loop()

loop_1_session_1 = loop_1.run_until_complete(get_aiohttp_session())
loop_1_session_2 = loop_1.run_until_complete(get_aiohttp_session())
loop_2_session_1 = loop_2.run_until_complete(get_aiohttp_session())
loop_2_session_2 = loop_2.run_until_complete(get_aiohttp_session())

assert loop_1_session_1 is loop_1_session_2
assert loop_2_session_1 is loop_2_session_2
19 changes: 7 additions & 12 deletions api/test/unit/utils/test_check_dead_links.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
import asyncio
from unittest import mock

import aiohttp
import pook
import pytest
from aiohttp.client import ClientSession

from api.utils.check_dead_links import HEADERS, check_dead_links


@mock.patch.object(aiohttp, "ClientSession", wraps=aiohttp.ClientSession)
@pook.on
def test_sends_user_agent(wrapped_client_session: mock.AsyncMock):
def test_sends_user_agent():
query_hash = "test_sends_user_agent"
results = [{"provider": "best_provider_ever"} for _ in range(40)]
image_urls = [f"https://example.org/{i}" for i in range(len(results))]
start_slice = 0

head_mock = (
pook.head(pook.regex(r"https://example.org/\d"))
.headers(HEADERS)
.times(len(results))
.reply(200)
.mock
Expand All @@ -30,10 +29,8 @@ def test_sends_user_agent(wrapped_client_session: mock.AsyncMock):
for url in image_urls:
assert url in requested_urls

wrapped_client_session.assert_called_once_with(headers=HEADERS, timeout=mock.ANY)


def test_handles_timeout():
def test_handles_timeout(monkeypatch):
"""
Test that case where timeout occurs.
Expand All @@ -45,13 +42,11 @@ def test_handles_timeout():
image_urls = [f"https://example.org/{i}" for i in range(len(results))]
start_slice = 0

def raise_timeout_error(*args, **kwargs):
async def raise_timeout_error(*args, **kwargs):
raise asyncio.TimeoutError()

with mock.patch(
"aiohttp.client.ClientSession._request", side_effect=raise_timeout_error
):
check_dead_links(query_hash, start_slice, results, image_urls)
monkeypatch.setattr(ClientSession, "_request", raise_timeout_error)
check_dead_links(query_hash, start_slice, results, image_urls)

# `check_dead_links` directly modifies the results list
# if the results are timing out then they're considered dead and discarded
Expand Down

0 comments on commit 1de2d21

Please sign in to comment.