Skip to content

Commit

Permalink
[2.x] ensure on_failure hook runs upon parameter validation error (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz authored Aug 29, 2024
1 parent 137ff56 commit 3b951c3
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 7 deletions.
23 changes: 16 additions & 7 deletions src/prefect/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,18 @@ def enter_flow_run_engine_from_subprocess(flow_run_id: UUID) -> State:
return state


async def _make_flow_run(
flow: Flow, parameters: Dict[str, Any], state: State, client: PrefectClient
) -> FlowRun:
return await client.create_flow_run(
flow,
# Send serialized parameters to the backend
parameters=flow.serialize_parameters(parameters),
state=state,
tags=TagsContext.get().current_tags,
)


@inject_client
async def create_then_begin_flow_run(
flow: Flow,
Expand All @@ -351,6 +363,7 @@ async def create_then_begin_flow_run(

await check_api_reachable(client, "Cannot create flow run")

flow_run = None
state = Pending()
if flow.should_validate_parameters:
try:
Expand All @@ -359,14 +372,10 @@ async def create_then_begin_flow_run(
state = await exception_to_failed_state(
message="Validation of flow parameters failed with error:"
)
flow_run = await _make_flow_run(flow, parameters, state, client)
await _run_flow_hooks(flow, flow_run, state)

flow_run = await client.create_flow_run(
flow,
# Send serialized parameters to the backend
parameters=flow.serialize_parameters(parameters),
state=state,
tags=TagsContext.get().current_tags,
)
flow_run = flow_run or await _make_flow_run(flow, parameters, state, client)

engine_logger.info(f"Created flow run {flow_run.name!r} for flow {flow.name!r}")

Expand Down
14 changes: 14 additions & 0 deletions tests/test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -2830,6 +2830,20 @@ def my_flow():
assert state.type == StateType.FAILED
assert my_mock.call_args_list == [call(), call()]

def test_on_failure_hooks_run_on_bad_parameters(self):
my_mock = MagicMock()

def failure_hook(flow, flow_run, state):
my_mock("failure_hook")

@flow(on_failure=[failure_hook])
def my_flow(x: int):
pass

state = my_flow(x="x", return_state=True)
assert state.type == StateType.FAILED
assert my_mock.call_args_list == [call("failure_hook")]


class TestFlowHooksOnCancellation:
def test_noniterable_hook_raises(self):
Expand Down

0 comments on commit 3b951c3

Please sign in to comment.