Skip to content

Commit

Permalink
Ensure pending tasks are cancelled (#159)
Browse files Browse the repository at this point in the history
* Ensure pending tasks are cancelled

* Make execute_after_task attribute public
  • Loading branch information
lucekdudek authored Sep 30, 2024
1 parent 0a40ac5 commit 91fc23f
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 13 deletions.
25 changes: 16 additions & 9 deletions golem/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,15 @@ async def start(self) -> None:
await self.event_bus.emit(SessionStarted(self))

async def aclose(self) -> None:
await self.event_bus.emit(ShutdownStarted(self))
self._set_no_more_children()
self._stop_event_collectors()
await self._close_autoclose_resources()
await self._close_apis()
await self.event_bus.emit(ShutdownFinished(self))

await self.event_bus.stop()
try:
await self.event_bus.emit(ShutdownStarted(self))
self._set_no_more_children()
self._stop_event_collectors()
await self._close_autoclose_resources()
await self._close_apis()
await self.event_bus.emit(ShutdownFinished(self))
finally:
await self.event_bus.stop()

def _stop_event_collectors(self) -> None:
demands = self.all_resources(Demand)
Expand All @@ -159,6 +160,9 @@ async def _close_apis(self) -> None:

async def _close_autoclose_resources(self) -> None:
agreement_msg = "Work finished"
pooling_batch_tasks = [
r.cleanup() for r in self._autoclose_resources if isinstance(r, PoolingBatch)
]
activity_tasks = [r.destroy() for r in self._autoclose_resources if isinstance(r, Activity)]
agreement_tasks = [
r.terminate(agreement_msg)
Expand All @@ -170,6 +174,8 @@ async def _close_autoclose_resources(self) -> None:
r.release() for r in self._autoclose_resources if isinstance(r, Allocation)
]
network_tasks = [r.remove() for r in self._autoclose_resources if isinstance(r, Network)]
if pooling_batch_tasks:
await asyncio.gather(*pooling_batch_tasks)
if activity_tasks:
await asyncio.gather(*activity_tasks)
if agreement_tasks:
Expand Down Expand Up @@ -396,7 +402,8 @@ async def add_to_network(self, network: Network, ip: Optional[str] = None) -> No
await network.add_requestor_ip(ip)

def add_autoclose_resource(
self, resource: Union["Allocation", "Demand", "Agreement", "Activity", "Network"]
self,
resource: Union["Allocation", "Demand", "Agreement", "Activity", "Network", "PoolingBatch"],
) -> None:
self._autoclose_resources.add(resource)

Expand Down
3 changes: 2 additions & 1 deletion golem/resources/activity/activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ async def execute(
batch_id = await self.api.call_exec(self.id, script, _request_timeout=timeout)
batch = PoolingBatch(self.node, batch_id)
batch.start_collecting_events()
self._node.add_autoclose_resource(batch)
self.add_child(batch)
self.running_batch_counter += 1
return batch
Expand Down Expand Up @@ -133,7 +134,7 @@ async def execute_after() -> None:
await batch.wait(ignore_errors=True)
await asyncio.gather(*[c.after() for c in commands])

asyncio.create_task(execute_after())
batch.execute_after_task = asyncio.create_task(execute_after())
return batch

async def execute_script(self, script: "Script") -> PoolingBatch:
Expand Down
6 changes: 6 additions & 0 deletions golem/resources/pooling_batch/pooling_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
CommandCancelled,
CommandFailed,
)
from golem.utils.asyncio import ensure_cancelled
from golem.utils.low import ActivityApi, YagnaEventCollector

if TYPE_CHECKING:
Expand Down Expand Up @@ -41,6 +42,7 @@ def __init__(self, node: "GolemNode", id_: str):

self.finished_event = asyncio.Event()
self._futures: Optional[List[asyncio.Future[models.ExeScriptCommandResult]]] = None
self.execute_after_task: Optional[asyncio.Task] = None

@property
def done(self) -> bool:
Expand Down Expand Up @@ -89,6 +91,10 @@ async def wait(
assert timeout_seconds is not None # mypy
raise BatchTimeoutError(self, timeout_seconds)

async def cleanup(self):
if self.execute_after_task:
await ensure_cancelled(self.execute_after_task)

@property
def events(self) -> List[models.ExeScriptCommandResult]:
"""Returns a list of results for this batch.
Expand Down
10 changes: 7 additions & 3 deletions golem/utils/asyncio/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,13 @@ async def get(self) -> TQueueItem:

error_task = asyncio.create_task(self._error_event.wait())
get_task = asyncio.create_task(super().get())
done, pending = await asyncio.wait(
[error_task, get_task], return_when=asyncio.FIRST_COMPLETED
)
try:
done, pending = await asyncio.wait(
[error_task, get_task], return_when=asyncio.FIRST_COMPLETED
)
except asyncio.CancelledError:
await ensure_cancelled_many([error_task, get_task])
raise

await ensure_cancelled_many(pending)

Expand Down

0 comments on commit 91fc23f

Please sign in to comment.