Skip to content

Commit

Permalink
Use ResultStore by default for task transactions
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle committed Sep 6, 2024
1 parent a571cf6 commit f9ec197
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 61 deletions.
21 changes: 18 additions & 3 deletions src/prefect/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from uuid import UUID

import pendulum
from pydantic import (
BaseModel,
ConfigDict,
Expand Down Expand Up @@ -293,16 +294,30 @@ async def _exists(self, key: str) -> bool:
# so the entire payload doesn't need to be read
try:
metadata_content = await self.metadata_storage.read_path(key)
return metadata_content is not None
if metadata_content is None:
return False
metadata = ResultRecordMetadata.load_bytes(metadata_content)

except Exception:
return False
else:
try:
content = await self.result_storage.read_path(key)
return content is not None
if content is None:
return False
record = ResultRecord.deserialize(content)
metadata = record.metadata
except Exception:
return False

if metadata.expiration:
# if the result has an expiration,
# check if it is still in the future
exists = metadata.expiration > pendulum.now("utc")
else:
exists = True
return exists

def exists(self, key: str) -> bool:
"""
Check if a result record exists in storage.
Expand Down Expand Up @@ -390,8 +405,8 @@ async def aread(self, key: str, holder: Optional[str] = None) -> "ResultRecord":

def create_result_record(
self,
key: str,
obj: Any,
key: Optional[str] = None,
expiration: Optional[DateTime] = None,
):
"""
Expand Down
32 changes: 18 additions & 14 deletions src/prefect/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
UnfinishedRun,
)
from prefect.logging.loggers import get_logger, get_run_logger
from prefect.results import BaseResult, R, ResultStore
from prefect.results import BaseResult, R, ResultRecord, ResultStore
from prefect.settings import PREFECT_ASYNC_FETCH_STATE_RESULT
from prefect.utilities.annotations import BaseAnnotation
from prefect.utilities.asyncutils import in_async_main_thread, sync_compatible
Expand Down Expand Up @@ -131,6 +131,8 @@ async def _get_state_result(
result = await _get_state_result_data_with_retries(
state, retry_result_failure=retry_result_failure
)
elif isinstance(state.data, ResultRecord):
result = state.data.result

elif state.data is None:
if state.is_failed() or state.is_crashed() or state.is_cancelled():
Expand Down Expand Up @@ -207,7 +209,7 @@ async def exception_to_crashed_state(
)

if result_store:
data = await result_store.create_result(exc)
data = result_store.create_result_record(exc)
else:
# Attach the exception for local usage, will not be available when retrieved
# from the API
Expand Down Expand Up @@ -240,10 +242,10 @@ async def exception_to_failed_state(
pass

if result_store:
data = await result_store.create_result(exc)
data = result_store.create_result_record(exc)
if write_result:
try:
await data.write()
await result_store.apersist_result_record(data)
except Exception as exc:
local_logger.warning(
"Failed to write result: %s Execution will continue, but the result has not been written",
Expand Down Expand Up @@ -310,20 +312,20 @@ async def return_value_to_state(
# Unless the user has already constructed a result explicitly, use the store
# to update the data to the correct type
if not isinstance(state.data, BaseResult):
result = await result_store.create_result(
result_record = result_store.create_result_record(
state.data,
key=key,
expiration=expiration,
)
if write_result:
try:
await result.write()
await result_store.apersist_result_record(result_record)
except Exception as exc:
local_logger.warning(
"Encountered an error while persisting result: %s Execution will continue, but the result has not been persisted",
exc,
)
state.data = result
state.data = result_record
return state

# Determine a new state from the aggregate of contained states
Expand Down Expand Up @@ -359,14 +361,14 @@ async def return_value_to_state(
# TODO: We may actually want to set the data to a `StateGroup` object and just
# allow it to be unpacked into a tuple and such so users can interact with
# it
result = await result_store.create_result(
result_record = result_store.create_result_record(
retval,
key=key,
expiration=expiration,
)
if write_result:
try:
await result.write()
await result_store.apersist_result_record(result_record)
except Exception as exc:
local_logger.warning(
"Encountered an error while persisting result: %s Execution will continue, but the result has not been persisted",
Expand All @@ -375,7 +377,7 @@ async def return_value_to_state(
return State(
type=new_state_type,
message=message,
data=result,
data=result_record,
)

# Generators aren't portable, implicitly convert them to a list.
Expand All @@ -385,23 +387,23 @@ async def return_value_to_state(
data = retval

# Otherwise, they just gave data and this is a completed retval
if isinstance(data, BaseResult):
if isinstance(data, (BaseResult, ResultRecord)):
return Completed(data=data)
else:
result = await result_store.create_result(
result_record = result_store.create_result_record(
data,
key=key,
expiration=expiration,
)
if write_result:
try:
await result.write()
await result_store.apersist_result_record(result_record)
except Exception as exc:
local_logger.warning(
"Encountered an error while persisting result: %s Execution will continue, but the result has not been persisted",
exc,
)
return Completed(data=result)
return Completed(data=result_record)


@sync_compatible
Expand Down Expand Up @@ -442,6 +444,8 @@ async def get_state_exception(state: State) -> BaseException:

if isinstance(state.data, BaseResult):
result = await _get_state_result_data_with_retries(state)
elif isinstance(state.data, ResultRecord):
result = state.data.result
elif state.data is None:
result = None
else:
Expand Down
39 changes: 11 additions & 28 deletions src/prefect/task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@
)
from prefect.futures import PrefectFuture
from prefect.logging.loggers import get_logger, patch_print, task_run_logger
from prefect.records.result_store import ResultRecordStore
from prefect.results import (
BaseResult,
ResultRecord,
_format_user_supplied_storage_key,
get_current_result_store,
)
Expand Down Expand Up @@ -418,6 +418,8 @@ def set_state(self, state: State, force: bool = False) -> State:
result = state.result(raise_on_failure=False, fetch=True)
if inspect.isawaitable(result):
result = run_coro_as_sync(result)
elif isinstance(state.data, ResultRecord):
result = state.data.result
else:
result = state.data

Expand Down Expand Up @@ -454,10 +456,6 @@ def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
return self._raised

def handle_success(self, result: R, transaction: Transaction) -> R:
result_store = getattr(TaskRunContext.get(), "result_store", None)
if result_store is None:
raise ValueError("Result store is not set")

if self.task.cache_expiration is not None:
expiration = pendulum.now("utc") + self.task.cache_expiration
else:
Expand All @@ -466,7 +464,7 @@ def handle_success(self, result: R, transaction: Transaction) -> R:
terminal_state = run_coro_as_sync(
return_value_to_state(
result,
result_store=result_store,
result_store=get_current_result_store(),
key=transaction.key,
expiration=expiration,
)
Expand Down Expand Up @@ -538,12 +536,11 @@ def handle_exception(self, exc: Exception) -> None:
# If the task fails, and we have retries left, set the task to retrying.
if not self.handle_retry(exc):
# If the task has no retries left, or the retry condition is not met, set the task to failed.
context = TaskRunContext.get()
state = run_coro_as_sync(
exception_to_failed_state(
exc,
message="Task run encountered an exception",
result_store=getattr(context, "result_store", None),
result_store=get_current_result_store(),
write_result=True,
)
)
Expand Down Expand Up @@ -723,15 +720,9 @@ def transaction_context(self) -> Generator[Transaction, None, None]:
else PREFECT_TASKS_REFRESH_CACHE.value()
)

result_store = getattr(TaskRunContext.get(), "result_store", None)
if result_store and result_store.persist_result:
store = ResultRecordStore(result_store=result_store)
else:
store = None

with transaction(
key=self.compute_transaction_key(),
store=store,
store=get_current_result_store(),
overwrite=overwrite,
logger=self.logger,
) as txn:
Expand Down Expand Up @@ -933,6 +924,8 @@ async def set_state(self, state: State, force: bool = False) -> State:
# Avoid fetching the result unless it is cached, otherwise we defeat
# the purpose of disabling `cache_result_in_memory`
result = await new_state.result(raise_on_failure=False, fetch=True)
elif isinstance(new_state.data, ResultRecord):
result = new_state.data.result
else:
result = new_state.data

Expand Down Expand Up @@ -966,18 +959,14 @@ async def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]"
return self._raised

async def handle_success(self, result: R, transaction: Transaction) -> R:
result_store = getattr(TaskRunContext.get(), "result_store", None)
if result_store is None:
raise ValueError("Result store is not set")

if self.task.cache_expiration is not None:
expiration = pendulum.now("utc") + self.task.cache_expiration
else:
expiration = None

terminal_state = await return_value_to_state(
result,
result_store=result_store,
result_store=get_current_result_store(),
key=transaction.key,
expiration=expiration,
)
Expand Down Expand Up @@ -1048,11 +1037,10 @@ async def handle_exception(self, exc: Exception) -> None:
# If the task fails, and we have retries left, set the task to retrying.
if not await self.handle_retry(exc):
# If the task has no retries left, or the retry condition is not met, set the task to failed.
context = TaskRunContext.get()
state = await exception_to_failed_state(
exc,
message="Task run encountered an exception",
result_store=getattr(context, "result_store", None),
result_store=get_current_result_store(),
)
self.record_terminal_state_timing(state)
await self.set_state(state)
Expand Down Expand Up @@ -1226,15 +1214,10 @@ async def transaction_context(self) -> AsyncGenerator[Transaction, None]:
if self.task.refresh_cache is not None
else PREFECT_TASKS_REFRESH_CACHE.value()
)
result_store = getattr(TaskRunContext.get(), "result_store", None)
if result_store and result_store.persist_result:
store = ResultRecordStore(result_store=result_store)
else:
store = None

with transaction(
key=self.compute_transaction_key(),
store=store,
store=get_current_result_store(),
overwrite=overwrite,
logger=self.logger,
) as txn:
Expand Down
2 changes: 2 additions & 0 deletions src/prefect/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ def commit(self) -> bool:
if isinstance(self.store, ResultStore):
if isinstance(self._staged_value, BaseResult):
self.store.write(self.key, self._staged_value.get(_sync=True))
elif isinstance(self._staged_value, ResultRecord):
self.store.persist_result_record(self._staged_value)
else:
self.store.write(self.key, self._staged_value)
else:
Expand Down
13 changes: 8 additions & 5 deletions src/prefect/utilities/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,11 @@
)
from prefect.flows import Flow
from prefect.futures import PrefectFuture
from prefect.futures import PrefectFuture as NewPrefectFuture
from prefect.logging.loggers import (
get_logger,
task_run_logger,
)
from prefect.results import BaseResult
from prefect.results import BaseResult, ResultRecord
from prefect.settings import (
PREFECT_LOGGING_LOG_PRINTS,
)
Expand Down Expand Up @@ -122,7 +121,7 @@ def add_futures_and_states_to_inputs(obj):


def collect_task_run_inputs_sync(
expr: Any, future_cls: Any = NewPrefectFuture, max_depth: int = -1
expr: Any, future_cls: Any = PrefectFuture, max_depth: int = -1
) -> Set[TaskRunInput]:
"""
This function recurses through an expression to generate a set of any discernible
Expand All @@ -131,7 +130,7 @@ def collect_task_run_inputs_sync(
Examples:
>>> task_inputs = {
>>> k: collect_task_run_inputs(v) for k, v in parameters.items()
>>> k: collect_task_run_inputs_sync(v) for k, v in parameters.items()
>>> }
"""
# TODO: This function needs to be updated to detect parameters and constants
Expand Down Expand Up @@ -401,6 +400,8 @@ async def propose_state(
# Avoid fetching the result unless it is cached, otherwise we defeat
# the purpose of disabling `cache_result_in_memory`
result = await state.result(raise_on_failure=False, fetch=True)
elif isinstance(state.data, ResultRecord):
result = state.data.result
else:
result = state.data

Expand Down Expand Up @@ -504,6 +505,8 @@ def propose_state_sync(
result = state.result(raise_on_failure=False, fetch=True)
if inspect.isawaitable(result):
result = run_coro_as_sync(result)
elif isinstance(state.data, ResultRecord):
result = state.data.result
else:
result = state.data

Expand Down Expand Up @@ -822,7 +825,7 @@ def resolve_to_final_result(expr, context):
if isinstance(context.get("annotation"), quote):
raise StopVisiting()

if isinstance(expr, NewPrefectFuture):
if isinstance(expr, PrefectFuture):
upstream_task_run = context.get("current_task_run")
upstream_task = context.get("current_task")
if (
Expand Down
Loading

0 comments on commit f9ec197

Please sign in to comment.