Skip to content

Commit

Permalink
Small bugfix for async prefect functions (#2462)
Browse files Browse the repository at this point in the history
## Summary of Changes

This provides a small bug fix for the recent patch to how prefect flows
are handled. The prior methods only worked for sync flows.

### Requirements

- [x] My PR is focused on a [single feature addition or
bugfix](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/getting-started/best-practices-for-pull-requests#write-small-prs).
- [x] My PR has relevant, comprehensive [unit
tests](https://quantum-accelerators.github.io/quacc/dev/contributing.html#unit-tests).
- [x] My PR is on a [custom
branch](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-and-deleting-branches-within-your-repository)
(i.e. is _not_ named `main`).

Note: If you are an external contributor, you will see a comment from
[@buildbot-princeton](https://github.com/buildbot-princeton). This is
solely for the maintainers.

---------

Co-authored-by: Andrew S. Rosen <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 19, 2024
1 parent 85615ef commit c39345b
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 22 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ include = ["quacc"]
exclude = ["**/__pycache__"]

[tool.pytest.ini_options]
asyncio_mode = "auto"
minversion = "6.0"
addopts = ["-p no:warnings", "--import-mode=importlib"]
xfail_strict = true
Expand Down
4 changes: 4 additions & 0 deletions src/quacc/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ class QuaccSettings(BaseSettings):
PREFECT_AUTO_SUBMIT: bool = Field(
True, description="Whether to auto-submit tasks to the task runner."
)
PREFECT_RESOLVE_FLOW_RESULTS: bool = Field(
True,
description="Whether to resolve all futures in flow results to data and fail if not possible",
)

# ---------------------------
# ORCA Settings
Expand Down
54 changes: 36 additions & 18 deletions src/quacc/wflow_tools/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,15 +348,7 @@ def workflow(a, b, c):

return task(_func, namespace=_func.__module__, **kwargs)
elif settings.WORKFLOW_ENGINE == "prefect":
from prefect import flow as prefect_flow

from quacc.wflow_tools.prefect_utils import resolve_futures_to_results

@wraps(_func)
def wrapper(*f_args, **f_kwargs):
return resolve_futures_to_results(_func(*f_args, **f_kwargs))

return prefect_flow(wrapper, validate_parameters=False, **kwargs)
return _get_prefect_wrapped_flow(_func, settings, **kwargs)
else:
return _func

Expand Down Expand Up @@ -585,15 +577,7 @@ def wrapper(*f_args, **f_kwargs):

return join_app(wrapped_fn, **kwargs)
elif settings.WORKFLOW_ENGINE == "prefect":
from prefect import flow as prefect_flow

from quacc.wflow_tools.prefect_utils import resolve_futures_to_results

@wraps(_func)
def wrapper(*f_args, **f_kwargs):
return resolve_futures_to_results(_func(*f_args, **f_kwargs))

return prefect_flow(wrapper, validate_parameters=False, **kwargs)
return _get_prefect_wrapped_flow(_func, settings, **kwargs)
elif settings.WORKFLOW_ENGINE == "redun":
from redun import task

Expand Down Expand Up @@ -643,6 +627,40 @@ def wrapper(
return wrapper


def _get_prefect_wrapped_flow(_func, settings, **kwargs):
from prefect import flow as prefect_flow
from prefect.utilities.asyncutils import is_async_fn

from quacc.wflow_tools.prefect_utils import (
resolve_futures_to_results,
resolve_futures_to_results_async,
)

if is_async_fn(_func):
if settings.PREFECT_RESOLVE_FLOW_RESULTS:

@wraps(_func)
async def async_wrapper(*f_args, **f_kwargs):
result = await _func(*f_args, **f_kwargs)
return await resolve_futures_to_results_async(result)

return prefect_flow(async_wrapper, validate_parameters=False, **kwargs)

else:
return prefect_flow(_func, validate_parameters=False, **kwargs)
else:
if settings.PREFECT_RESOLVE_FLOW_RESULTS:

@wraps(_func)
def sync_wrapper(*f_args, **f_kwargs):
result = _func(*f_args, **f_kwargs)
return resolve_futures_to_results(result)

return prefect_flow(sync_wrapper, validate_parameters=False, **kwargs)
else:
return prefect_flow(_func, validate_parameters=False, **kwargs)


class Delayed_:
"""A small Dask-compatible, serializable object to wrap delayed functions that we
don't want to execute.
Expand Down
71 changes: 67 additions & 4 deletions src/quacc/wflow_tools/prefect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,73 @@ def _collect_futures(futures, expr, context):
results = []
for future in futures:
future.wait()
result = future.state.result()
if isinstance(result, BaseResult):
result = result.get()
results.append(result)
if future.state.is_completed():
result = future.state.result()
if isinstance(result, BaseResult):
result = result.get()
results.append(result)
else:
raise BaseException("At least one result did not complete successfully")

states_by_future = dict(zip(futures, results, strict=False))

def replace_futures_with_states(expr, context):
# Expressions inside quotes should not be modified
if isinstance(context.get("annotation"), quote):
raise StopVisiting

if isinstance(expr, PrefectFuture):
return states_by_future[expr]
else:
return expr

return visit_collection(
expr, visit_fn=replace_futures_with_states, return_data=True, context={}
)


async def resolve_futures_to_results_async(expr: PrefectFuture | Any) -> State | Any:
"""
Given a Python built-in collection, recursively find `PrefectFutures` and build a
new collection with the same structure with futures resolved to their final result.
Resolving futures to their final result may wait for execution to complete.
Unsupported object types will be returned without modification.
This function is a trivial change from resolve_futures_to_states here:
https://github.com/PrefectHQ/prefect/blob/main/src/prefect/futures.py
"""
futures: set[PrefectFuture] = set()

def _collect_futures(futures, expr, context):
# Expressions inside quotes should not be traversed
if isinstance(context.get("annotation"), quote):
raise StopVisiting

if isinstance(expr, PrefectFuture):
futures.add(expr)

return expr

visit_collection(
expr, visit_fn=partial(_collect_futures, futures), return_data=False, context={}
)

# if no futures were found, return the original expression
if not futures:
return expr

# Get final states for each future
results = []
for future in futures:
future.wait()
if future.state.is_completed():
result = future.state.result()
if isinstance(result, BaseResult):
result = await result.get()
results.append(result)
else:
raise BaseException("At least one result did not complete successfully")

states_by_future = dict(zip(futures, results, strict=False))

Expand Down
55 changes: 55 additions & 0 deletions tests/prefect/test_prefect_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from __future__ import annotations

import pytest
from prefect import flow, task

from quacc.wflow_tools.prefect_utils import (
resolve_futures_to_results,
resolve_futures_to_results_async,
)


def test_resolve_futures_to_results():
@task
def test_task():
return {"test": 5}

@flow
def test_flow():
future = test_task.submit()
nested_future = {"nest": future}
return resolve_futures_to_results(nested_future)

result = test_flow()
assert result["nest"]["test"] == 5


def test_resolve_futures_to_results_taskfail():
@task
def test_task():
raise Exception
return {"test": 5}

@flow
def test_flow():
future = test_task.submit()
nested_future = {"nest": future}
return resolve_futures_to_results(nested_future)

with pytest.raises(BaseException): # noqa: B017, PT011
test_flow()


async def test_resolve_futures_to_results_async():
@task
async def test_task():
return {"test": 5}

@flow
async def test_flow():
future = test_task.submit()
nested_future = {"nest": future}
return await resolve_futures_to_results_async(nested_future)

result = await test_flow()
assert result["nest"]["test"] == 5
38 changes: 38 additions & 0 deletions tests/prefect/test_syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,41 @@ def myflow():
t1 = time.time()
dt = t1 - t0
assert dt < 10


async def test_prefect_decorators_async(tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)

@job
async def add(a, b):
return a + b

@flow
async def add_flow(a, b):
return add(a, b)

assert (await add_flow(1, 2)) == 3


async def test_prefect_decorators3_async(tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)

@job
async def add(a, b):
return a + b

@job
async def make_more(val):
return [val] * 3

@subflow
async def add_distributed(vals, c):
return [add(val, c) for val in vals]

@flow
async def dynamic_workflow(a, b, c):
result1 = add(a, b)
result2 = make_more(result1)
return await add_distributed(result2, c)

assert (await dynamic_workflow(1, 2, 3)) == [6, 6, 6]

0 comments on commit c39345b

Please sign in to comment.