From 853a01e21147ca6da4de1aaabbf72b7914f08855 Mon Sep 17 00:00:00 2001 From: Alexander Streed Date: Tue, 6 Aug 2024 10:39:06 -0500 Subject: [PATCH] Enable tasks to be submitted and mapped from tasks submitted to Dask or Ray (#14829) --- src/prefect/context.py | 3 ++- tests/test_context.py | 44 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/src/prefect/context.py b/src/prefect/context.py index e2a8e99ca2d2..66664c6b5d79 100644 --- a/src/prefect/context.py +++ b/src/prefect/context.py @@ -91,11 +91,12 @@ def hydrated_context( client = client or get_client(sync_client=True) if flow_run_context := serialized_context.get("flow_run_context"): flow = flow_run_context["flow"] + task_runner = stack.enter_context(flow.task_runner.duplicate()) flow_run_context = FlowRunContext( **flow_run_context, client=client, result_factory=run_coro_as_sync(ResultFactory.from_flow(flow)), - task_runner=flow.task_runner.duplicate(), + task_runner=task_runner, detached=True, ) stack.enter_context(flow_run_context) diff --git a/tests/test_context.py b/tests/test_context.py index da88fc170ea7..a10e3d83dda7 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -7,6 +7,7 @@ import prefect.settings from prefect import flow, task +from prefect.client.orchestration import PrefectClient from prefect.context import ( GLOBAL_SETTINGS_CONTEXT, ContextModel, @@ -36,6 +37,7 @@ save_profiles, temporary_settings, ) +from prefect.states import Running from prefect.task_runners import ThreadPoolTaskRunner @@ -537,6 +539,48 @@ def foo(): assert isinstance(hydrated_flow_run_context.start_time, DateTime) assert hydrated_flow_run_context.parameters == {"x": "y"} + async def test_task_runner_started_when_hydrating_context( + self, prefect_client: PrefectClient + ): + """ + This test ensures the task runner for a flow run context is started when + the context is hydrated. This enables calling .submit and .map on tasks + running in remote environments like Dask and Ray. + + Regression test for https://github.com/PrefectHQ/prefect/issues/14788 + """ + + @flow + def foo(): + pass + + @task + def bar(): + return 42 + + test_task_runner = ThreadPoolTaskRunner() + flow_run = await prefect_client.create_flow_run(foo, state=Running()) + result_factory = await ResultFactory.from_flow(foo) + flow_run_context = FlowRunContext( + flow=foo, + flow_run=flow_run, + client=prefect_client, + task_runner=test_task_runner, + result_factory=result_factory, + parameters={"x": "y"}, + ) + + with hydrated_context( + { + "flow_run_context": flow_run_context.serialize(), + } + ): + hydrated_flow_run_context = FlowRunContext.get() + assert hydrated_flow_run_context + + future = hydrated_flow_run_context.task_runner.submit(bar, parameters={}) + assert future.result() == 42 + async def test_with_task_run_context(self, prefect_client, flow_run): @task def bar():