Skip to content

Commit

Permalink
set end_time during client-side task orchestration (#14681)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakekaplan authored Jul 19, 2024
1 parent ba7904b commit 033cb35
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 2 deletions.
11 changes: 11 additions & 0 deletions src/prefect/task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,11 @@ def handle_success(self, result: R, transaction: Transaction) -> R:
)
if transaction.is_committed():
terminal_state.name = "Cached"

if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
if self.task_run.start_time and not self.task_run.end_time:
self.task_run.end_time = terminal_state.timestamp

self.set_state(terminal_state)
self._return_value = result
return result
Expand Down Expand Up @@ -458,6 +463,9 @@ def handle_exception(self, exc: Exception) -> None:
result_factory=getattr(context, "result_factory", None),
)
)
if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
if self.task_run.start_time and not self.task_run.end_time:
self.task_run.end_time = state.timestamp
self.set_state(state)
self._raised = exc

Expand All @@ -480,6 +488,9 @@ def handle_crash(self, exc: BaseException) -> None:
state = run_coro_as_sync(exception_to_crashed_state(exc))
self.logger.error(f"Crash detected! {state.message}")
self.logger.debug("Crash details:", exc_info=exc)
if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
if self.task_run.start_time and not self.task_run.end_time:
self.task_run.end_time = state.timestamp
self.set_state(state, force=True)
self._raised = exc

Expand Down
142 changes: 140 additions & 2 deletions tests/test_task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,7 @@ async def my_task():


class TestTaskTimeTracking:
async def test_sync_task_start_time_set_on_running(self):
async def test_sync_task_sets_start_time_on_running(self):
@task
def foo():
return TaskRunContext.get().task_run.id
Expand All @@ -1121,7 +1121,7 @@ def foo():
assert len(states) == 1 and states[0].type == StateType.RUNNING
assert states[0].timestamp == run.start_time

async def test_async_task_start_time_set_on_running(self):
async def test_async_task_sets_start_time_on_running(self):
ID = None

@task
Expand All @@ -1137,6 +1137,144 @@ async def foo():
assert len(states) == 1 and states[0].type == StateType.RUNNING
assert states[0].timestamp == run.start_time

async def test_sync_task_sets_end_time_on_completed(self):
@task
def foo():
return TaskRunContext.get().task_run.id

task_run_id = run_task_sync(foo)
run = await get_task_run(task_run_id)

assert run.end_time is not None
states = await get_task_run_states(task_run_id, StateType.COMPLETED)
assert len(states) == 1 and states[0].type == StateType.COMPLETED
assert states[0].timestamp == run.end_time

async def test_async_task_sets_end_time_on_completed(self):
@task
async def foo():
return TaskRunContext.get().task_run.id

task_run_id = await run_task_async(foo)
run = await get_task_run(task_run_id)

assert run.end_time is not None
states = await get_task_run_states(task_run_id, StateType.COMPLETED)
assert len(states) == 1 and states[0].type == StateType.COMPLETED
assert states[0].timestamp == run.end_time

async def test_sync_task_sets_end_time_on_failed(self):
ID = None

@task
def foo():
nonlocal ID
ID = TaskRunContext.get().task_run.id
raise ValueError("failure!!!")

with pytest.raises(ValueError):
run_task_sync(foo)

run = await get_task_run(ID)

assert run.end_time is not None
states = await get_task_run_states(ID, StateType.FAILED)
assert len(states) == 1 and states[0].type == StateType.FAILED
assert states[0].timestamp == run.end_time

async def test_async_task_sets_end_time_on_failed(self):
ID = None

@task
async def foo():
nonlocal ID
ID = TaskRunContext.get().task_run.id
raise ValueError("failure!!!")

with pytest.raises(ValueError):
await run_task_async(foo)

run = await get_task_run(ID)

assert run.end_time is not None
states = await get_task_run_states(ID, StateType.FAILED)
assert len(states) == 1 and states[0].type == StateType.FAILED
assert states[0].timestamp == run.end_time

async def test_sync_task_sets_end_time_on_crashed(self):
ID = None

@task
def foo():
nonlocal ID
ID = TaskRunContext.get().task_run.id
raise SystemExit

with pytest.raises(SystemExit):
run_task_sync(foo)

run = await get_task_run(ID)

assert run.end_time is not None
states = await get_task_run_states(ID, StateType.CRASHED)
assert len(states) == 1 and states[0].type == StateType.CRASHED
assert states[0].timestamp == run.end_time

async def test_async_task_sets_end_time_on_crashed(self):
ID = None

@task
async def foo():
nonlocal ID
ID = TaskRunContext.get().task_run.id
raise SystemExit

with pytest.raises(SystemExit):
await run_task_async(foo)

run = await get_task_run(ID)

assert run.end_time is not None
states = await get_task_run_states(ID, StateType.CRASHED)
assert len(states) == 1 and states[0].type == StateType.CRASHED
assert states[0].timestamp == run.end_time

async def test_sync_task_does_not_set_end_time_on_crash_pre_runnning(
self, monkeypatch
):
monkeypatch.setattr(
TaskRunEngine, "begin_run", MagicMock(side_effect=SystemExit)
)

@task
def my_task():
pass

with pytest.raises(SystemExit):
my_task()

run = await get_task_run(task_run_id=None)

assert run.end_time is None

async def test_async_task_does_not_set_end_time_on_crash_pre_running(
self, monkeypatch
):
monkeypatch.setattr(
TaskRunEngine, "begin_run", MagicMock(side_effect=SystemExit)
)

@task
async def my_task():
pass

with pytest.raises(SystemExit):
await my_task()

run = await get_task_run(task_run_id=None)

assert run.end_time is None


class TestRunCountTracking:
@pytest.fixture
Expand Down

0 comments on commit 033cb35

Please sign in to comment.