Skip to content

Commit

Permalink
Get restuls tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle committed Sep 7, 2024
1 parent 7ba4051 commit 80b36df
Show file tree
Hide file tree
Showing 14 changed files with 379 additions and 217 deletions.
10 changes: 8 additions & 2 deletions src/prefect/client/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2158,7 +2158,10 @@ async def set_flow_run_state(
try:
response = await self._client.post(
f"/flow_runs/{flow_run_id}/set_state",
json=dict(state=state_create.model_dump(mode="json"), force=force),
json=dict(
state=state_create.model_dump(mode="json", serialize_as_any=True),
force=force,
),
)
except httpx.HTTPStatusError as e:
if e.response.status_code == status.HTTP_404_NOT_FOUND:
Expand Down Expand Up @@ -3934,7 +3937,10 @@ def set_flow_run_state(
try:
response = self._client.post(
f"/flow_runs/{flow_run_id}/set_state",
json=dict(state=state_create.model_dump(mode="json"), force=force),
json=dict(
state=state_create.model_dump(mode="json", serialize_as_any=True),
force=force,
),
)
except httpx.HTTPStatusError as e:
if e.response.status_code == status.HTTP_404_NOT_FOUND:
Expand Down
6 changes: 4 additions & 2 deletions src/prefect/client/schemas/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,12 @@ def to_state_create(self):
results should be sent to the API. Other data is only available locally.
"""
from prefect.client.schemas.actions import StateCreate
from prefect.results import BaseResult
from prefect.results import BaseResult, ResultRecord, should_persist_result

if isinstance(self.data, BaseResult) and self.data.serialize_to_none is False:
if isinstance(self.data, BaseResult):
data = self.data
elif isinstance(self.data, ResultRecord) and should_persist_result():
data = self.data.metadata
else:
data = None

Expand Down
6 changes: 5 additions & 1 deletion src/prefect/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from prefect.client.schemas import FlowRun, TaskRun
from prefect.events.worker import EventsWorker
from prefect.exceptions import MissingContextError
from prefect.results import ResultStore
from prefect.results import ResultStore, get_default_persist_setting
from prefect.settings import PREFECT_HOME, Profile, Settings
from prefect.states import State
from prefect.task_runners import TaskRunner
Expand Down Expand Up @@ -343,6 +343,7 @@ class EngineContext(RunContext):

# Result handling
result_store: ResultStore
persist_result: bool = Field(default_factory=get_default_persist_setting)

# Counter for task calls allowing unique
task_run_dynamic_keys: Dict[str, int] = Field(default_factory=dict)
Expand Down Expand Up @@ -372,6 +373,7 @@ def serialize(self):
"start_time",
"input_keyset",
"result_store",
"persist_result",
},
exclude_unset=True,
)
Expand All @@ -397,6 +399,7 @@ class TaskRunContext(RunContext):

# Result handling
result_store: ResultStore
persist_result: bool = Field(default_factory=get_default_persist_setting)

__var__ = ContextVar("task_run")

Expand All @@ -410,6 +413,7 @@ def serialize(self):
"start_time",
"input_keyset",
"result_store",
"persist_result",
},
exclude_unset=True,
)
Expand Down
12 changes: 10 additions & 2 deletions src/prefect/flow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@
get_run_logger,
patch_print,
)
from prefect.results import BaseResult, ResultStore, get_current_result_store
from prefect.results import (
BaseResult,
ResultStore,
get_current_result_store,
should_persist_result,
)
from prefect.settings import PREFECT_DEBUG_MODE
from prefect.states import (
Failed,
Expand Down Expand Up @@ -271,7 +276,7 @@ def handle_success(self, result: R) -> R:
return_value_to_state(
resolved_result,
result_store=result_store,
write_result=True,
write_result=should_persist_result(),
)
)
self.set_state(terminal_state)
Expand Down Expand Up @@ -511,6 +516,9 @@ def setup_run_context(self, client: Optional[SyncPrefectClient] = None):
self.flow, _sync=True
),
task_runner=task_runner,
persist_result=self.flow.persist_result
if self.flow.persist_result is not None
else should_persist_result(),
)
)
stack.enter_context(ConcurrencyContextV1())
Expand Down
85 changes: 62 additions & 23 deletions src/prefect/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import threading
import uuid
from functools import partial
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -20,6 +21,7 @@
from uuid import UUID

import pendulum
from cachetools import LRUCache
from pydantic import (
BaseModel,
ConfigDict,
Expand All @@ -34,6 +36,7 @@
from typing_extensions import ParamSpec, Self

import prefect
from prefect._internal.compatibility.deprecated import deprecated_field
from prefect.blocks.core import Block
from prefect.client.utilities import inject_client
from prefect.exceptions import (
Expand Down Expand Up @@ -98,7 +101,7 @@ async def get_default_result_storage() -> WritableFileSystem:

@sync_compatible
async def resolve_result_storage(
result_storage: ResultStorage,
result_storage: Union[ResultStorage, UUID],
) -> WritableFileSystem:
"""
Resolve one of the valid `ResultStorage` input types into a saved block
Expand All @@ -119,6 +122,9 @@ async def resolve_result_storage(
storage_block = await Block.load(result_storage, client=client)
storage_block_id = storage_block._block_document_id
assert storage_block_id is not None, "Loaded storage blocks must have ids"
elif isinstance(result_storage, UUID):
block_document = await client.read_block_document(result_storage)
storage_block = Block._from_block_document(block_document)
else:
raise TypeError(
"Result storage must be one of the following types: 'UUID', 'Block', "
Expand Down Expand Up @@ -172,16 +178,38 @@ def get_default_persist_setting() -> bool:
return PREFECT_RESULTS_PERSIST_BY_DEFAULT.value()


def should_persist_result() -> bool:
"""
Return the default option for result persistence (False).
"""
from prefect.context import FlowRunContext, TaskRunContext

task_run_context = TaskRunContext.get()
if task_run_context is not None:
return task_run_context.persist_result
flow_run_context = FlowRunContext.get()
if flow_run_context is not None:
return flow_run_context.persist_result

return PREFECT_RESULTS_PERSIST_BY_DEFAULT.value()


def _format_user_supplied_storage_key(key: str) -> str:
# Note here we are pinning to task runs since flow runs do not support storage keys
# yet; we'll need to split logic in the future or have two separate functions
runtime_vars = {key: getattr(prefect.runtime, key) for key in dir(prefect.runtime)}
return key.format(**runtime_vars, parameters=prefect.runtime.task_run.parameters)


@deprecated_field(
"persist_result",
when=lambda x: x is not None,
when_message="use the `should_persist_result` utility function instead",
start_date="Sep 2024",
)
class ResultStore(BaseModel):
"""
Manages the storage and retrieval of results.
Manages the storage and retrieval of results.ff
Attributes:
result_storage: The storage for result records. If not provided, the default
Expand All @@ -201,10 +229,13 @@ class ResultStore(BaseModel):
result_storage: Optional[WritableFileSystem] = Field(default=None)
metadata_storage: Optional[WritableFileSystem] = Field(default=None)
lock_manager: Optional[LockManager] = Field(default=None)
persist_result: bool = Field(default_factory=get_default_persist_setting)
cache_result_in_memory: bool = Field(default=True)
serializer: Serializer = Field(default_factory=get_default_result_serializer)
storage_key_fn: Callable[[], str] = Field(default=DEFAULT_STORAGE_KEY_FN)
cache: LRUCache = Field(default_factory=lambda: LRUCache(maxsize=1000))

# Deprecated fields
persist_result: Optional[bool] = Field(default=None)

@property
def result_storage_block_id(self) -> Optional[UUID]:
Expand All @@ -228,8 +259,6 @@ async def update_for_flow(self, flow: "Flow") -> Self:
update["result_storage"] = await resolve_result_storage(flow.result_storage)
if flow.result_serializer is not None:
update["serializer"] = resolve_serializer(flow.result_serializer)
if flow.persist_result is not None:
update["persist_result"] = flow.persist_result
if flow.cache_result_in_memory is not None:
update["cache_result_in_memory"] = flow.cache_result_in_memory
if self.result_storage is None and update.get("result_storage") is None:
Expand All @@ -252,8 +281,6 @@ async def update_for_task(self: Self, task: "Task") -> Self:
update["result_storage"] = await resolve_result_storage(task.result_storage)
if task.result_serializer is not None:
update["serializer"] = resolve_serializer(task.result_serializer)
if task.persist_result is not None:
update["persist_result"] = task.persist_result
if task.cache_result_in_memory is not None:
update["cache_result_in_memory"] = task.cache_result_in_memory
if task.result_storage_key is not None:
Expand Down Expand Up @@ -357,9 +384,13 @@ async def _read(self, key: str, holder: str) -> "ResultRecord":
Returns:
A result record.
"""

if self.lock_manager is not None and not self.is_lock_holder(key, holder):
await self.await_for_lock(key)

if key in self.cache:
return self.cache[key]

if self.result_storage is None:
self.result_storage = await get_default_result_storage()

Expand All @@ -370,12 +401,18 @@ async def _read(self, key: str, holder: str) -> "ResultRecord":
metadata.storage_key is not None
), "Did not find storage key in metadata"
result_content = await self.result_storage.read_path(metadata.storage_key)
return ResultRecord.deserialize_from_result_and_metadata(
result_record = ResultRecord.deserialize_from_result_and_metadata(
result=result_content, metadata=metadata_content
)
if self.cache_result_in_memory:
self.cache[key] = result_record
return result_record
else:
content = await self.result_storage.read_path(key)
return ResultRecord.deserialize(content)
result_record = ResultRecord.deserialize(content)
if self.cache_result_in_memory:
self.cache[key] = result_record
return result_record

def read(self, key: str, holder: Optional[str] = None) -> "ResultRecord":
"""
Expand Down Expand Up @@ -433,7 +470,6 @@ def create_result_record(
expiration=expiration,
storage_key=key,
storage_block_id=self.result_storage_block_id,
serialize_to_none=not self.persist_result,
),
)

Expand Down Expand Up @@ -500,9 +536,6 @@ async def _persist_result_record(self, result_record: "ResultRecord", holder: st
result_record.metadata.storage_key is not None
), "Storage key is required on result record"

if not self.persist_result:
return

key = result_record.metadata.storage_key
if (
self.lock_manager is not None
Expand All @@ -522,8 +555,19 @@ async def _persist_result_record(self, result_record: "ResultRecord", holder: st
result_record.metadata.storage_key,
content=result_record.serialize_result(),
)
if result_record.metadata.storage_block_id is None:
basepath = (
Path(self.result_storage.basepath)
if hasattr(self.result_storage, "basepath")
else Path(".")
)
metadata_key = str(
Path(result_record.metadata.storage_key).relative_to(basepath)
)
else:
metadata_key = result_record.metadata.storage_key
await self.metadata_storage.write_path(
result_record.metadata.storage_key,
metadata_key,
content=result_record.serialize_metadata(),
)
# Otherwise, write the result metadata and result together
Expand All @@ -532,6 +576,9 @@ async def _persist_result_record(self, result_record: "ResultRecord", holder: st
result_record.metadata.storage_key, content=result_record.serialize()
)

if self.cache_result_in_memory:
self.cache[key] = result_record

def persist_result_record(
self, result_record: "ResultRecord", holder: Optional[str] = None
):
Expand Down Expand Up @@ -730,7 +777,7 @@ def key_fn():
serializer=self.serializer,
cache_object=should_cache_object,
expiration=expiration,
serialize_to_none=not self.persist_result,
serialize_to_none=not should_persist_result(),
)

# TODO: These two methods need to find a new home
Expand Down Expand Up @@ -782,13 +829,6 @@ class ResultRecordMetadata(BaseModel):
serializer: Serializer = Field(default_factory=PickleSerializer)
prefect_version: str = Field(default=prefect.__version__)
storage_block_id: Optional[uuid.UUID] = Field(default=None)
serialize_to_none: bool = Field(default=False)

@model_serializer(mode="wrap")
def serialize_model(self, handler, info):
if self.serialize_to_none:
return None
return handler(self, info)

def dump_bytes(self) -> bytes:
"""
Expand Down Expand Up @@ -1127,7 +1167,6 @@ async def write(self, obj: R = NotSet, client: "PrefectClient" = None) -> None:
result_store = ResultStore(
result_storage=storage_block,
serializer=serializer,
persist_result=not self.serialize_to_none,
)
await result_store.awrite(
obj=obj, key=self.storage_key, expiration=self.expiration
Expand Down
3 changes: 3 additions & 0 deletions src/prefect/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,9 @@ async def get_state_exception(state: State) -> BaseException:
result = await _get_state_result_data_with_retries(state)
elif isinstance(state.data, ResultRecord):
result = state.data.result
elif isinstance(state.data, ResultRecordMetadata):
record = await ResultRecord.from_metadata(state.data)
result = record.result
elif state.data is None:
result = None
else:
Expand Down
9 changes: 9 additions & 0 deletions src/prefect/task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
ResultRecord,
_format_user_supplied_storage_key,
get_current_result_store,
should_persist_result,
)
from prefect.settings import (
PREFECT_DEBUG_MODE,
Expand Down Expand Up @@ -597,6 +598,9 @@ def setup_run_context(self, client: Optional[SyncPrefectClient] = None):
self.task, _sync=True
),
client=client,
persist_result=self.task.persist_result
if self.task.persist_result is not None
else should_persist_result(),
)
)
stack.enter_context(ConcurrencyContextV1())
Expand Down Expand Up @@ -726,6 +730,7 @@ def transaction_context(self) -> Generator[Transaction, None, None]:
store=get_current_result_store(),
overwrite=overwrite,
logger=self.logger,
write_on_commit=should_persist_result(),
) as txn:
yield txn

Expand Down Expand Up @@ -1096,6 +1101,9 @@ async def setup_run_context(self, client: Optional[PrefectClient] = None):
self.task, _sync=False
),
client=client,
persist_result=self.task.persist_result
if self.task.persist_result is not None
else should_persist_result(),
)
)
stack.enter_context(ConcurrencyContext())
Expand Down Expand Up @@ -1222,6 +1230,7 @@ async def transaction_context(self) -> AsyncGenerator[Transaction, None]:
store=get_current_result_store(),
overwrite=overwrite,
logger=self.logger,
write_on_commit=should_persist_result(),
) as txn:
yield txn

Expand Down
Loading

0 comments on commit 80b36df

Please sign in to comment.