Skip to content

Commit

Permalink
Improve consistency of sync_compatible when running flows in remote…
Browse files Browse the repository at this point in the history
… environments (#14660)
  • Loading branch information
desertaxle authored Jul 19, 2024
1 parent 15dee86 commit 96140de
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 7 deletions.
16 changes: 11 additions & 5 deletions src/prefect/utilities/asyncutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,11 +267,17 @@ async def run_sync_in_worker_thread(
Note that cancellation of threads will not result in interrupted computation, the
thread may continue running — the outcome will just be ignored.
"""
call = partial(__fn, *args, **kwargs)
result = await anyio.to_thread.run_sync(
call_with_mark, call, abandon_on_cancel=True, limiter=get_thread_limiter()
)
return result
# When running a sync function in a worker thread, we set this flag so that
# any root sync compatible functions will run as sync functions
token = RUNNING_ASYNC_FLAG.set(False)
try:
call = partial(__fn, *args, **kwargs)
result = await anyio.to_thread.run_sync(
call_with_mark, call, abandon_on_cancel=True, limiter=get_thread_limiter()
)
return result
finally:
RUNNING_ASYNC_FLAG.reset(token)


def call_with_mark(call):
Expand Down
7 changes: 6 additions & 1 deletion tests/deployment/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ async def test_initialize_project_with_docker_recipe_default_image(self, recipe)
class TestDiscoverFlows:
async def test_find_all_flows_in_dir_tree(self, project_dir):
flows = await _search_for_flow_functions(str(project_dir))
assert len(flows) == 6, f"Expected 6 flows, found {len(flows)}"
assert len(flows) == 7, f"Expected 7 flows, found {len(flows)}"

expected_flows = [
{
Expand Down Expand Up @@ -191,6 +191,11 @@ async def test_find_all_flows_in_dir_tree(self, project_dir):
project_dir / "import-project" / "my_module" / "flow.py"
),
},
{
"flow_name": "uses_block",
"function_name": "uses_block",
"filepath": str(project_dir / "flows" / "uses_block.py"),
},
]

for flow in flows:
Expand Down
14 changes: 14 additions & 0 deletions tests/test-projects/flows/uses_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import uuid

from prefect import flow
from prefect.blocks.system import Secret

block_name = f"foo-{uuid.uuid4()}"
Secret(value="bar").save("foo")

my_secret = Secret.load("foo")


@flow
async def uses_block():
return my_secret.get()
34 changes: 33 additions & 1 deletion tests/test_flow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pydantic
import pytest

from prefect import Flow, flow, task
from prefect import Flow, __development_base_path__, flow, task
from prefect._internal.compatibility.experimental import ExperimentalFeature
from prefect.client.orchestration import PrefectClient, SyncPrefectClient
from prefect.client.schemas.filters import FlowFilter, FlowRunFilter
Expand All @@ -37,6 +37,7 @@
from prefect.logging import get_run_logger
from prefect.server.schemas.core import FlowRun as ServerFlowRun
from prefect.utilities.callables import get_call_parameters
from prefect.utilities.filesystem import tmpchdir


@flow
Expand Down Expand Up @@ -1730,3 +1731,34 @@ def g(required: str, model: TheModel = {"x": [1, 2, 3]}): # type: ignore
yield i

assert [i for i in g("hello")] == ["hello", 1, 2, 3]


class TestLoadFlowAndFlowRun:
async def test_load_flow_from_script_with_module_level_sync_compatible_call(
self, prefect_client: PrefectClient, tmp_path
):
"""
This test ensures that when a worker or runner loads a flow from a script, and
that script contains a module-level call to a sync-compatible function, the sync
compatible function is correctly runs as sync and does not prevent the flow from
being loaded.
Regression test for https://github.com/PrefectHQ/prefect/issues/14625
"""
flow_id = await prefect_client.create_flow_from_name(flow_name="uses_block")
deployment_id = await prefect_client.create_deployment(
flow_id=flow_id,
name="test-load-flow-from-script-with-module-level-sync-compatible-call",
path=str(__development_base_path__ / "tests" / "test-projects" / "flows"),
entrypoint="uses_block.py:uses_block",
)
api_flow_run = await prefect_client.create_flow_run_from_deployment(
deployment_id=deployment_id
)

with tmpchdir(tmp_path):
flow_run, flow = load_flow_and_flow_run(api_flow_run.id)

assert flow_run.id == api_flow_run.id

assert await flow() == "bar"

0 comments on commit 96140de

Please sign in to comment.