diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index bb2c0523..bc64fe92 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -429,15 +429,20 @@ def activate( f"Failed converting activation exception: {inner_err}" ) - def is_completion(command): + def is_non_cancellation_completion(command): return ( command.HasField("complete_workflow_execution") or command.HasField("continue_as_new_workflow_execution") or command.HasField("fail_workflow_execution") - or command.HasField("cancel_workflow_execution") ) - if any(map(is_completion, self._current_completion.successful.commands)): + # We do also warn in the case of workflow cancellation, but this is done + # when handling the workflow cancellation, since we also cancel update + # handlers at that time. + if any( + is_non_cancellation_completion(c) + for c in self._current_completion.successful.commands + ): self._warn_if_unfinished_handlers() return self._current_completion @@ -1851,6 +1856,7 @@ async def _run_top_level_workflow_function(self, coro: Awaitable[None]) -> None: err ): self._add_command().cancel_workflow_execution.SetInParent() + self._warn_if_unfinished_handlers() # Cancel update tasks, so that the update caller receives an # update failed error. We do not currently cancel signal tasks # since (a) doing so would require a workflow flag and (b) the diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 052a28bf..4030e03f 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -5584,9 +5584,16 @@ class _UnfinishedHandlersOnWorkflowTerminationTest: async def test_warning_is_issued_on_exit_with_unfinished_handler( self, ): - assert await self._run_workflow_and_get_warning() == ( - self.handler_waiting == "-no-wait-all-handlers-finish-" - ) + warning_emitted = await self._run_workflow_and_get_warning() + if self.workflow_termination_type == "-cancellation-": + # All paths through this test for which the workflow is cencalled result + # in the warning being emitted. + assert warning_emitted + else: + # Otherwise, the warning is emitted iff the workflow does not wait for handlers to finish. + assert warning_emitted == ( + self.handler_waiting == "-no-wait-all-handlers-finish-" + ) async def _run_workflow_and_get_warning(self) -> bool: workflow_id = f"wf-{uuid.uuid4()}"