Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jakekaplan committed Jul 19, 2024
1 parent bd4c66c commit 8c2096c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/prefect/task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def begin_run(self):
new_state = Running()
state = self.set_state(new_state)
if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
self.task_run.start_time = state.timestamp
self.task_run.start_time = self.task_run.state.timestamp

# TODO: this is temporary until the API stops rejecting state transitions
# and the client / transaction store becomes the source of truth
Expand Down
32 changes: 26 additions & 6 deletions tests/test_task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ def enable_client_side_task_run_orchestration(
yield enabled


def state_from_event(event) -> State:
return State(
id=event.id,
timestamp=event.occurred,
**event.payload["validated_state"],
)


async def get_task_run(task_run_id: Optional[UUID]) -> TaskRun:
if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
task_run = get_task_run_sync(task_run_id)
Expand Down Expand Up @@ -83,7 +91,7 @@ def get_task_run_sync(task_run_id: Optional[UUID]) -> TaskRun:
if e.resource.prefect_object_id("prefect.task-run") == task_run_id
]
last_event = events[-1]
state = State(**last_event.payload["validated_state"])
state = state_from_event(last_event)
task_run = TaskRun(
id=last_event.resource.prefect_object_id("prefect.task-run"),
state=state,
Expand Down Expand Up @@ -119,7 +127,7 @@ async def get_task_run_states(
for e in events
if e.resource.prefect_object_id("prefect.task-run") == task_run_id
]
states = [State(**e.payload["validated_state"]) for e in events]
states = [state_from_event(e) for e in events]
else:
client = get_client()
states = await client.read_task_run_states(task_run_id)
Expand Down Expand Up @@ -1100,7 +1108,20 @@ async def my_task():


class TestTaskTimeTracking:
async def test_start_time_set_on_running(self):
async def test_sync_task_start_time_set_on_running(self):
@task
def foo():
return TaskRunContext.get().task_run.id

task_run_id = run_task_sync(foo)
run = await get_task_run(task_run_id)

assert run.start_time is not None
states = await get_task_run_states(task_run_id, StateType.RUNNING)
assert len(states) == 1 and states[0].type == StateType.RUNNING
assert states[0].timestamp == run.start_time

async def test_async_task_start_time_set_on_running(self):
ID = None

@task
Expand All @@ -1113,9 +1134,8 @@ async def foo():

assert run.start_time is not None
states = await get_task_run_states(ID, StateType.RUNNING)
assert len(states) == 1
running = states[0]
assert running.timestamp == run.start_time
assert len(states) == 1 and states[0].type == StateType.RUNNING
assert states[0].timestamp == run.start_time


class TestSyncAsyncTasks:
Expand Down

0 comments on commit 8c2096c

Please sign in to comment.