From 2d5fb022c39bb2bae49de6f54501422b3ad114c9 Mon Sep 17 00:00:00 2001 From: Ryan Mulligan Date: Thu, 10 Oct 2024 18:13:55 -0700 Subject: [PATCH] [replit_river] return cleanup task from client.disconnect() (#99) Why === * The task created by the websocket wrapper was orphaned. Tasks need to be awaited somewhere, or you get errors like ``` RuntimeError: no running event loop Task was destroyed but it is pending! ``` * Since it takes a while to finish, we don't want to wait for it in certain cases, so instead we'll return it as a cleanup task that the caller can await as appropriate. What changed === * When the websocket close task is made, return it. * At every level, return the task and combine it with other cleanup tasks as appropriate Test plan === * The behavior shouldn't be different unless you await the cleanup function. If you do await it, you won't get a pending task exception when closing the event loop. --- replit_river/client.py | 6 ++++-- replit_river/client_transport.py | 4 ++-- replit_river/session.py | 19 +++++++++++++------ replit_river/transport.py | 13 +++++++++++-- replit_river/websocket_wrapper.py | 4 +++- 5 files changed, 33 insertions(+), 13 deletions(-) diff --git a/replit_river/client.py b/replit_river/client.py index fab2dc6..2d8fabc 100644 --- a/replit_river/client.py +++ b/replit_river/client.py @@ -1,3 +1,4 @@ +import asyncio import logging from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable from typing import Any, Generic, Optional, Union @@ -38,10 +39,11 @@ def __init__( transport_options=transport_options, ) - async def close(self) -> None: + async def close(self) -> asyncio.Task | None: logger.info(f"river client {self._client_id} start closing") - await self._transport.close() + cleanup_task = await self._transport.close() logger.info(f"river client {self._client_id} closed") + return cleanup_task async def ensure_connected(self) -> None: await self._transport.get_or_create_session() diff --git a/replit_river/client_transport.py b/replit_river/client_transport.py index 7ae62b2..13abc7a 100644 --- a/replit_river/client_transport.py +++ b/replit_river/client_transport.py @@ -68,9 +68,9 @@ def __init__( # We want to make sure there's only one session creation at a time self._create_session_lock = asyncio.Lock() - async def close(self) -> None: + async def close(self) -> asyncio.Task: self._rate_limiter.close() - await self._close_all_sessions() + return await self._close_all_sessions() async def get_or_create_session(self) -> ClientSession: async with self._create_session_lock: diff --git a/replit_river/session.py b/replit_river/session.py index 38d0972..5ed83f6 100644 --- a/replit_river/session.py +++ b/replit_river/session.py @@ -434,15 +434,18 @@ async def _send_responses_from_output_stream( async def close_websocket( self, ws_wrapper: WebsocketWrapper, should_retry: bool - ) -> None: + ) -> asyncio.Task | None: """Mark the websocket as closed, close the websocket, and retry if needed.""" + cleanup_websocket_task: asyncio.Task | None = None async with self._ws_lock: # Already closed. if not await ws_wrapper.is_open(): - return - await ws_wrapper.close() + logger.info("websocket wrapper already closed") + return cleanup_websocket_task + cleanup_websocket_task = await ws_wrapper.close() if should_retry and self._retry_connection_callback: self._task_manager.create_task(self._retry_connection_callback()) + return cleanup_websocket_task async def _open_stream_and_call_handler( self, @@ -523,8 +526,9 @@ async def _remove_acked_messages_in_buffer(self) -> None: async def start_serve_responses(self) -> None: self._task_manager.create_task(self.serve()) - async def close(self) -> None: + async def close(self) -> asyncio.Task | None: """Close the session and all associated streams.""" + cleanup_websocket_task: asyncio.Task | None = None logger.info( f"{self._transport_id} closing session " f"to {self._to_id}, ws: {self._ws_wrapper.id}, " @@ -533,12 +537,14 @@ async def close(self) -> None: async with self._state_lock: if self._state != SessionState.ACTIVE: # already closing - return + return cleanup_websocket_task self._state = SessionState.CLOSING self._reset_session_close_countdown() await self._task_manager.cancel_all_tasks() - await self.close_websocket(self._ws_wrapper, should_retry=False) + cleanup_websocket_task = await self.close_websocket( + self._ws_wrapper, should_retry=False + ) await self._buffer.close() @@ -553,3 +559,4 @@ async def close(self) -> None: self._streams.clear() self._state = SessionState.CLOSED + return cleanup_websocket_task diff --git a/replit_river/transport.py b/replit_river/transport.py index d4395be..7eeb14f 100644 --- a/replit_river/transport.py +++ b/replit_river/transport.py @@ -27,7 +27,8 @@ def __init__( self._handlers: Dict[Tuple[str, str], Tuple[str, GenericRpcHandler]] = {} self._session_lock = asyncio.Lock() - async def _close_all_sessions(self) -> None: + async def _close_all_sessions(self) -> asyncio.Task: + cleanup_tasks: list[asyncio.Task] = [] sessions = self._sessions.values() logger.info( f"start closing sessions {self._transport_id}, number sessions : " @@ -38,10 +39,18 @@ async def _close_all_sessions(self) -> None: # closing sessions requires access to the session lock, so we need to close # them one by one to be safe for session in sessions_to_close: - await session.close() + cleanup_task = await session.close() + if cleanup_task: + cleanup_tasks.append(cleanup_task) logger.info(f"Transport closed {self._transport_id}") + async def cleanup() -> None: + for cleanup_task in cleanup_tasks: + await cleanup_task + + return asyncio.create_task(cleanup()) + async def _delete_session(self, session: Session) -> None: async with self._session_lock: if session._to_id in self._sessions: diff --git a/replit_river/websocket_wrapper.py b/replit_river/websocket_wrapper.py index e929fb8..9a1958c 100644 --- a/replit_river/websocket_wrapper.py +++ b/replit_river/websocket_wrapper.py @@ -24,7 +24,7 @@ async def is_open(self) -> bool: async with self.ws_lock: return self.ws_state == WsState.OPEN - async def close(self) -> None: + async def close(self) -> asyncio.Task | None: async with self.ws_lock: if self.ws_state == WsState.OPEN: self.ws_state = WsState.CLOSING @@ -33,3 +33,5 @@ async def close(self) -> None: lambda _: logger.debug("old websocket %s closed.", self.ws.id) ) self.ws_state = WsState.CLOSED + return task + return None