From 1c3f9d8e7af068fd94bdeff9e896e8f184f3dbdf Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Mon, 30 Dec 2024 04:30:59 +0000 Subject: [PATCH] feat: re-enable safe quote --- litestar/channels/backends/psycopg.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/litestar/channels/backends/psycopg.py b/litestar/channels/backends/psycopg.py index 4f96c29a33..6ee7fe74b7 100644 --- a/litestar/channels/backends/psycopg.py +++ b/litestar/channels/backends/psycopg.py @@ -9,6 +9,10 @@ from .base import ChannelsBackend +def _safe_quote(ident: str) -> str: + return '"{}"'.format(ident.replace('"', '""')) # sourcery skip + + class PsycoPgChannelsBackend(ChannelsBackend): _listener_conn: AsyncConnection[Any] @@ -32,13 +36,13 @@ async def publish(self, data: bytes, channels: Iterable[str]) -> None: async def subscribe(self, channels: Iterable[str]) -> None: for channel in set(channels) - self._subscribed_channels: - await self._listener_conn.execute(SQL("LISTEN {}").format(Identifier(channel))) + await self._listener_conn.execute(SQL("LISTEN {}").format(Identifier(_safe_quote(channel)))) self._subscribed_channels.add(channel) async def unsubscribe(self, channels: Iterable[str]) -> None: for channel in channels: - await self._listener_conn.execute(SQL("UNLISTEN {}").format(Identifier(channel))) + await self._listener_conn.execute(SQL("UNLISTEN {}").format(Identifier(_safe_quote(channel)))) self._subscribed_channels = self._subscribed_channels - set(channels) async def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]: