From cf5f4748886f19858888ee02c0fefeef112737d3 Mon Sep 17 00:00:00 2001 From: Bogdan Markov Date: Fri, 20 Dec 2024 10:45:04 +0100 Subject: [PATCH] Cancel queue get in case of cancel of parent. --- tests/aio/test_session_pool.py | 18 ++++++++++++++---- ydb/aio/table.py | 9 ++++++++- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/tests/aio/test_session_pool.py b/tests/aio/test_session_pool.py index c2875ba3..50b003e3 100644 --- a/tests/aio/test_session_pool.py +++ b/tests/aio/test_session_pool.py @@ -33,11 +33,21 @@ async def test_waiter_is_notified(driver): @pytest.mark.asyncio async def test_no_race_after_future_cancel(driver): + async def first_session(pool: ydb.aio.SessionPool): + s = await pool.acquire() + await asyncio.sleep(0.003) + await pool.release(s) + + async def second_session(pool: ydb.aio.SessionPool): + await asyncio.sleep(0.001) + waiter = asyncio.ensure_future(pool.acquire()) + await asyncio.sleep(0.001) + waiter.cancel() + pool = ydb.aio.SessionPool(driver, 1) - s = await pool.acquire() - waiter = asyncio.ensure_future(pool.acquire()) - waiter.cancel() - await pool.release(s) + await asyncio.gather(first_session(pool), second_session(pool)) + + assert pool._active_queue.qsize() == 1 s = await pool.acquire() assert s.initialized() await pool.stop() diff --git a/ydb/aio/table.py b/ydb/aio/table.py index aec32e1a..a5240847 100644 --- a/ydb/aio/table.py +++ b/ydb/aio/table.py @@ -563,7 +563,14 @@ async def _prepare_session(self, timeout, retry_num) -> ydb.ISession: async def _get_session_from_queue(self, timeout: float): task_wait = asyncio.ensure_future(asyncio.wait_for(self._active_queue.get(), timeout=timeout)) task_should_stop = asyncio.ensure_future(self._should_stop.wait()) - done, _ = await asyncio.wait((task_wait, task_should_stop), return_when=asyncio.FIRST_COMPLETED) + try: + done, _ = await asyncio.wait((task_wait, task_should_stop), return_when=asyncio.FIRST_COMPLETED) + except asyncio.CancelledError as exc: + cancelled = task_wait.cancel() + if not cancelled: + priority, session = task_wait.result() + self._active_queue.put_nowait((priority, session)) + raise exc if task_should_stop in done: task_wait.cancel() return self._create()