Skip to content

Commit

Permalink
orchestrate task runs without calling the api
Browse files Browse the repository at this point in the history
  • Loading branch information
jakekaplan committed Jul 17, 2024
1 parent afdb04d commit d980dc9
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 33 deletions.
44 changes: 26 additions & 18 deletions src/prefect/task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import threading
import time
import uuid
from asyncio import CancelledError
from contextlib import ExitStack, contextmanager
from dataclasses import dataclass, field
Expand Down Expand Up @@ -53,6 +54,7 @@
from prefect.results import BaseResult, ResultFactory, _format_user_supplied_storage_key
from prefect.settings import (
PREFECT_DEBUG_MODE,
PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION,
PREFECT_TASKS_REFRESH_CACHE,
)
from prefect.states import (
Expand Down Expand Up @@ -299,9 +301,12 @@ def set_state(self, state: State, force: bool = False) -> State:
if not self.task_run:
raise ValueError("Task run is not set")
try:
new_state = propose_state_sync(
self.client, state, task_run_id=self.task_run.id, force=force
)
if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
new_state = state
else:
new_state = propose_state_sync(
self.client, state, task_run_id=self.task_run.id, force=force
)
except Pause as exc:
# We shouldn't get a pause signal without a state, but if this happens,
# just use a Paused state to assume an in-process pause.
Expand Down Expand Up @@ -460,7 +465,6 @@ def handle_crash(self, exc: BaseException) -> None:
@contextmanager
def setup_run_context(self, client: Optional[SyncPrefectClient] = None):
from prefect.utilities.engine import (
_resolve_custom_task_run_name,
should_log_prints,
)

Expand All @@ -469,7 +473,6 @@ def setup_run_context(self, client: Optional[SyncPrefectClient] = None):
if not self.task_run:
raise ValueError("Task run is not set")

self.task_run = client.read_task_run(self.task_run.id)
with ExitStack() as stack:
if log_prints := should_log_prints(self.task):
stack.enter_context(patch_print())
Expand All @@ -487,19 +490,20 @@ def setup_run_context(self, client: Optional[SyncPrefectClient] = None):
self.logger = task_run_logger(task_run=self.task_run, task=self.task) # type: ignore

# update the task run name if necessary
if not self._task_name_set and self.task.task_run_name:
task_run_name = _resolve_custom_task_run_name(
task=self.task, parameters=self.parameters
)
self.client.set_task_run_name(
task_run_id=self.task_run.id, name=task_run_name
)
self.logger.extra["task_run_name"] = task_run_name
self.logger.debug(
f"Renamed task run {self.task_run.name!r} to {task_run_name!r}"
)
self.task_run.name = task_run_name
self._task_name_set = True
# TODO: - do we need to even do this? can we do this when we create the run?
# if not self._task_name_set and self.task.task_run_name:
# task_run_name = _resolve_custom_task_run_name(
# task=self.task, parameters=self.parameters
# )
# self.client.set_task_run_name(
# task_run_id=self.task_run.id, name=task_run_name
# )
# self.logger.extra["task_run_name"] = task_run_name
# self.logger.debug(
# f"Renamed task run {self.task_run.name!r} to {task_run_name!r}"
# )
# self.task_run.name = task_run_name
# self._task_name_set = True
yield

@contextmanager
Expand Down Expand Up @@ -722,6 +726,8 @@ def run_task_sync(
context=context,
)

task_run_id = task_run_id or uuid.uuid4()

with engine.start(task_run_id=task_run_id, dependencies=dependencies):
while engine.is_running():
run_coro_as_sync(engine.wait_until_ready())
Expand Down Expand Up @@ -749,6 +755,8 @@ async def run_task_async(
context=context,
)

task_run_id = task_run_id or uuid.uuid4()

with engine.start(task_run_id=task_run_id, dependencies=dependencies):
while engine.is_running():
await engine.wait_until_ready()
Expand Down
75 changes: 60 additions & 15 deletions src/prefect/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,19 @@

from typing_extensions import Literal, ParamSpec

import prefect.states
from prefect._internal.compatibility.deprecated import (
deprecated_async_method,
)
from prefect.cache_policies import DEFAULT, NONE, CachePolicy
from prefect.client.orchestration import get_client
from prefect.client.schemas import TaskRun
from prefect.client.schemas.objects import TaskRunInput, TaskRunResult
from prefect.client.schemas.objects import (
StateDetails,
TaskRunInput,
TaskRunPolicy,
TaskRunResult,
)
from prefect.context import (
FlowRunContext,
TagsContext,
Expand All @@ -50,6 +56,7 @@
from prefect.logging.loggers import get_logger
from prefect.results import ResultFactory, ResultSerializer, ResultStorage
from prefect.settings import (
PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION,
PREFECT_TASK_DEFAULT_RETRIES,
PREFECT_TASK_DEFAULT_RETRY_DELAY_SECONDS,
)
Expand Down Expand Up @@ -766,23 +773,61 @@ async def create_run(
task_inputs[k] = task_inputs[k].union(extras)

# create the task run
task_run = client.create_task_run(
task=self,
name=task_run_name,
flow_run_id=(
if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
flow_run_id = (
getattr(flow_run_context.flow_run, "id", None)
if flow_run_context and flow_run_context.flow_run
else None
),
dynamic_key=str(dynamic_key),
id=id,
state=state,
task_inputs=task_inputs,
extra_tags=TagsContext.get().current_tags,
)
# the new engine uses sync clients but old engines use async clients
if inspect.isawaitable(task_run):
task_run = await task_run
)
state = prefect.states.Pending(
state_details=StateDetails(
task_run_id=id,
flow_run_id=flow_run_id,
)
)
task_run = TaskRun(
id=id,
name=task_run_name,
flow_run_id=flow_run_id,
task_key=self.task_key,
dynamic_key=str(dynamic_key),
task_version=self.version,
empirical_policy=TaskRunPolicy(
retries=self.retries,
retry_delay=self.retry_delay_seconds,
retry_jitter_factor=self.retry_jitter_factor,
),
tags=list(
set(self.tags).union(TagsContext.get().current_tags or [])
),
task_inputs=task_inputs or {},
expected_start_time=state.timestamp,
state_id=state.id,
state_type=state.type,
state_name=state.name,
state=state,
created=state.timestamp,
updated=state.timestamp,
)

else:
task_run = client.create_task_run(
task=self,
name=task_run_name,
flow_run_id=(
getattr(flow_run_context.flow_run, "id", None)
if flow_run_context and flow_run_context.flow_run
else None
),
dynamic_key=str(dynamic_key),
id=id,
state=state,
task_inputs=task_inputs,
extra_tags=TagsContext.get().current_tags,
)
# the new engine uses sync clients but old engines use async clients
if inspect.isawaitable(task_run):
task_run = await task_run

return task_run

Expand Down

0 comments on commit d980dc9

Please sign in to comment.