Skip to content

Commit

Permalink
checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
bunchesofdonald committed Jul 19, 2024
1 parent 2cbabdc commit 10621f8
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 43 deletions.
99 changes: 64 additions & 35 deletions src/prefect/task_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import inspect
import logging
import threading
Expand All @@ -20,6 +19,7 @@
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
Expand Down Expand Up @@ -225,21 +225,25 @@ def _resolve_and_wait_for_dependencies(self) -> Union[State, None]:
except UpstreamTaskError as upstream_exc:
return Pending(name="NotReady", message=str(upstream_exc))

def _calculate_backoff(self, backoff_count: int) -> float:
def _calculate_backoff(self, backoff_count: int) -> Tuple[float, int]:
if backoff_count < BACKOFF_MAX:
backoff_count += 1
return clamped_poisson_interval(

interval = clamped_poisson_interval(
average_interval=backoff_count, clamping_factor=0.3
)

def _handle_state_change(self, new_state: State) -> State:
return interval, backoff_count

def _handle_state_change(self, new_state: State):
last_state = self.state

# currently this is a hack to keep a reference to the state object
# that has an in-memory result attached to it; using the API state

# could result in losing that reference
self.task_run.state = new_state

# emit a state change event
self._last_event = emit_task_run_state_change_event(
task_run=self.task_run,
Expand Down Expand Up @@ -307,7 +311,9 @@ def transaction_context(self) -> Generator[Transaction, None, None]:
) as txn:
yield txn

def _get_hooks(self, task: Task, state: State) -> dict[str, Callable]:
def _hooks(
self, task: Task, state: State
) -> Generator[Tuple[str, Callable], None, None]:
task = self.task

if state.is_failed() and task.on_failure_hooks:
Expand All @@ -317,7 +323,8 @@ def _get_hooks(self, task: Task, state: State) -> dict[str, Callable]:
else:
hooks = []

return {_get_hook_name(hook): hook for hook in hooks}
for hook in hooks:
yield (_get_hook_name(hook), hook)

def _result_from_non_base_result(
self, raise_on_failure: bool
Expand Down Expand Up @@ -434,7 +441,7 @@ def call_hooks(self, state: Optional[State] = None):

self._ensure_task_run()

for hook_name, hook in self._get_hooks(task=self.task, state=state).items():
for hook_name, hook in self._hooks(task=self.task, state=state):
try:
self.logger.info(
f"Running hook {hook_name!r} in response to entering state"
Expand All @@ -456,6 +463,8 @@ def begin_run(self):
self.set_state(not_ready_state, force=self.state.is_pending())
return

self.task_run = self.client.read_task_run(self.task_run.id)

new_state = Running()
state = self.set_state(new_state)

Expand All @@ -473,7 +482,7 @@ def begin_run(self):

# TODO: Could this listen for state change events instead of polling?
while state.is_pending() or state.is_paused():
interval = self._calculate_backoff(backoff_count)
interval, backoff_count = self._calculate_backoff(backoff_count)
time.sleep(interval)
state = self.set_state(new_state)

Expand Down Expand Up @@ -615,16 +624,18 @@ def initialize_run(
extra_task_inputs=dependencies,
)
)
self.logger = self._task_run_logger()

# Emit an event to capture that the task run was in the `PENDING` state.
self.logger = self._task_run_logger()

# Capture initial state of the task run, usually `PENDING` or `SCHEDULED`.
self._last_event = emit_task_run_state_change_event(
task_run=self.task_run,
initial_state=None,
validated_state=self.task_run.state,
)

self.set_task_run_name()
with self.configured_run_context():
self.set_task_run_name()

self.logger.info(
f"Created task run {self.task_run.name!r} for task {self.task.name!r}"
Expand Down Expand Up @@ -714,6 +725,20 @@ def handle_crash(self, exc: BaseException) -> None:
#
# --------------------------

@contextmanager
def configured_run_context(self) -> Generator[None, None, None]:
result_factory = run_coro_as_sync(ResultFactory.from_task(self.task))
self.task_run = self.client.read_task_run(self.task_run.id)

try:
with self.setup_run_context(
result_factory=result_factory,
client=self.client,
):
yield
finally:
pass

@contextmanager
def start(
self,
Expand All @@ -729,15 +754,7 @@ def start(

@contextmanager
def run_context(self):
result_factory = run_coro_as_sync(ResultFactory.from_task(self.task))

self.task_run = self.client.read_task_run(self.task_run.id)
self.set_task_run_name()

with self.setup_run_context(
result_factory=result_factory,
client=self.client,
):
with self.configured_run_context():
try:
with timeout(
seconds=self.task.timeout_seconds,
Expand Down Expand Up @@ -784,7 +801,7 @@ async def call_hooks(self, state: Optional[State] = None):

self._ensure_task_run()

for hook_name, hook in self._get_hooks(task=self.task, state=state).items():
for hook_name, hook in self._hooks(task=self.task, state=state):
try:
self.logger.info(
f"Running hook {hook_name!r} in response to entering state"
Expand All @@ -806,6 +823,8 @@ async def begin_run(self):
await self.set_state(not_ready_state, force=self.state.is_pending())
return

self.task_run = await self.client.read_task_run(self.task_run.id)

new_state = Running()
state = await self.set_state(new_state)

Expand All @@ -823,8 +842,8 @@ async def begin_run(self):

# TODO: Could this listen for state change events instead of polling?
while state.is_pending() or state.is_paused():
interval = self._calculate_backoff(backoff_count)
await asyncio.sleep(interval)
interval, backoff_count = self._calculate_backoff(backoff_count)
await anyio.sleep(interval)
state = await self.set_state(new_state)

async def set_state(self, state: State, force: bool = False) -> State:
Expand Down Expand Up @@ -969,16 +988,20 @@ async def initialize_run(
wait_for=self.wait_for,
extra_task_inputs=dependencies,
)
self.logger = self._task_run_logger()

# Emit an event to capture that the task run was in the `PENDING` state.
self.logger = self._task_run_logger()

# Capture initial state of the task run, usually `PENDING` or `SCHEDULED`.
self._last_event = emit_task_run_state_change_event(
task_run=self.task_run,
initial_state=None,
validated_state=self.task_run.state,
)

await self.set_task_run_name()
async with self.configured_run_context():
# Some tasks may rely on the run context to set the task
# run name.
await self.set_task_run_name()

self.logger.info(
f"Created task run {self.task_run.name!r} for task {self.task.name!r}"
Expand Down Expand Up @@ -1062,6 +1085,20 @@ async def wait_until_ready(self):
#
# --------------------------

@asynccontextmanager
async def configured_run_context(self) -> AsyncGenerator[None, None]:
result_factory = await ResultFactory.from_task(self.task)
self.task_run = await self.client.read_task_run(self.task_run.id)

try:
with self.setup_run_context(
result_factory=result_factory,
client=self.client,
):
yield
finally:
pass

@asynccontextmanager
async def start(
self,
Expand All @@ -1087,15 +1124,7 @@ async def start(

@asynccontextmanager
async def run_context(self):
result_factory = await ResultFactory.from_task(self.task)

self.task_run = await self.client.read_task_run(self.task_run.id)
await self.set_task_run_name()

with self.setup_run_context(
result_factory=result_factory,
client=self.client,
):
async with self.configured_run_context():
try:
with timeout_async(
seconds=self.task.timeout_seconds,
Expand Down
7 changes: 3 additions & 4 deletions tests/test_flow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1148,11 +1148,10 @@ async def pausing_flow():
await asyncio.sleep(0.1)
flow_run = await prefect_client.read_flow_run(flow_run_id)

# execution isn't blocked, so this task should enter the engine, but not begin
# execution
with pytest.raises(RuntimeError):
# the sleeper mock will exhaust its side effects after 6 calls
try:
await doesnt_run()
except RuntimeError:
pass

await pausing_flow()

Expand Down
12 changes: 8 additions & 4 deletions tests/test_task_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def empty_task():
def always_error(*args, **kwargs):
raise ValueError("oops")

monkeypatch.setattr("prefect.task_engine.TaskRunEngine.start", always_error)
monkeypatch.setattr("prefect.task_engine.SyncTaskRunEngine.start", always_error)

task_worker = TaskWorker(empty_task)

Expand Down Expand Up @@ -794,17 +794,21 @@ async def mock_iter():
mock_subscription.return_value = mock_iter()

server_task = asyncio.create_task(task_worker.start())
await event.wait()
await event.wait() # Wait for the first task to set the event

# Give some additional time for the first task to complete
await asyncio.sleep(0.5)

updated_task_run_1 = await prefect_client.read_task_run(task_run_1.id)
updated_task_run_2 = await prefect_client.read_task_run(task_run_2.id)

assert updated_task_run_1.state.is_completed()
assert not updated_task_run_2.state.is_completed()

# clear the event to allow the second task to complete
# Clear the event to allow the second task to complete
event.clear()

await event.wait()
await event.wait() # Wait for the second task to set the event
updated_task_run_2 = await prefect_client.read_task_run(task_run_2.id)

assert updated_task_run_2.state.is_completed()
Expand Down

0 comments on commit 10621f8

Please sign in to comment.