Skip to content

Commit

Permalink
Enable tasks to be submitted and mapped from tasks submitted to Dask …
Browse files Browse the repository at this point in the history
…or Ray (#14829)
  • Loading branch information
desertaxle authored Aug 6, 2024
1 parent 2f4f681 commit 853a01e
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/prefect/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,12 @@ def hydrated_context(
client = client or get_client(sync_client=True)
if flow_run_context := serialized_context.get("flow_run_context"):
flow = flow_run_context["flow"]
task_runner = stack.enter_context(flow.task_runner.duplicate())
flow_run_context = FlowRunContext(
**flow_run_context,
client=client,
result_factory=run_coro_as_sync(ResultFactory.from_flow(flow)),
task_runner=flow.task_runner.duplicate(),
task_runner=task_runner,
detached=True,
)
stack.enter_context(flow_run_context)
Expand Down
44 changes: 44 additions & 0 deletions tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import prefect.settings
from prefect import flow, task
from prefect.client.orchestration import PrefectClient
from prefect.context import (
GLOBAL_SETTINGS_CONTEXT,
ContextModel,
Expand Down Expand Up @@ -36,6 +37,7 @@
save_profiles,
temporary_settings,
)
from prefect.states import Running
from prefect.task_runners import ThreadPoolTaskRunner


Expand Down Expand Up @@ -537,6 +539,48 @@ def foo():
assert isinstance(hydrated_flow_run_context.start_time, DateTime)
assert hydrated_flow_run_context.parameters == {"x": "y"}

async def test_task_runner_started_when_hydrating_context(
self, prefect_client: PrefectClient
):
"""
This test ensures the task runner for a flow run context is started when
the context is hydrated. This enables calling .submit and .map on tasks
running in remote environments like Dask and Ray.
Regression test for https://github.com/PrefectHQ/prefect/issues/14788
"""

@flow
def foo():
pass

@task
def bar():
return 42

test_task_runner = ThreadPoolTaskRunner()
flow_run = await prefect_client.create_flow_run(foo, state=Running())
result_factory = await ResultFactory.from_flow(foo)
flow_run_context = FlowRunContext(
flow=foo,
flow_run=flow_run,
client=prefect_client,
task_runner=test_task_runner,
result_factory=result_factory,
parameters={"x": "y"},
)

with hydrated_context(
{
"flow_run_context": flow_run_context.serialize(),
}
):
hydrated_flow_run_context = FlowRunContext.get()
assert hydrated_flow_run_context

future = hydrated_flow_run_context.task_runner.submit(bar, parameters={})
assert future.result() == 42

async def test_with_task_run_context(self, prefect_client, flow_run):
@task
def bar():
Expand Down

0 comments on commit 853a01e

Please sign in to comment.