diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 405c7fef..c4f561e5 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -5360,9 +5360,11 @@ def _unfinished_handler_warning_cls(self) -> Type: @workflow.defn -class UnfinishedHandlersWithCancellationWorkflow: +class UnfinishedHandlersWithCancellationOrFailureWorkflow: @workflow.run - async def run(self) -> NoReturn: + async def run(self, workflow_termination_type: Literal["cancellation", "failure"]) -> NoReturn: + if workflow_termination_type == "failure": + raise ApplicationError("Deliberately failing workflow with an unfinished handler") await workflow.wait_condition(lambda: False) @workflow.update @@ -5376,23 +5378,36 @@ async def my_signal(self): async def test_unfinished_update_handler_with_workflow_cancellation(client: Client): - await _UnfinishedHandlersWithCancellationTest( - client, "update" - ).test_warning_is_issued_when_cancellation_causes_exit_with_unfinished_handler() + await _UnfinishedHandlersWithCancellationOrFailureTest( + client, "update", "cancellation", + ).test_warning_is_issued_when_cancellation_or_failure_causes_exit_with_unfinished_handler() async def test_unfinished_signal_handler_with_workflow_cancellation(client: Client): - await _UnfinishedHandlersWithCancellationTest( - client, "signal" - ).test_warning_is_issued_when_cancellation_causes_exit_with_unfinished_handler() + await _UnfinishedHandlersWithCancellationOrFailureTest( + client, "signal", "cancellation", + ).test_warning_is_issued_when_cancellation_or_failure_causes_exit_with_unfinished_handler() + + +async def test_unfinished_update_handler_with_workflow_failure(client: Client): + await _UnfinishedHandlersWithCancellationOrFailureTest( + client, "update", "failure", + ).test_warning_is_issued_when_cancellation_or_failure_causes_exit_with_unfinished_handler() + + +async def test_unfinished_signal_handler_with_workflow_failure(client: Client): + await _UnfinishedHandlersWithCancellationOrFailureTest( + client, "signal", "failure", + ).test_warning_is_issued_when_cancellation_or_failure_causes_exit_with_unfinished_handler() @dataclass -class _UnfinishedHandlersWithCancellationTest: +class _UnfinishedHandlersWithCancellationOrFailureTest: client: Client handler_type: Literal["update", "signal"] + workflow_termination_type: Literal["cancellation", "failure"] - async def test_warning_is_issued_when_cancellation_causes_exit_with_unfinished_handler( + async def test_warning_is_issued_when_cancellation_or_failure_causes_exit_with_unfinished_handler( self, ): assert await self._run_workflow_and_get_warning() @@ -5402,31 +5417,33 @@ async def _run_workflow_and_get_warning(self) -> bool: update_id = "update-id" task_queue = "tq" - # We require a cancellation request and an update to be delivered in the same WFT. To do - # this we send the start, cancel, and update/signal requests, and then start the worker - # after they've all been accepted by the server. + # We require a startWorkflow, an update, and maybe a cancellation request, to be delivered + # in the same WFT. To do this we start the worker after they've all been accepted by the + # server. handle = await self.client.start_workflow( - UnfinishedHandlersWithCancellationWorkflow.run, + UnfinishedHandlersWithCancellationOrFailureWorkflow.run, + self.workflow_termination_type, id=workflow_id, task_queue=task_queue, ) - await handle.cancel() + if self.workflow_termination_type == "cancellation": + await handle.cancel() if self.handler_type == "update": update_task = asyncio.create_task( handle.execute_update( - UnfinishedHandlersWithCancellationWorkflow.my_update, id=update_id + UnfinishedHandlersWithCancellationOrFailureWorkflow.my_update, id=update_id ) ) await assert_eq_eventually( True, lambda: workflow_update_exists(self.client, workflow_id, update_id) ) else: - await handle.signal(UnfinishedHandlersWithCancellationWorkflow.my_signal) + await handle.signal(UnfinishedHandlersWithCancellationOrFailureWorkflow.my_signal) async with new_worker( self.client, - UnfinishedHandlersWithCancellationWorkflow, + UnfinishedHandlersWithCancellationOrFailureWorkflow, task_queue=task_queue, ): with pytest.WarningsRecorder() as warnings: