diff --git a/src/prefect/task_engine.py b/src/prefect/task_engine.py index 2836939e3b2c..515d32efa657 100644 --- a/src/prefect/task_engine.py +++ b/src/prefect/task_engine.py @@ -301,39 +301,36 @@ def set_state(self, state: State, force: bool = False) -> State: last_state = self.state if not self.task_run: raise ValueError("Task run is not set") - try: - if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION: - new_state = state - # Copy over state_details from state to state - new_state.state_details.task_run_id = ( - last_state.state_details.task_run_id - ) - new_state.state_details.flow_run_id = ( - last_state.state_details.flow_run_id - ) - else: - new_state = propose_state_sync( - self.client, state, task_run_id=self.task_run.id, force=force - ) - except Pause as exc: - # We shouldn't get a pause signal without a state, but if this happens, - # just use a Paused state to assume an in-process pause. - new_state = exc.state if exc.state else Paused() - if new_state.state_details.pause_reschedule: - # If we're being asked to pause and reschedule, we should exit the - # task and expect to be resumed later. - raise - - # currently this is a hack to keep a reference to the state object - # that has an in-memory result attached to it; using the API state - # could result in losing that reference - self.task_run.state = new_state - # Predictively update the de-normalized task_run.state_* attributes client-side if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION: + self.task_run.state = new_state = state + + # Ensure that the state_details are populated with the current run IDs + new_state.state_details.task_run_id = self.task_run.id + new_state.state_details.flow_run_id = self.task_run.flow_run_id + + # Predictively update the de-normalized task_run.state_* attributes self.task_run.state_id = new_state.id self.task_run.state_type = new_state.type self.task_run.state_name = new_state.name + else: + try: + new_state = propose_state_sync( + self.client, state, task_run_id=self.task_run.id, force=force + ) + except Pause as exc: + # We shouldn't get a pause signal without a state, but if this happens, + # just use a Paused state to assume an in-process pause. + new_state = exc.state if exc.state else Paused() + if new_state.state_details.pause_reschedule: + # If we're being asked to pause and reschedule, we should exit the + # task and expect to be resumed later. + raise + + # currently this is a hack to keep a reference to the state object + # that has an in-memory result attached to it; using the API state + # could result in losing that reference + self.task_run.state = new_state # emit a state change event self._last_event = emit_task_run_state_change_event( diff --git a/tests/test_task_engine.py b/tests/test_task_engine.py index 23329c44e847..9e0bb1f54052 100644 --- a/tests/test_task_engine.py +++ b/tests/test_task_engine.py @@ -1913,3 +1913,83 @@ def foo(): assert task_run.state_id == task_run.state.id assert task_run.state_type == task_run.state.type == StateType.FAILED assert task_run.state_name == task_run.state.name == "Failed" + + async def test_state_details_have_denormalized_task_run_id_async(self): + proof_that_i_ran = uuid4() + + @task + async def foo(): + task_run = TaskRunContext.get().task_run + + assert task_run + assert task_run.state + assert task_run.state.state_details + + assert task_run.state.state_details.flow_run_id is None + assert task_run.state.state_details.task_run_id == task_run.id + + return proof_that_i_ran + + assert await run_task_async(foo) == proof_that_i_ran + + async def test_state_details_have_denormalized_flow_run_id_async(self): + proof_that_i_ran = uuid4() + + @flow + async def the_flow(): + return foo() + + @task + async def foo(): + task_run = TaskRunContext.get().task_run + + assert task_run + assert task_run.state + assert task_run.state.state_details + + assert task_run.state.state_details.flow_run_id == task_run.flow_run_id + assert task_run.state.state_details.task_run_id == task_run.id + + return proof_that_i_ran + + assert await the_flow() == proof_that_i_ran + + def test_state_details_have_denormalized_task_run_id_sync(self): + proof_that_i_ran = uuid4() + + @task + def foo(): + task_run = TaskRunContext.get().task_run + + assert task_run + assert task_run.state + assert task_run.state.state_details + + assert task_run.state.state_details.flow_run_id is None + assert task_run.state.state_details.task_run_id == task_run.id + + return proof_that_i_ran + + assert run_task_sync(foo) == proof_that_i_ran + + def test_state_details_have_denormalized_flow_run_id_sync(self): + proof_that_i_ran = uuid4() + + @flow + def the_flow(): + return foo() + + @task + def foo(): + task_run = TaskRunContext.get().task_run + + assert task_run + assert task_run.state + assert task_run.state.state_details + + assert task_run.state.state_details.flow_run_id == task_run.flow_run_id + assert task_run.state.state_details.task_run_id == task_run.id + + return proof_that_i_ran + + assert the_flow() == proof_that_i_ran