diff --git a/litestar/channels/backends/psycopg.py b/litestar/channels/backends/psycopg.py index 14b53bcd1a..c7651d030c 100644 --- a/litestar/channels/backends/psycopg.py +++ b/litestar/channels/backends/psycopg.py @@ -1,19 +1,16 @@ from __future__ import annotations from contextlib import AsyncExitStack -from typing import AsyncGenerator, Iterable +from typing import Any, AsyncGenerator, Iterable -import psycopg +from psycopg import AsyncConnection +from psycopg.sql import SQL, Identifier from .base import ChannelsBackend -def _safe_quote(ident: str) -> str: - return '"{}"'.format(ident.replace('"', '""')) # sourcery skip - - class PsycoPgChannelsBackend(ChannelsBackend): - _listener_conn: psycopg.AsyncConnection + _listener_conn: AsyncConnection[Any] def __init__(self, pg_dsn: str) -> None: self._pg_dsn = pg_dsn @@ -21,7 +18,7 @@ def __init__(self, pg_dsn: str) -> None: self._exit_stack = AsyncExitStack() async def on_startup(self) -> None: - self._listener_conn = await psycopg.AsyncConnection.connect(self._pg_dsn, autocommit=True) + self._listener_conn = await AsyncConnection[Any].connect(self._pg_dsn, autocommit=True) await self._exit_stack.enter_async_context(self._listener_conn) async def on_shutdown(self) -> None: @@ -29,21 +26,23 @@ async def on_shutdown(self) -> None: async def publish(self, data: bytes, channels: Iterable[str]) -> None: dec_data = data.decode("utf-8") - async with await psycopg.AsyncConnection.connect(self._pg_dsn) as conn: + async with await AsyncConnection[Any].connect(self._pg_dsn) as conn: for channel in channels: - await conn.execute("SELECT pg_notify(%s, %s);", (channel, dec_data)) + await conn.execute(SQL("SELECT pg_notify(%s, %s);").format(Identifier(channel), dec_data)) async def subscribe(self, channels: Iterable[str]) -> None: - for channel in set(channels) - self._subscribed_channels: - # can't use placeholders in LISTEN - await self._listener_conn.execute(f"LISTEN {_safe_quote(channel)};") # pyright: ignore + channels_to_subscribe = set(channels) - self._subscribed_channels + if not channels_to_subscribe: + return + + for channel in channels_to_subscribe: + await self._listener_conn.execute(SQL("LISTEN {}").format(Identifier(channel))) self._subscribed_channels.add(channel) async def unsubscribe(self, channels: Iterable[str]) -> None: for channel in channels: - # can't use placeholders in UNLISTEN - await self._listener_conn.execute(f"UNLISTEN {_safe_quote(channel)};") # pyright: ignore + await self._listener_conn.execute(SQL("UNLISTEN {}").format(Identifier(channel))) self._subscribed_channels = self._subscribed_channels - set(channels) async def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]: