Skip to content

Commit

Permalink
Merge branch 'main' into task-run-increment-run-time
Browse files Browse the repository at this point in the history
  • Loading branch information
jakekaplan committed Jul 19, 2024
2 parents 7e69e91 + 033cb35 commit 6365a6e
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 30 deletions.
66 changes: 36 additions & 30 deletions src/prefect/task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,18 @@ def begin_run(self):
return

new_state = Running()
state = self.set_state(new_state)

if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
self.task_run.start_time = self.task_run.state.timestamp
self.task_run.start_time = new_state.timestamp
self.task_run.run_count += 1

flow_run_context = FlowRunContext.get()
if flow_run_context:
# Carry forward any task run information from the flow run
flow_run = flow_run_context.flow_run
self.task_run.flow_run_run_count = flow_run.run_count

state = self.set_state(new_state)

# TODO: this is temporary until the API stops rejecting state transitions
# and the client / transaction store becomes the source of truth
Expand Down Expand Up @@ -301,39 +310,36 @@ def set_state(self, state: State, force: bool = False) -> State:
last_state = self.state
if not self.task_run:
raise ValueError("Task run is not set")
try:
if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
new_state = state
# Copy over state_details from state to state
new_state.state_details.task_run_id = (
last_state.state_details.task_run_id
)
new_state.state_details.flow_run_id = (
last_state.state_details.flow_run_id
)
else:
new_state = propose_state_sync(
self.client, state, task_run_id=self.task_run.id, force=force
)
except Pause as exc:
# We shouldn't get a pause signal without a state, but if this happens,
# just use a Paused state to assume an in-process pause.
new_state = exc.state if exc.state else Paused()
if new_state.state_details.pause_reschedule:
# If we're being asked to pause and reschedule, we should exit the
# task and expect to be resumed later.
raise

# currently this is a hack to keep a reference to the state object
# that has an in-memory result attached to it; using the API state
# could result in losing that reference
self.task_run.state = new_state

# Predictively update the de-normalized task_run.state_* attributes client-side
if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
self.task_run.state = new_state = state

# Ensure that the state_details are populated with the current run IDs
new_state.state_details.task_run_id = self.task_run.id
new_state.state_details.flow_run_id = self.task_run.flow_run_id

# Predictively update the de-normalized task_run.state_* attributes
self.task_run.state_id = new_state.id
self.task_run.state_type = new_state.type
self.task_run.state_name = new_state.name
else:
try:
new_state = propose_state_sync(
self.client, state, task_run_id=self.task_run.id, force=force
)
except Pause as exc:
# We shouldn't get a pause signal without a state, but if this happens,
# just use a Paused state to assume an in-process pause.
new_state = exc.state if exc.state else Paused()
if new_state.state_details.pause_reschedule:
# If we're being asked to pause and reschedule, we should exit the
# task and expect to be resumed later.
raise

# currently this is a hack to keep a reference to the state object
# that has an in-memory result attached to it; using the API state
# could result in losing that reference
self.task_run.state = new_state

# emit a state change event
self._last_event = emit_task_run_state_change_event(
Expand Down
161 changes: 161 additions & 0 deletions tests/test_task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,6 +1319,87 @@ async def my_task():
assert run.end_time is None


class TestRunCountTracking:
@pytest.fixture
async def flow_run_context(self, prefect_client: PrefectClient):
@flow
def f():
pass

test_task_runner = ThreadPoolTaskRunner()
flow_run = await prefect_client.create_flow_run(f)
await propose_state(prefect_client, Running(), flow_run_id=flow_run.id)

flow_run = await prefect_client.read_flow_run(flow_run.id)
assert flow_run.run_count == 1

result_factory = await ResultFactory.from_flow(f)
return EngineContext(
flow=f,
flow_run=flow_run,
client=prefect_client,
task_runner=test_task_runner,
result_factory=result_factory,
parameters={"x": "y"},
)

def test_sync_task_run_counts(self, flow_run_context: EngineContext):
ID = None
proof_that_i_ran = uuid4()

@task
def foo():
task_run = TaskRunContext.get().task_run

nonlocal ID
ID = task_run.id

assert task_run
assert task_run.state
assert task_run.state.type == StateType.RUNNING

assert task_run.run_count == 1
assert task_run.flow_run_run_count == flow_run_context.flow_run.run_count

return proof_that_i_ran

with flow_run_context:
assert run_task_sync(foo) == proof_that_i_ran

task_run = get_task_run_sync(ID)
assert task_run
assert task_run.run_count == 1
assert task_run.flow_run_run_count == flow_run_context.flow_run.run_count

async def test_async_task_run_counts(self, flow_run_context: EngineContext):
ID = None
proof_that_i_ran = uuid4()

@task
async def foo():
task_run = TaskRunContext.get().task_run

nonlocal ID
ID = task_run.id

assert task_run
assert task_run.state
assert task_run.state.type == StateType.RUNNING

assert task_run.run_count == 1
assert task_run.flow_run_run_count == flow_run_context.flow_run.run_count

return proof_that_i_ran

with flow_run_context:
assert await run_task_async(foo) == proof_that_i_ran

task_run = await get_task_run(ID)
assert task_run
assert task_run.run_count == 1
assert task_run.flow_run_run_count == flow_run_context.flow_run.run_count


class TestSyncAsyncTasks:
async def test_sync_task_in_async_task(self):
@task
Expand Down Expand Up @@ -2094,3 +2175,83 @@ def foo():
assert task_run.state_id == task_run.state.id
assert task_run.state_type == task_run.state.type == StateType.FAILED
assert task_run.state_name == task_run.state.name == "Failed"

async def test_state_details_have_denormalized_task_run_id_async(self):
proof_that_i_ran = uuid4()

@task
async def foo():
task_run = TaskRunContext.get().task_run

assert task_run
assert task_run.state
assert task_run.state.state_details

assert task_run.state.state_details.flow_run_id is None
assert task_run.state.state_details.task_run_id == task_run.id

return proof_that_i_ran

assert await run_task_async(foo) == proof_that_i_ran

async def test_state_details_have_denormalized_flow_run_id_async(self):
proof_that_i_ran = uuid4()

@flow
async def the_flow():
return foo()

@task
async def foo():
task_run = TaskRunContext.get().task_run

assert task_run
assert task_run.state
assert task_run.state.state_details

assert task_run.state.state_details.flow_run_id == task_run.flow_run_id
assert task_run.state.state_details.task_run_id == task_run.id

return proof_that_i_ran

assert await the_flow() == proof_that_i_ran

def test_state_details_have_denormalized_task_run_id_sync(self):
proof_that_i_ran = uuid4()

@task
def foo():
task_run = TaskRunContext.get().task_run

assert task_run
assert task_run.state
assert task_run.state.state_details

assert task_run.state.state_details.flow_run_id is None
assert task_run.state.state_details.task_run_id == task_run.id

return proof_that_i_ran

assert run_task_sync(foo) == proof_that_i_ran

def test_state_details_have_denormalized_flow_run_id_sync(self):
proof_that_i_ran = uuid4()

@flow
def the_flow():
return foo()

@task
def foo():
task_run = TaskRunContext.get().task_run

assert task_run
assert task_run.state
assert task_run.state.state_details

assert task_run.state.state_details.flow_run_id == task_run.flow_run_id
assert task_run.state.state_details.task_run_id == task_run.id

return proof_that_i_ran

assert the_flow() == proof_that_i_ran

0 comments on commit 6365a6e

Please sign in to comment.