From 91fc23f0555f8fe4737114e31362d5412bf9a58e Mon Sep 17 00:00:00 2001 From: Lucjan Dudek Date: Mon, 30 Sep 2024 11:11:02 +0200 Subject: [PATCH] Ensure pending tasks are cancelled (#159) * Ensure pending tasks are cancelled * Make execute_after_task attribute public --- golem/node/node.py | 25 ++++++++++++------- golem/resources/activity/activity.py | 3 ++- .../resources/pooling_batch/pooling_batch.py | 6 +++++ golem/utils/asyncio/queue.py | 10 +++++--- 4 files changed, 31 insertions(+), 13 deletions(-) diff --git a/golem/node/node.py b/golem/node/node.py index 2e6877ca..55cdd40a 100644 --- a/golem/node/node.py +++ b/golem/node/node.py @@ -127,14 +127,15 @@ async def start(self) -> None: await self.event_bus.emit(SessionStarted(self)) async def aclose(self) -> None: - await self.event_bus.emit(ShutdownStarted(self)) - self._set_no_more_children() - self._stop_event_collectors() - await self._close_autoclose_resources() - await self._close_apis() - await self.event_bus.emit(ShutdownFinished(self)) - - await self.event_bus.stop() + try: + await self.event_bus.emit(ShutdownStarted(self)) + self._set_no_more_children() + self._stop_event_collectors() + await self._close_autoclose_resources() + await self._close_apis() + await self.event_bus.emit(ShutdownFinished(self)) + finally: + await self.event_bus.stop() def _stop_event_collectors(self) -> None: demands = self.all_resources(Demand) @@ -159,6 +160,9 @@ async def _close_apis(self) -> None: async def _close_autoclose_resources(self) -> None: agreement_msg = "Work finished" + pooling_batch_tasks = [ + r.cleanup() for r in self._autoclose_resources if isinstance(r, PoolingBatch) + ] activity_tasks = [r.destroy() for r in self._autoclose_resources if isinstance(r, Activity)] agreement_tasks = [ r.terminate(agreement_msg) @@ -170,6 +174,8 @@ async def _close_autoclose_resources(self) -> None: r.release() for r in self._autoclose_resources if isinstance(r, Allocation) ] network_tasks = [r.remove() for r in self._autoclose_resources if isinstance(r, Network)] + if pooling_batch_tasks: + await asyncio.gather(*pooling_batch_tasks) if activity_tasks: await asyncio.gather(*activity_tasks) if agreement_tasks: @@ -396,7 +402,8 @@ async def add_to_network(self, network: Network, ip: Optional[str] = None) -> No await network.add_requestor_ip(ip) def add_autoclose_resource( - self, resource: Union["Allocation", "Demand", "Agreement", "Activity", "Network"] + self, + resource: Union["Allocation", "Demand", "Agreement", "Activity", "Network", "PoolingBatch"], ) -> None: self._autoclose_resources.add(resource) diff --git a/golem/resources/activity/activity.py b/golem/resources/activity/activity.py index bee184ae..a0d68bf9 100644 --- a/golem/resources/activity/activity.py +++ b/golem/resources/activity/activity.py @@ -106,6 +106,7 @@ async def execute( batch_id = await self.api.call_exec(self.id, script, _request_timeout=timeout) batch = PoolingBatch(self.node, batch_id) batch.start_collecting_events() + self._node.add_autoclose_resource(batch) self.add_child(batch) self.running_batch_counter += 1 return batch @@ -133,7 +134,7 @@ async def execute_after() -> None: await batch.wait(ignore_errors=True) await asyncio.gather(*[c.after() for c in commands]) - asyncio.create_task(execute_after()) + batch.execute_after_task = asyncio.create_task(execute_after()) return batch async def execute_script(self, script: "Script") -> PoolingBatch: diff --git a/golem/resources/pooling_batch/pooling_batch.py b/golem/resources/pooling_batch/pooling_batch.py index 3ff94a69..526a7a2a 100644 --- a/golem/resources/pooling_batch/pooling_batch.py +++ b/golem/resources/pooling_batch/pooling_batch.py @@ -12,6 +12,7 @@ CommandCancelled, CommandFailed, ) +from golem.utils.asyncio import ensure_cancelled from golem.utils.low import ActivityApi, YagnaEventCollector if TYPE_CHECKING: @@ -41,6 +42,7 @@ def __init__(self, node: "GolemNode", id_: str): self.finished_event = asyncio.Event() self._futures: Optional[List[asyncio.Future[models.ExeScriptCommandResult]]] = None + self.execute_after_task: Optional[asyncio.Task] = None @property def done(self) -> bool: @@ -89,6 +91,10 @@ async def wait( assert timeout_seconds is not None # mypy raise BatchTimeoutError(self, timeout_seconds) + async def cleanup(self): + if self.execute_after_task: + await ensure_cancelled(self.execute_after_task) + @property def events(self) -> List[models.ExeScriptCommandResult]: """Returns a list of results for this batch. diff --git a/golem/utils/asyncio/queue.py b/golem/utils/asyncio/queue.py index a2df2715..3251cdda 100644 --- a/golem/utils/asyncio/queue.py +++ b/golem/utils/asyncio/queue.py @@ -33,9 +33,13 @@ async def get(self) -> TQueueItem: error_task = asyncio.create_task(self._error_event.wait()) get_task = asyncio.create_task(super().get()) - done, pending = await asyncio.wait( - [error_task, get_task], return_when=asyncio.FIRST_COMPLETED - ) + try: + done, pending = await asyncio.wait( + [error_task, get_task], return_when=asyncio.FIRST_COMPLETED + ) + except asyncio.CancelledError: + await ensure_cancelled_many([error_task, get_task]) + raise await ensure_cancelled_many(pending)