Skip to content

Commit

Permalink
Merge pull request #1124 from lsst/tickets/DM-47889
Browse files Browse the repository at this point in the history
DM-47889: Prevent DB connection pool exhaustion in Butler server
  • Loading branch information
dhirving authored Dec 5, 2024
2 parents f8303a5 + 8ea635f commit d632886
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 62 deletions.
47 changes: 22 additions & 25 deletions python/lsst/daf/butler/registry/databases/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,37 +163,34 @@ def makeEngine(
# multiple threads simultaneously. So we need to configure
# SQLAlchemy to pool connections for multi-threaded usage.
#
# This is not the maximum number of active connections --
# SQLAlchemy allows some additional overflow configured via the
# max_overflow parameter. pool_size is only the maximum number
# saved in the pool during periods of lower concurrency.
# This pool size was chosen to work well for services using
# FastAPI. FastAPI uses a thread pool of 40 by default, so this
# gives us a connection for each thread in the pool. Because Butler
# is currently sync-only, we won't ever be executing more queries
# than we have threads.
#
# This specific value for pool size was chosen somewhat arbitrarily
# -- there has not been any formal testing done to profile database
# concurrency. The value chosen may be somewhat lower than is
# optimal for service use cases. Some considerations:
# Connections are only created as they are needed, so in typical
# single-threaded Butler use only one connection will ever be
# created. Services with low peak concurrency may never create this
# many connections.
#
# 1. Connections are only created as they are needed, so in typical
# single-threaded Butler use only one connection will ever be
# created. Services with low peak concurrency may never create
# this many connections.
# 2. Most services using the Butler (including Butler
# server) are using FastAPI, which uses a thread pool of 40 by
# default. So when running at max concurrency we may have:
# * 10 connections checked out from the pool
# * 10 "overflow" connections re-created each time they are
# used.
# * 20 threads queued up, waiting for a connection, and
# potentially timing out if the other threads don't release
# their connections in a timely manner.
# 3. The main Butler databases at SLAC are run behind pgbouncer,
# so we can support a larger number of simultaneous connections
# than if we were connecting directly to Postgres.
# The main Butler databases at SLAC are run behind pgbouncer, so we
# can support a larger number of simultaneous connections than if
# we were connecting directly to Postgres.
#
# See
# https://docs.sqlalchemy.org/en/20/core/pooling.html#sqlalchemy.pool.QueuePool.__init__
# for more information on the behavior of this parameter.
pool_size=10,
pool_size=40,
# If we are experiencing heavy enough load that we overflow the
# connection pool, it will be harmful to start creating extra
# connections that we disconnect immediately after use.
# Connecting from scratch is fairly expensive, which is why we have
# a pool in the first place.
max_overflow=0,
# If the pool is full, this is the maximum number of seconds we
# will wait for a connection to become available before giving up.
pool_timeout=60,
# In combination with pool_pre_ping, prevent SQLAlchemy from
# unnecessarily reviving pooled connections that have gone stale.
# Setting this to true makes it always re-use the most recent
Expand Down
46 changes: 44 additions & 2 deletions python/lsst/daf/butler/remote_butler/_http_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

__all__ = ("RemoteButlerHttpConnection", "parse_model")

import time
import urllib.parse
from collections.abc import Iterator, Mapping
from contextlib import contextmanager
Expand Down Expand Up @@ -210,7 +211,7 @@ def _send_request(self, request: _Request) -> httpx.Response:
with the message as a subclass of ButlerUserError.
"""
try:
response = self._client.send(request.request)
response = self._send_with_retries(request, stream=False)
self._handle_http_status(response, request.request_id)
return response
except httpx.HTTPError as e:
Expand All @@ -219,7 +220,7 @@ def _send_request(self, request: _Request) -> httpx.Response:
@contextmanager
def _send_request_with_stream_response(self, request: _Request) -> Iterator[httpx.Response]:
try:
response = self._client.send(request.request, stream=True)
response = self._send_with_retries(request, stream=True)
try:
self._handle_http_status(response, request.request_id)
yield response
Expand All @@ -228,6 +229,21 @@ def _send_request_with_stream_response(self, request: _Request) -> Iterator[http
except httpx.HTTPError as e:
raise ButlerServerError(request.request_id) from e

def _send_with_retries(self, request: _Request, stream: bool) -> httpx.Response:
max_retry_time_seconds = 120
start_time = time.time()
while True:
response = self._client.send(request.request, stream=stream)
retry = _needs_retry(response)
time_remaining = max_retry_time_seconds - (time.time() - start_time)
if retry.retry and time_remaining > 0:
if stream:
response.close()
sleep_time = min(time_remaining, retry.delay_seconds)
time.sleep(sleep_time)
else:
return response

def _handle_http_status(self, response: httpx.Response, request_id: str) -> None:
if response.status_code == ERROR_STATUS_CODE:
# Raise an exception that the server has forwarded to the
Expand All @@ -245,6 +261,32 @@ def _handle_http_status(self, response: httpx.Response, request_id: str) -> None
response.raise_for_status()


@dataclass(frozen=True)
class _Retry:
retry: bool
delay_seconds: int


def _needs_retry(response: httpx.Response) -> _Retry:
# Handle a 503 Service Unavailable, sent by the server if it is
# overloaded, or a 429, sent by the server if the client
# triggers a rate limit.
if response.status_code == 503 or response.status_code == 429:
# Only retry if the server has instructed us to do so by sending a
# Retry-After header.
retry_after = response.headers.get("retry-after")
if retry_after is not None:
try:
# The HTTP standard also allows a date string here, but the
# Butler server only sends integer seconds.
delay_seconds = int(retry_after)
return _Retry(True, delay_seconds)
except ValueError:
pass

return _Retry(False, 0)


def parse_model(response: httpx.Response, model: type[_AnyPydanticModel]) -> _AnyPydanticModel:
"""Deserialize a Pydantic model from the body of an HTTP response.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ async def query_execute(
request: QueryExecuteRequestModel, factory: Factory = Depends(factory_dependency)
) -> StreamingResponse:
query = _StreamQueryDriverExecute(request, factory)
return execute_streaming_query(query)
return await execute_streaming_query(query)


class _QueryAllDatasetsContext(NamedTuple):
Expand Down Expand Up @@ -136,7 +136,7 @@ async def query_all_datasets_execute(
request: QueryAllDatasetsRequestModel, factory: Factory = Depends(factory_dependency)
) -> StreamingResponse:
query = _StreamQueryAllDatasets(request, factory)
return execute_streaming_query(query)
return await execute_streaming_query(query)


@query_router.post(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from contextlib import AbstractContextManager
from typing import Protocol, TypeVar

from fastapi import HTTPException
from fastapi.concurrency import contextmanager_in_threadpool, iterate_in_threadpool
from fastapi.responses import StreamingResponse
from lsst.daf.butler.remote_butler.server_models import (
Expand All @@ -43,11 +44,26 @@
from ...._exceptions import ButlerUserError
from ..._errors import serialize_butler_user_error

# Restrict the maximum number of streaming queries that can be running
# simultaneously, to prevent the database connection pool and the thread pool
# from being tied up indefinitely. Beyond this number, the server will return
# an HTTP 503 Service Unavailable with a Retry-After header. We are currently
# using the default FastAPI thread pool size of 40 (total) and have 40 maximum
# database connections (per Butler repository.)
_MAXIMUM_CONCURRENT_STREAMING_QUERIES = 25
# How long we ask callers to wait before trying their query again.
# The hope is that they will bounce to a less busy replica, so we don't want
# them to wait too long.
_QUERY_RETRY_SECONDS = 5

# Alias this function so we can mock it during unit tests.
_timeout = asyncio.timeout

_TContext = TypeVar("_TContext")

# Count of active streaming queries.
_current_streaming_queries = 0


class StreamingQuery(Protocol[_TContext]):
"""Interface for queries that can return streaming results."""
Expand All @@ -67,7 +83,7 @@ def execute(self, context: _TContext) -> Iterator[QueryExecuteResultData]:
"""


def execute_streaming_query(query: StreamingQuery) -> StreamingResponse:
async def execute_streaming_query(query: StreamingQuery) -> StreamingResponse:
"""Run a query, streaming the response incrementally, one page at a time,
as newline-separated chunks of JSON.
Expand Down Expand Up @@ -95,6 +111,22 @@ def execute_streaming_query(query: StreamingQuery) -> StreamingResponse:
read -- ``StreamingQuery.execute()`` cannot be interrupted while it is
in the middle of reading a page.
"""
# Prevent an excessive number of streaming queries from jamming up the
# thread pool and database connection pool. We can't change the response
# code after starting the StreamingResponse, so we enforce this here.
#
# This creates a small chance that more than the expected number of
# streaming queries will be started, but there is no guarantee that the
# StreamingResponse generator function will ever be called, so we can't
# guarantee that we release the slot if we reserve one here.
if _current_streaming_queries >= _MAXIMUM_CONCURRENT_STREAMING_QUERIES:
await _block_retry_for_unit_test()
raise HTTPException(
status_code=503, # service temporarily unavailable
detail="The Butler Server is currently overloaded with requests.",
headers={"retry-after": str(_QUERY_RETRY_SECONDS)},
)

output_generator = _stream_query_pages(query)
return StreamingResponse(
output_generator,
Expand All @@ -115,17 +147,24 @@ async def _stream_query_pages(query: StreamingQuery) -> AsyncIterator[str]:
When it takes longer than 15 seconds to get a response from the DB,
sends a keep-alive message to prevent clients from timing out.
"""
# `None` signals that there is no more data to send.
queue = asyncio.Queue[QueryExecuteResultData | None](1)
async with asyncio.TaskGroup() as tg:
# Run a background task to read from the DB and insert the result pages
# into a queue.
tg.create_task(_enqueue_query_pages(queue, query))
# Read the result pages from the queue and send them to the client,
# inserting a keep-alive message every 15 seconds if we are waiting a
# long time for the database.
async for message in _dequeue_query_pages_with_keepalive(queue):
yield message.model_dump_json() + "\n"
global _current_streaming_queries
try:
_current_streaming_queries += 1
await _block_query_for_unit_test()

# `None` signals that there is no more data to send.
queue = asyncio.Queue[QueryExecuteResultData | None](1)
async with asyncio.TaskGroup() as tg:
# Run a background task to read from the DB and insert the result
# pages into a queue.
tg.create_task(_enqueue_query_pages(queue, query))
# Read the result pages from the queue and send them to the client,
# inserting a keep-alive message every 15 seconds if we are waiting
# a long time for the database.
async for message in _dequeue_query_pages_with_keepalive(queue):
yield message.model_dump_json() + "\n"
finally:
_current_streaming_queries -= 1


async def _enqueue_query_pages(
Expand Down Expand Up @@ -163,3 +202,17 @@ async def _dequeue_query_pages_with_keepalive(
yield message
except TimeoutError:
yield QueryKeepAliveModel()


async def _block_retry_for_unit_test() -> None:
"""Will be overridden during unit tests to block the server,
in order to verify retry logic.
"""
pass


async def _block_query_for_unit_test() -> None:
"""Will be overridden during unit tests to block the server,
in order to verify maximum concurrency logic.
"""
pass
43 changes: 23 additions & 20 deletions python/lsst/daf/butler/tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,26 +122,29 @@ def create_test_server(
server_butler_factory._preload_direct_butler_cache = False
app.dependency_overrides[butler_factory_dependency] = lambda: server_butler_factory

client = TestClient(app)
client_without_error_propagation = TestClient(app, raise_server_exceptions=False)

remote_butler = _make_remote_butler(client)
remote_butler_without_error_propagation = _make_remote_butler(
client_without_error_propagation
)

direct_butler = Butler.from_config(config_file_path, writeable=True)
assert isinstance(direct_butler, DirectButler)
hybrid_butler = HybridButler(remote_butler, direct_butler)

yield TestServerInstance(
config_file_path=config_file_path,
client=client,
direct_butler=direct_butler,
remote_butler=remote_butler,
remote_butler_without_error_propagation=remote_butler_without_error_propagation,
hybrid_butler=hybrid_butler,
)
# Using TestClient in a context manager ensures that it uses
# the same async event loop for all requests -- otherwise it
# starts a new one on each request.
with TestClient(app) as client:
remote_butler = _make_remote_butler(client)

direct_butler = Butler.from_config(config_file_path, writeable=True)
assert isinstance(direct_butler, DirectButler)
hybrid_butler = HybridButler(remote_butler, direct_butler)

client_without_error_propagation = TestClient(app, raise_server_exceptions=False)
remote_butler_without_error_propagation = _make_remote_butler(
client_without_error_propagation
)

yield TestServerInstance(
config_file_path=config_file_path,
client=client,
direct_butler=direct_butler,
remote_butler=remote_butler,
remote_butler_without_error_propagation=remote_butler_without_error_propagation,
hybrid_butler=hybrid_butler,
)


def _make_remote_butler(client: TestClient) -> RemoteButler:
Expand Down
Loading

0 comments on commit d632886

Please sign in to comment.