From c39345b3421f4b6a320e279072fd4a12a4009072 Mon Sep 17 00:00:00 2001 From: zulissimeta <122578103+zulissimeta@users.noreply.github.com> Date: Thu, 19 Sep 2024 14:17:48 -0700 Subject: [PATCH] Small bugfix for async prefect functions (#2462) ## Summary of Changes This provides a small bug fix for the recent patch to how prefect flows are handled. The prior methods only worked for sync flows. ### Requirements - [x] My PR is focused on a [single feature addition or bugfix](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/getting-started/best-practices-for-pull-requests#write-small-prs). - [x] My PR has relevant, comprehensive [unit tests](https://quantum-accelerators.github.io/quacc/dev/contributing.html#unit-tests). - [x] My PR is on a [custom branch](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-and-deleting-branches-within-your-repository) (i.e. is _not_ named `main`). Note: If you are an external contributor, you will see a comment from [@buildbot-princeton](https://github.com/buildbot-princeton). This is solely for the maintainers. --------- Co-authored-by: Andrew S. Rosen Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- pyproject.toml | 1 + src/quacc/settings.py | 4 ++ src/quacc/wflow_tools/decorators.py | 54 +++++++++++++------- src/quacc/wflow_tools/prefect_utils.py | 71 ++++++++++++++++++++++++-- tests/prefect/test_prefect_utils.py | 55 ++++++++++++++++++++ tests/prefect/test_syntax.py | 38 ++++++++++++++ 6 files changed, 201 insertions(+), 22 deletions(-) create mode 100644 tests/prefect/test_prefect_utils.py diff --git a/pyproject.toml b/pyproject.toml index 6938df143f..7213bfe50c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,6 +83,7 @@ include = ["quacc"] exclude = ["**/__pycache__"] [tool.pytest.ini_options] +asyncio_mode = "auto" minversion = "6.0" addopts = ["-p no:warnings", "--import-mode=importlib"] xfail_strict = true diff --git a/src/quacc/settings.py b/src/quacc/settings.py index 327e44376c..8998edb8d1 100644 --- a/src/quacc/settings.py +++ b/src/quacc/settings.py @@ -138,6 +138,10 @@ class QuaccSettings(BaseSettings): PREFECT_AUTO_SUBMIT: bool = Field( True, description="Whether to auto-submit tasks to the task runner." ) + PREFECT_RESOLVE_FLOW_RESULTS: bool = Field( + True, + description="Whether to resolve all futures in flow results to data and fail if not possible", + ) # --------------------------- # ORCA Settings diff --git a/src/quacc/wflow_tools/decorators.py b/src/quacc/wflow_tools/decorators.py index 1290a34977..864d73d8a9 100644 --- a/src/quacc/wflow_tools/decorators.py +++ b/src/quacc/wflow_tools/decorators.py @@ -348,15 +348,7 @@ def workflow(a, b, c): return task(_func, namespace=_func.__module__, **kwargs) elif settings.WORKFLOW_ENGINE == "prefect": - from prefect import flow as prefect_flow - - from quacc.wflow_tools.prefect_utils import resolve_futures_to_results - - @wraps(_func) - def wrapper(*f_args, **f_kwargs): - return resolve_futures_to_results(_func(*f_args, **f_kwargs)) - - return prefect_flow(wrapper, validate_parameters=False, **kwargs) + return _get_prefect_wrapped_flow(_func, settings, **kwargs) else: return _func @@ -585,15 +577,7 @@ def wrapper(*f_args, **f_kwargs): return join_app(wrapped_fn, **kwargs) elif settings.WORKFLOW_ENGINE == "prefect": - from prefect import flow as prefect_flow - - from quacc.wflow_tools.prefect_utils import resolve_futures_to_results - - @wraps(_func) - def wrapper(*f_args, **f_kwargs): - return resolve_futures_to_results(_func(*f_args, **f_kwargs)) - - return prefect_flow(wrapper, validate_parameters=False, **kwargs) + return _get_prefect_wrapped_flow(_func, settings, **kwargs) elif settings.WORKFLOW_ENGINE == "redun": from redun import task @@ -643,6 +627,40 @@ def wrapper( return wrapper +def _get_prefect_wrapped_flow(_func, settings, **kwargs): + from prefect import flow as prefect_flow + from prefect.utilities.asyncutils import is_async_fn + + from quacc.wflow_tools.prefect_utils import ( + resolve_futures_to_results, + resolve_futures_to_results_async, + ) + + if is_async_fn(_func): + if settings.PREFECT_RESOLVE_FLOW_RESULTS: + + @wraps(_func) + async def async_wrapper(*f_args, **f_kwargs): + result = await _func(*f_args, **f_kwargs) + return await resolve_futures_to_results_async(result) + + return prefect_flow(async_wrapper, validate_parameters=False, **kwargs) + + else: + return prefect_flow(_func, validate_parameters=False, **kwargs) + else: + if settings.PREFECT_RESOLVE_FLOW_RESULTS: + + @wraps(_func) + def sync_wrapper(*f_args, **f_kwargs): + result = _func(*f_args, **f_kwargs) + return resolve_futures_to_results(result) + + return prefect_flow(sync_wrapper, validate_parameters=False, **kwargs) + else: + return prefect_flow(_func, validate_parameters=False, **kwargs) + + class Delayed_: """A small Dask-compatible, serializable object to wrap delayed functions that we don't want to execute. diff --git a/src/quacc/wflow_tools/prefect_utils.py b/src/quacc/wflow_tools/prefect_utils.py index 6fbf827222..a632267d1c 100644 --- a/src/quacc/wflow_tools/prefect_utils.py +++ b/src/quacc/wflow_tools/prefect_utils.py @@ -51,10 +51,73 @@ def _collect_futures(futures, expr, context): results = [] for future in futures: future.wait() - result = future.state.result() - if isinstance(result, BaseResult): - result = result.get() - results.append(result) + if future.state.is_completed(): + result = future.state.result() + if isinstance(result, BaseResult): + result = result.get() + results.append(result) + else: + raise BaseException("At least one result did not complete successfully") + + states_by_future = dict(zip(futures, results, strict=False)) + + def replace_futures_with_states(expr, context): + # Expressions inside quotes should not be modified + if isinstance(context.get("annotation"), quote): + raise StopVisiting + + if isinstance(expr, PrefectFuture): + return states_by_future[expr] + else: + return expr + + return visit_collection( + expr, visit_fn=replace_futures_with_states, return_data=True, context={} + ) + + +async def resolve_futures_to_results_async(expr: PrefectFuture | Any) -> State | Any: + """ + Given a Python built-in collection, recursively find `PrefectFutures` and build a + new collection with the same structure with futures resolved to their final result. + Resolving futures to their final result may wait for execution to complete. + + Unsupported object types will be returned without modification. + + This function is a trivial change from resolve_futures_to_states here: + https://github.com/PrefectHQ/prefect/blob/main/src/prefect/futures.py + """ + futures: set[PrefectFuture] = set() + + def _collect_futures(futures, expr, context): + # Expressions inside quotes should not be traversed + if isinstance(context.get("annotation"), quote): + raise StopVisiting + + if isinstance(expr, PrefectFuture): + futures.add(expr) + + return expr + + visit_collection( + expr, visit_fn=partial(_collect_futures, futures), return_data=False, context={} + ) + + # if no futures were found, return the original expression + if not futures: + return expr + + # Get final states for each future + results = [] + for future in futures: + future.wait() + if future.state.is_completed(): + result = future.state.result() + if isinstance(result, BaseResult): + result = await result.get() + results.append(result) + else: + raise BaseException("At least one result did not complete successfully") states_by_future = dict(zip(futures, results, strict=False)) diff --git a/tests/prefect/test_prefect_utils.py b/tests/prefect/test_prefect_utils.py new file mode 100644 index 0000000000..bab5b387c2 --- /dev/null +++ b/tests/prefect/test_prefect_utils.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +import pytest +from prefect import flow, task + +from quacc.wflow_tools.prefect_utils import ( + resolve_futures_to_results, + resolve_futures_to_results_async, +) + + +def test_resolve_futures_to_results(): + @task + def test_task(): + return {"test": 5} + + @flow + def test_flow(): + future = test_task.submit() + nested_future = {"nest": future} + return resolve_futures_to_results(nested_future) + + result = test_flow() + assert result["nest"]["test"] == 5 + + +def test_resolve_futures_to_results_taskfail(): + @task + def test_task(): + raise Exception + return {"test": 5} + + @flow + def test_flow(): + future = test_task.submit() + nested_future = {"nest": future} + return resolve_futures_to_results(nested_future) + + with pytest.raises(BaseException): # noqa: B017, PT011 + test_flow() + + +async def test_resolve_futures_to_results_async(): + @task + async def test_task(): + return {"test": 5} + + @flow + async def test_flow(): + future = test_task.submit() + nested_future = {"nest": future} + return await resolve_futures_to_results_async(nested_future) + + result = await test_flow() + assert result["nest"]["test"] == 5 diff --git a/tests/prefect/test_syntax.py b/tests/prefect/test_syntax.py index 90b6de477b..a7909feff2 100644 --- a/tests/prefect/test_syntax.py +++ b/tests/prefect/test_syntax.py @@ -260,3 +260,41 @@ def myflow(): t1 = time.time() dt = t1 - t0 assert dt < 10 + + +async def test_prefect_decorators_async(tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + + @job + async def add(a, b): + return a + b + + @flow + async def add_flow(a, b): + return add(a, b) + + assert (await add_flow(1, 2)) == 3 + + +async def test_prefect_decorators3_async(tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + + @job + async def add(a, b): + return a + b + + @job + async def make_more(val): + return [val] * 3 + + @subflow + async def add_distributed(vals, c): + return [add(val, c) for val in vals] + + @flow + async def dynamic_workflow(a, b, c): + result1 = add(a, b) + result2 = make_more(result1) + return await add_distributed(result2, c) + + assert (await dynamic_workflow(1, 2, 3)) == [6, 6, 6]