Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Predictively updates the run IDs on State.state_details client-side #14679

Merged
merged 1 commit into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 25 additions & 28 deletions src/prefect/task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,39 +301,36 @@ def set_state(self, state: State, force: bool = False) -> State:
last_state = self.state
if not self.task_run:
raise ValueError("Task run is not set")
try:
if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
new_state = state
# Copy over state_details from state to state
new_state.state_details.task_run_id = (
last_state.state_details.task_run_id
)
new_state.state_details.flow_run_id = (
last_state.state_details.flow_run_id
)
else:
new_state = propose_state_sync(
self.client, state, task_run_id=self.task_run.id, force=force
)
except Pause as exc:
# We shouldn't get a pause signal without a state, but if this happens,
# just use a Paused state to assume an in-process pause.
new_state = exc.state if exc.state else Paused()
if new_state.state_details.pause_reschedule:
# If we're being asked to pause and reschedule, we should exit the
# task and expect to be resumed later.
raise

# currently this is a hack to keep a reference to the state object
# that has an in-memory result attached to it; using the API state
# could result in losing that reference
self.task_run.state = new_state

# Predictively update the de-normalized task_run.state_* attributes client-side
if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
self.task_run.state = new_state = state

# Ensure that the state_details are populated with the current run IDs
new_state.state_details.task_run_id = self.task_run.id
new_state.state_details.flow_run_id = self.task_run.flow_run_id

# Predictively update the de-normalized task_run.state_* attributes
self.task_run.state_id = new_state.id
self.task_run.state_type = new_state.type
self.task_run.state_name = new_state.name
else:
try:
new_state = propose_state_sync(
self.client, state, task_run_id=self.task_run.id, force=force
)
except Pause as exc:
# We shouldn't get a pause signal without a state, but if this happens,
# just use a Paused state to assume an in-process pause.
new_state = exc.state if exc.state else Paused()
if new_state.state_details.pause_reschedule:
# If we're being asked to pause and reschedule, we should exit the
# task and expect to be resumed later.
raise

# currently this is a hack to keep a reference to the state object
# that has an in-memory result attached to it; using the API state
# could result in losing that reference
self.task_run.state = new_state

# emit a state change event
self._last_event = emit_task_run_state_change_event(
Expand Down
80 changes: 80 additions & 0 deletions tests/test_task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1913,3 +1913,83 @@ def foo():
assert task_run.state_id == task_run.state.id
assert task_run.state_type == task_run.state.type == StateType.FAILED
assert task_run.state_name == task_run.state.name == "Failed"

async def test_state_details_have_denormalized_task_run_id_async(self):
proof_that_i_ran = uuid4()

@task
async def foo():
task_run = TaskRunContext.get().task_run

assert task_run
assert task_run.state
assert task_run.state.state_details

assert task_run.state.state_details.flow_run_id is None
assert task_run.state.state_details.task_run_id == task_run.id

return proof_that_i_ran

assert await run_task_async(foo) == proof_that_i_ran

async def test_state_details_have_denormalized_flow_run_id_async(self):
proof_that_i_ran = uuid4()

@flow
async def the_flow():
return foo()

@task
async def foo():
task_run = TaskRunContext.get().task_run

assert task_run
assert task_run.state
assert task_run.state.state_details

assert task_run.state.state_details.flow_run_id == task_run.flow_run_id
assert task_run.state.state_details.task_run_id == task_run.id

return proof_that_i_ran

assert await the_flow() == proof_that_i_ran

def test_state_details_have_denormalized_task_run_id_sync(self):
proof_that_i_ran = uuid4()

@task
def foo():
task_run = TaskRunContext.get().task_run

assert task_run
assert task_run.state
assert task_run.state.state_details

assert task_run.state.state_details.flow_run_id is None
assert task_run.state.state_details.task_run_id == task_run.id

return proof_that_i_ran

assert run_task_sync(foo) == proof_that_i_ran

def test_state_details_have_denormalized_flow_run_id_sync(self):
proof_that_i_ran = uuid4()

@flow
def the_flow():
return foo()

@task
def foo():
task_run = TaskRunContext.get().task_run

assert task_run
assert task_run.state
assert task_run.state.state_details

assert task_run.state.state_details.flow_run_id == task_run.flow_run_id
assert task_run.state.state_details.task_run_id == task_run.id

return proof_that_i_ran

assert the_flow() == proof_that_i_ran