Skip to content

Commit

Permalink
Showing 13 changed files with 816 additions and 697 deletions.
3 changes: 2 additions & 1 deletion src/prefect/_internal/schemas/validators.py
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@
from copy import copy
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union
from uuid import UUID

import jsonschema
import pendulum
@@ -653,7 +654,7 @@ def validate_message_template_variables(v: Optional[str]) -> Optional[str]:
return v


def validate_default_queue_id_not_none(v: Optional[str]) -> Optional[str]:
def validate_default_queue_id_not_none(v: Optional[UUID]) -> UUID:
if v is None:
raise ValueError(
"`default_queue_id` is a required field. If you are "
4 changes: 3 additions & 1 deletion src/prefect/client/__init__.py
Original file line number Diff line number Diff line change
@@ -16,6 +16,8 @@
</div>
"""

from collections.abc import Callable
from typing import Any
from prefect._internal.compatibility.migration import getattr_migration

__getattr__ = getattr_migration(__name__)
__getattr__: Callable[[str], Any] = getattr_migration(__name__)
55 changes: 27 additions & 28 deletions src/prefect/client/base.py
Original file line number Diff line number Diff line change
@@ -4,22 +4,11 @@
import time
import uuid
from collections import defaultdict
from collections.abc import AsyncGenerator, Awaitable, MutableMapping
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
Dict,
MutableMapping,
Optional,
Protocol,
Set,
Tuple,
Type,
runtime_checkable,
)
from logging import Logger
from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, runtime_checkable

import anyio
import httpx
@@ -46,14 +35,14 @@

# Datastores for lifespan management, keys should be a tuple of thread and app
# identities.
APP_LIFESPANS: Dict[Tuple[int, int], LifespanManager] = {}
APP_LIFESPANS_REF_COUNTS: Dict[Tuple[int, int], int] = {}
APP_LIFESPANS: dict[tuple[int, int], LifespanManager] = {}
APP_LIFESPANS_REF_COUNTS: dict[tuple[int, int], int] = {}
# Blocks concurrent access to the above dicts per thread. The index should be the thread
# identity.
APP_LIFESPANS_LOCKS: Dict[int, anyio.Lock] = defaultdict(anyio.Lock)
APP_LIFESPANS_LOCKS: dict[int, anyio.Lock] = defaultdict(anyio.Lock)


logger = get_logger("client")
logger: Logger = get_logger("client")


# Define ASGI application types for type checking
@@ -174,9 +163,9 @@ def raise_for_status(self) -> Response:
raise PrefectHTTPStatusError.from_httpx_error(exc) from exc.__cause__

@classmethod
def from_httpx_response(cls: Type[Self], response: httpx.Response) -> Response:
def from_httpx_response(cls: type[Self], response: httpx.Response) -> Response:
"""
Create a `PrefectReponse` from an `httpx.Response`.
Create a `PrefectResponse` from an `httpx.Response`.
By changing the `__class__` attribute of the Response, we change the method
resolution order to look for methods defined in PrefectResponse, while leaving
@@ -222,10 +211,10 @@ async def _send_with_retry(
self,
request: Request,
send: Callable[[Request], Awaitable[Response]],
send_args: Tuple[Any, ...],
send_kwargs: Dict[str, Any],
retry_codes: Set[int] = set(),
retry_exceptions: Tuple[Type[Exception], ...] = tuple(),
send_args: tuple[Any, ...],
send_kwargs: dict[str, Any],
retry_codes: set[int] = set(),
retry_exceptions: tuple[type[Exception], ...] = tuple(),
):
"""
Send a request and retry it if it fails.
@@ -240,6 +229,11 @@ async def _send_with_retry(
try_count = 0
response = None

if TYPE_CHECKING:
# older httpx versions type method as str | bytes | Unknown
# but in reality it is always a string.
assert isinstance(request.method, str) # type: ignore

is_change_request = request.method.lower() in {"post", "put", "patch", "delete"}

if self.enable_csrf_support and is_change_request:
@@ -436,10 +430,10 @@ def _send_with_retry(
self,
request: Request,
send: Callable[[Request], Response],
send_args: Tuple[Any, ...],
send_kwargs: Dict[str, Any],
retry_codes: Set[int] = set(),
retry_exceptions: Tuple[Type[Exception], ...] = tuple(),
send_args: tuple[Any, ...],
send_kwargs: dict[str, Any],
retry_codes: set[int] = set(),
retry_exceptions: tuple[type[Exception], ...] = tuple(),
):
"""
Send a request and retry it if it fails.
@@ -454,6 +448,11 @@ def _send_with_retry(
try_count = 0
response = None

if TYPE_CHECKING:
# older httpx versions type method as str | bytes | Unknown
# but in reality it is always a string.
assert isinstance(request.method, str) # type: ignore

is_change_request = request.method.lower() in {"post", "put", "patch", "delete"}

if self.enable_csrf_support and is_change_request:
36 changes: 20 additions & 16 deletions src/prefect/client/cloud.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import re
from typing import Any, Dict, List, Optional, cast
from typing import Any, NoReturn, Optional, cast
from uuid import UUID

import anyio
import httpx
import pydantic
from starlette import status
from typing_extensions import Self

import prefect.context
import prefect.settings
@@ -30,7 +31,7 @@
def get_cloud_client(
host: Optional[str] = None,
api_key: Optional[str] = None,
httpx_settings: Optional[Dict[str, Any]] = None,
httpx_settings: Optional[dict[str, Any]] = None,
infer_cloud_url: bool = False,
) -> "CloudClient":
"""
@@ -62,11 +63,14 @@ class CloudUnauthorizedError(PrefectException):


class CloudClient:
account_id: Optional[str] = None
workspace_id: Optional[str] = None

def __init__(
self,
host: str,
api_key: str,
httpx_settings: Optional[Dict[str, Any]] = None,
httpx_settings: Optional[dict[str, Any]] = None,
) -> None:
httpx_settings = httpx_settings or dict()
httpx_settings.setdefault("headers", dict())
@@ -79,7 +83,7 @@ def __init__(
**httpx_settings, enable_csrf_support=False
)

api_url = prefect.settings.PREFECT_API_URL.value() or ""
api_url: str = prefect.settings.PREFECT_API_URL.value() or ""
if match := (
re.search(PARSE_API_URL_REGEX, host)
or re.search(PARSE_API_URL_REGEX, api_url)
@@ -100,7 +104,7 @@ def workspace_base_url(self) -> str:

return f"{self.account_base_url}/workspaces/{self.workspace_id}"

async def api_healthcheck(self):
async def api_healthcheck(self) -> None:
"""
Attempts to connect to the Cloud API and raises the encountered exception if not
successful.
@@ -110,8 +114,8 @@ async def api_healthcheck(self):
with anyio.fail_after(10):
await self.read_workspaces()

async def read_workspaces(self) -> List[Workspace]:
workspaces = pydantic.TypeAdapter(List[Workspace]).validate_python(
async def read_workspaces(self) -> list[Workspace]:
workspaces = pydantic.TypeAdapter(list[Workspace]).validate_python(
await self.get("/me/workspaces")
)
return workspaces
@@ -124,17 +128,17 @@ async def read_current_workspace(self) -> Workspace:
return workspace
raise ValueError("Current workspace not found")

async def read_worker_metadata(self) -> Dict[str, Any]:
async def read_worker_metadata(self) -> dict[str, Any]:
response = await self.get(
f"{self.workspace_base_url}/collections/work_pool_types"
)
return cast(Dict[str, Any], response)
return cast(dict[str, Any], response)

async def read_account_settings(self) -> Dict[str, Any]:
async def read_account_settings(self) -> dict[str, Any]:
response = await self.get(f"{self.account_base_url}/settings")
return cast(Dict[str, Any], response)
return cast(dict[str, Any], response)

async def update_account_settings(self, settings: Dict[str, Any]):
async def update_account_settings(self, settings: dict[str, Any]) -> None:
await self.request(
"PATCH",
f"{self.account_base_url}/settings",
@@ -145,7 +149,7 @@ async def read_account_ip_allowlist(self) -> IPAllowlist:
response = await self.get(f"{self.account_base_url}/ip_allowlist")
return IPAllowlist.model_validate(response)

async def update_account_ip_allowlist(self, updated_allowlist: IPAllowlist):
async def update_account_ip_allowlist(self, updated_allowlist: IPAllowlist) -> None:
await self.request(
"PUT",
f"{self.account_base_url}/ip_allowlist",
@@ -175,20 +179,20 @@ async def update_flow_run_labels(
json=labels,
)

async def __aenter__(self):
async def __aenter__(self) -> Self:
await self._client.__aenter__()
return self

async def __aexit__(self, *exc_info: Any) -> None:
return await self._client.__aexit__(*exc_info)

def __enter__(self):
def __enter__(self) -> NoReturn:
raise RuntimeError(
"The `CloudClient` must be entered with an async context. Use 'async "
"with CloudClient(...)' not 'with CloudClient(...)'"
)

def __exit__(self, *_):
def __exit__(self, *_: object) -> NoReturn:
assert False, "This should never be called but must be defined for __enter__"

async def get(self, route: str, **kwargs: Any) -> Any:
572 changes: 299 additions & 273 deletions src/prefect/client/orchestration.py

Large diffs are not rendered by default.

24 changes: 24 additions & 0 deletions src/prefect/client/schemas/__init__.py
Original file line number Diff line number Diff line change
@@ -25,3 +25,27 @@
StateAcceptDetails,
StateRejectDetails,
)

__all__ = (
"BlockDocument",
"BlockSchema",
"BlockType",
"BlockTypeUpdate",
"DEFAULT_BLOCK_SCHEMA_VERSION",
"FlowRun",
"FlowRunPolicy",
"OrchestrationResult",
"SetStateStatus",
"State",
"StateAbortDetails",
"StateAcceptDetails",
"StateCreate",
"StateDetails",
"StateRejectDetails",
"StateType",
"TaskRun",
"TaskRunInput",
"TaskRunPolicy",
"TaskRunResult",
"Workspace",
)
246 changes: 126 additions & 120 deletions src/prefect/client/schemas/actions.py

Large diffs are not rendered by default.

187 changes: 108 additions & 79 deletions src/prefect/client/schemas/objects.py

Large diffs are not rendered by default.

36 changes: 18 additions & 18 deletions src/prefect/client/schemas/responses.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import datetime
from typing import Any, Dict, List, Optional, TypeVar, Union
from typing import Any, ClassVar, Generic, Optional, TypeVar, Union
from uuid import UUID

from pydantic import ConfigDict, Field
@@ -13,7 +13,7 @@
from prefect.utilities.collections import AutoEnum
from prefect.utilities.names import generate_slug

R = TypeVar("R")
T = TypeVar("T")


class SetStateStatus(AutoEnum):
@@ -120,7 +120,7 @@ class HistoryResponse(PrefectBaseModel):
interval_end: DateTime = Field(
default=..., description="The end date of the interval."
)
states: List[HistoryResponseState] = Field(
states: list[HistoryResponseState] = Field(
default=..., description="A list of state histories during the interval."
)

@@ -130,18 +130,18 @@ class HistoryResponse(PrefectBaseModel):
]


class OrchestrationResult(PrefectBaseModel):
class OrchestrationResult(PrefectBaseModel, Generic[T]):
"""
A container for the output of state orchestration.
"""

state: Optional[objects.State]
state: Optional[objects.State[T]]
status: SetStateStatus
details: StateResponseDetails


class WorkerFlowRunResponse(PrefectBaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)

work_pool_id: UUID
work_queue_id: UUID
@@ -179,7 +179,7 @@ class FlowRunResponse(ObjectBaseModel):
description="The version of the flow executed in this flow run.",
examples=["1.0"],
)
parameters: Dict[str, Any] = Field(
parameters: dict[str, Any] = Field(
default_factory=dict, description="Parameters for the flow run."
)
idempotency_key: Optional[str] = Field(
@@ -189,15 +189,15 @@ class FlowRunResponse(ObjectBaseModel):
" run is not created multiple times."
),
)
context: Dict[str, Any] = Field(
context: dict[str, Any] = Field(
default_factory=dict,
description="Additional context for the flow run.",
examples=[{"my_var": "my_val"}],
)
empirical_policy: objects.FlowRunPolicy = Field(
default_factory=objects.FlowRunPolicy,
)
tags: List[str] = Field(
tags: list[str] = Field(
default_factory=list,
description="A list of tags on the flow run",
examples=[["tag-1", "tag-2"]],
@@ -275,7 +275,7 @@ class FlowRunResponse(ObjectBaseModel):
description="The state of the flow run.",
examples=["objects.State(type=objects.StateType.COMPLETED)"],
)
job_variables: Optional[dict] = Field(
job_variables: Optional[dict[str, Any]] = Field(
default=None, description="Job variables for the flow run."
)

@@ -335,22 +335,22 @@ class DeploymentResponse(ObjectBaseModel):
default=None,
description="The concurrency options for the deployment.",
)
schedules: List[objects.DeploymentSchedule] = Field(
schedules: list[objects.DeploymentSchedule] = Field(
default_factory=list, description="A list of schedules for the deployment."
)
job_variables: Dict[str, Any] = Field(
job_variables: dict[str, Any] = Field(
default_factory=dict,
description="Overrides to apply to flow run infrastructure at runtime.",
)
parameters: Dict[str, Any] = Field(
parameters: dict[str, Any] = Field(
default_factory=dict,
description="Parameters for flow runs scheduled by the deployment.",
)
pull_steps: Optional[List[dict]] = Field(
pull_steps: Optional[list[dict[str, Any]]] = Field(
default=None,
description="Pull steps for cloning and running this deployment.",
)
tags: List[str] = Field(
tags: list[str] = Field(
default_factory=list,
description="A list of tags for the deployment",
examples=[["tag-1", "tag-2"]],
@@ -367,7 +367,7 @@ class DeploymentResponse(ObjectBaseModel):
default=None,
description="The last time the deployment was polled for status updates.",
)
parameter_openapi_schema: Optional[Dict[str, Any]] = Field(
parameter_openapi_schema: Optional[dict[str, Any]] = Field(
default=None,
description="The parameter schema of the flow, including defaults.",
)
@@ -400,7 +400,7 @@ class DeploymentResponse(ObjectBaseModel):
default=None,
description="Optional information about the updater of this deployment.",
)
work_queue_id: UUID = Field(
work_queue_id: Optional[UUID] = Field(
default=None,
description=(
"The id of the work pool queue to which this deployment is assigned."
@@ -423,7 +423,7 @@ class DeploymentResponse(ObjectBaseModel):


class MinimalConcurrencyLimitResponse(PrefectBaseModel):
model_config = ConfigDict(extra="ignore")
model_config: ClassVar[ConfigDict] = ConfigDict(extra="ignore")

id: UUID
name: str
229 changes: 136 additions & 93 deletions src/prefect/client/schemas/schedules.py
Original file line number Diff line number Diff line change
@@ -3,13 +3,13 @@
"""

import datetime
from typing import Annotated, Any, Optional, Union
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Optional, Union

import dateutil
import dateutil.rrule
import dateutil.tz
import pendulum
from pydantic import AfterValidator, ConfigDict, Field, field_validator, model_validator
from pydantic_extra_types.pendulum_dt import DateTime
from typing_extensions import TypeAlias, TypeGuard

from prefect._internal.schemas.bases import PrefectBaseModel
@@ -20,6 +20,14 @@
validate_rrule_string,
)

if TYPE_CHECKING:
# type checkers have difficulty accepting that
# pydantic_extra_types.pendulum_dt and pendulum.DateTime can be used
# together.
DateTime = pendulum.DateTime
else:
from pydantic_extra_types.pendulum_dt import DateTime

MAX_ITERATIONS = 1000
# approx. 1 years worth of RDATEs + buffer
MAX_RRULE_LENGTH = 6500
@@ -54,7 +62,7 @@ class IntervalSchedule(PrefectBaseModel):
timezone (str, optional): a valid timezone string
"""

model_config = ConfigDict(extra="forbid", exclude_none=True)
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")

interval: datetime.timedelta = Field(gt=datetime.timedelta(0))
anchor_date: Annotated[DateTime, AfterValidator(default_anchor_date)] = Field(
@@ -68,6 +76,19 @@ def validate_timezone(self):
self.timezone = default_timezone(self.timezone, self.model_dump())
return self

if TYPE_CHECKING:
# The model accepts str or datetime values for `anchor_date`
def __init__(
self,
/,
interval: datetime.timedelta,
anchor_date: Optional[
Union[pendulum.DateTime, datetime.datetime, str]
] = None,
timezone: Optional[str] = None,
) -> None:
...


class CronSchedule(PrefectBaseModel):
"""
@@ -94,7 +115,7 @@ class CronSchedule(PrefectBaseModel):
"""

model_config = ConfigDict(extra="forbid")
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")

cron: str = Field(default=..., examples=["0 0 * * *"])
timezone: Optional[str] = Field(default=None, examples=["America/New_York"])
@@ -107,18 +128,36 @@ class CronSchedule(PrefectBaseModel):

@field_validator("timezone")
@classmethod
def valid_timezone(cls, v):
def valid_timezone(cls, v: Optional[str]) -> str:
return default_timezone(v)

@field_validator("cron")
@classmethod
def valid_cron_string(cls, v):
def valid_cron_string(cls, v: str) -> str:
return validate_cron_string(v)


DEFAULT_ANCHOR_DATE = pendulum.date(2020, 1, 1)


def _rrule_dt(
rrule: dateutil.rrule.rrule, name: str = "_dtstart"
) -> Optional[datetime.datetime]:
return getattr(rrule, name, None)


def _rrule(
rruleset: dateutil.rrule.rruleset, name: str = "_rrule"
) -> list[dateutil.rrule.rrule]:
return getattr(rruleset, name, [])


def _rdates(
rrule: dateutil.rrule.rruleset, name: str = "_rdate"
) -> list[datetime.datetime]:
return getattr(rrule, name, [])


class RRuleSchedule(PrefectBaseModel):
"""
RRule schedule, based on the iCalendar standard
@@ -139,7 +178,7 @@ class RRuleSchedule(PrefectBaseModel):
timezone (str, optional): a valid timezone string
"""

model_config = ConfigDict(extra="forbid")
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")

rrule: str
timezone: Optional[str] = Field(
@@ -148,58 +187,60 @@ class RRuleSchedule(PrefectBaseModel):

@field_validator("rrule")
@classmethod
def validate_rrule_str(cls, v):
def validate_rrule_str(cls, v: str) -> str:
return validate_rrule_string(v)

@classmethod
def from_rrule(cls, rrule: dateutil.rrule.rrule):
def from_rrule(
cls, rrule: Union[dateutil.rrule.rrule, dateutil.rrule.rruleset]
) -> "RRuleSchedule":
if isinstance(rrule, dateutil.rrule.rrule):
if rrule._dtstart.tzinfo is not None:
timezone = rrule._dtstart.tzinfo.name
dtstart = _rrule_dt(rrule)
if dtstart and dtstart.tzinfo is not None:
timezone = dtstart.tzinfo.tzname(dtstart)
else:
timezone = "UTC"
return RRuleSchedule(rrule=str(rrule), timezone=timezone)
elif isinstance(rrule, dateutil.rrule.rruleset):
dtstarts = [rr._dtstart for rr in rrule._rrule if rr._dtstart is not None]
unique_dstarts = set(pendulum.instance(d).in_tz("UTC") for d in dtstarts)
unique_timezones = set(d.tzinfo for d in dtstarts if d.tzinfo is not None)

if len(unique_timezones) > 1:
raise ValueError(
f"rruleset has too many dtstart timezones: {unique_timezones}"
)

if len(unique_dstarts) > 1:
raise ValueError(f"rruleset has too many dtstarts: {unique_dstarts}")

if unique_dstarts and unique_timezones:
timezone = dtstarts[0].tzinfo.name
else:
timezone = "UTC"

rruleset_string = ""
if rrule._rrule:
rruleset_string += "\n".join(str(r) for r in rrule._rrule)
if rrule._exrule:
rruleset_string += "\n" if rruleset_string else ""
rruleset_string += "\n".join(str(r) for r in rrule._exrule).replace(
"RRULE", "EXRULE"
)
if rrule._rdate:
rruleset_string += "\n" if rruleset_string else ""
rruleset_string += "RDATE:" + ",".join(
rd.strftime("%Y%m%dT%H%M%SZ") for rd in rrule._rdate
)
if rrule._exdate:
rruleset_string += "\n" if rruleset_string else ""
rruleset_string += "EXDATE:" + ",".join(
exd.strftime("%Y%m%dT%H%M%SZ") for exd in rrule._exdate
)
return RRuleSchedule(rrule=rruleset_string, timezone=timezone)
rrules = _rrule(rrule)
dtstarts = [dts for rr in rrules if (dts := _rrule_dt(rr)) is not None]
unique_dstarts = set(pendulum.instance(d).in_tz("UTC") for d in dtstarts)
unique_timezones = set(d.tzinfo for d in dtstarts if d.tzinfo is not None)

if len(unique_timezones) > 1:
raise ValueError(
f"rruleset has too many dtstart timezones: {unique_timezones}"
)

if len(unique_dstarts) > 1:
raise ValueError(f"rruleset has too many dtstarts: {unique_dstarts}")

if unique_dstarts and unique_timezones:
[unique_tz] = unique_timezones
timezone = unique_tz.tzname(dtstarts[0])
else:
raise ValueError(f"Invalid RRule object: {rrule}")

def to_rrule(self) -> dateutil.rrule.rrule:
timezone = "UTC"

rruleset_string = ""
if rrules:
rruleset_string += "\n".join(str(r) for r in rrules)
if exrule := _rrule(rrule, "_exrule"):
rruleset_string += "\n" if rruleset_string else ""
rruleset_string += "\n".join(str(r) for r in exrule).replace(
"RRULE", "EXRULE"
)
if rdates := _rdates(rrule):
rruleset_string += "\n" if rruleset_string else ""
rruleset_string += "RDATE:" + ",".join(
rd.strftime("%Y%m%dT%H%M%SZ") for rd in rdates
)
if exdates := _rdates(rrule, "_exdate"):
rruleset_string += "\n" if rruleset_string else ""
rruleset_string += "EXDATE:" + ",".join(
exd.strftime("%Y%m%dT%H%M%SZ") for exd in exdates
)
return RRuleSchedule(rrule=rruleset_string, timezone=timezone)

def to_rrule(self) -> Union[dateutil.rrule.rrule, dateutil.rrule.rruleset]:
"""
Since rrule doesn't properly serialize/deserialize timezones, we localize dates
here
@@ -211,51 +252,53 @@ def to_rrule(self) -> dateutil.rrule.rrule:
)
timezone = dateutil.tz.gettz(self.timezone)
if isinstance(rrule, dateutil.rrule.rrule):
kwargs = dict(dtstart=rrule._dtstart.replace(tzinfo=timezone))
if rrule._until:
dtstart = _rrule_dt(rrule)
assert dtstart is not None
kwargs: dict[str, Any] = dict(dtstart=dtstart.replace(tzinfo=timezone))
if until := _rrule_dt(rrule, "_until"):
kwargs.update(
until=rrule._until.replace(tzinfo=timezone),
until=until.replace(tzinfo=timezone),
)
return rrule.replace(**kwargs)
elif isinstance(rrule, dateutil.rrule.rruleset):
# update rrules
localized_rrules = []
for rr in rrule._rrule:
kwargs = dict(dtstart=rr._dtstart.replace(tzinfo=timezone))
if rr._until:
kwargs.update(
until=rr._until.replace(tzinfo=timezone),
)
localized_rrules.append(rr.replace(**kwargs))
rrule._rrule = localized_rrules

# update exrules
localized_exrules = []
for exr in rrule._exrule:
kwargs = dict(dtstart=exr._dtstart.replace(tzinfo=timezone))
if exr._until:
kwargs.update(
until=exr._until.replace(tzinfo=timezone),
)
localized_exrules.append(exr.replace(**kwargs))
rrule._exrule = localized_exrules

# update rdates
localized_rdates = []
for rd in rrule._rdate:
localized_rdates.append(rd.replace(tzinfo=timezone))
rrule._rdate = localized_rdates

# update exdates
localized_exdates = []
for exd in rrule._exdate:
localized_exdates.append(exd.replace(tzinfo=timezone))
rrule._exdate = localized_exdates

return rrule

# update rrules
localized_rrules: list[dateutil.rrule.rrule] = []
for rr in _rrule(rrule):
dtstart = _rrule_dt(rr)
assert dtstart is not None
kwargs: dict[str, Any] = dict(dtstart=dtstart.replace(tzinfo=timezone))
if until := _rrule_dt(rr, "_until"):
kwargs.update(until=until.replace(tzinfo=timezone))
localized_rrules.append(rr.replace(**kwargs))
setattr(rrule, "_rrule", localized_rrules)

# update exrules
localized_exrules: list[dateutil.rrule.rruleset] = []
for exr in _rrule(rrule, "_exrule"):
dtstart = _rrule_dt(exr)
assert dtstart is not None
kwargs = dict(dtstart=dtstart.replace(tzinfo=timezone))
if until := _rrule_dt(exr, "_until"):
kwargs.update(until=until.replace(tzinfo=timezone))
localized_exrules.append(exr.replace(**kwargs))
setattr(rrule, "_exrule", localized_exrules)

# update rdates
localized_rdates: list[datetime.datetime] = []
for rd in _rdates(rrule):
localized_rdates.append(rd.replace(tzinfo=timezone))
setattr(rrule, "_rdate", localized_rdates)

# update exdates
localized_exdates: list[datetime.datetime] = []
for exd in _rdates(rrule, "_exdate"):
localized_exdates.append(exd.replace(tzinfo=timezone))
setattr(rrule, "_exdate", localized_exdates)

return rrule

@field_validator("timezone")
def valid_timezone(cls, v):
def valid_timezone(cls, v: Optional[str]) -> str:
"""
Validate that the provided timezone is a valid IANA timezone.
@@ -277,7 +320,7 @@ def valid_timezone(cls, v):


class NoSchedule(PrefectBaseModel):
model_config = ConfigDict(extra="forbid")
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")


SCHEDULE_TYPES: TypeAlias = Union[
@@ -326,7 +369,7 @@ def construct_schedule(
if isinstance(interval, (int, float)):
interval = datetime.timedelta(seconds=interval)
if not anchor_date:
anchor_date = DateTime.now()
anchor_date = pendulum.DateTime.now()
schedule = IntervalSchedule(
interval=interval, anchor_date=anchor_date, timezone=timezone
)
16 changes: 9 additions & 7 deletions src/prefect/client/subscriptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
from typing import Any, Dict, Generic, Iterable, Optional, Type, TypeVar
from collections.abc import Iterable
from logging import Logger
from typing import Any, Generic, Optional, TypeVar

import orjson
import websockets
@@ -11,15 +13,15 @@
from prefect.logging import get_logger
from prefect.settings import PREFECT_API_KEY

logger = get_logger(__name__)
logger: Logger = get_logger(__name__)

S = TypeVar("S", bound=IDBaseModel)


class Subscription(Generic[S]):
def __init__(
self,
model: Type[S],
model: type[S],
path: str,
keys: Iterable[str],
client_id: Optional[str] = None,
@@ -28,9 +30,9 @@ def __init__(
self.model = model
self.client_id = client_id
base_url = base_url.replace("http", "ws", 1) if base_url else None
self.subscription_url = f"{base_url}{path}"
self.subscription_url: str = f"{base_url}{path}"

self.keys = list(keys)
self.keys: list[str] = list(keys)

self._connect = websockets.connect(
self.subscription_url,
@@ -78,10 +80,10 @@ async def _ensure_connected(self):
).decode()
)

auth: Dict[str, Any] = orjson.loads(await websocket.recv())
auth: dict[str, Any] = orjson.loads(await websocket.recv())
assert auth["type"] == "auth_success", auth.get("message")

message = {"type": "subscribe", "keys": self.keys}
message: dict[str, Any] = {"type": "subscribe", "keys": self.keys}
if self.client_id:
message.update({"client_id": self.client_id})

72 changes: 32 additions & 40 deletions src/prefect/client/utilities.py
Original file line number Diff line number Diff line change
@@ -5,32 +5,31 @@
# This module must not import from `prefect.client` when it is imported to avoid
# circular imports for decorators such as `inject_client` which are widely used.

from collections.abc import Awaitable, Coroutine
from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Coroutine,
Optional,
Tuple,
TypeVar,
Union,
cast,
)

from typing_extensions import Concatenate, ParamSpec
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

from typing_extensions import Concatenate, ParamSpec, TypeIs, TypeVar

if TYPE_CHECKING:
from prefect.client.orchestration import PrefectClient, SyncPrefectClient

P = ParamSpec("P")
R = TypeVar("R")
R = TypeVar("R", infer_variance=True)


def _current_async_client(
client: Union["PrefectClient", "SyncPrefectClient"],
) -> TypeIs["PrefectClient"]:
from prefect._internal.concurrency.event_loop import get_running_loop

# Only a PrefectClient will have a _loop attribute that is the current loop
return getattr(client, "_loop", None) == get_running_loop()


def get_or_create_client(
client: Optional["PrefectClient"] = None,
) -> Tuple[Union["PrefectClient", "SyncPrefectClient"], bool]:
) -> tuple["PrefectClient", bool]:
"""
Returns provided client, infers a client from context if available, or creates a new client.
@@ -42,29 +41,22 @@ def get_or_create_client(
"""
if client is not None:
return client, True
from prefect._internal.concurrency.event_loop import get_running_loop

from prefect.context import AsyncClientContext, FlowRunContext, TaskRunContext

async_client_context = AsyncClientContext.get()
flow_run_context = FlowRunContext.get()
task_run_context = TaskRunContext.get()

if async_client_context and async_client_context.client._loop == get_running_loop(): # type: ignore[reportPrivateUsage]
return async_client_context.client, True
elif (
flow_run_context
and getattr(flow_run_context.client, "_loop", None) == get_running_loop()
):
return flow_run_context.client, True
elif (
task_run_context
and getattr(task_run_context.client, "_loop", None) == get_running_loop()
):
return task_run_context.client, True
else:
from prefect.client.orchestration import get_client as get_httpx_client
for context in (async_client_context, flow_run_context, task_run_context):
if context is None:
continue
if _current_async_client(context_client := context.client):
return context_client, True

from prefect.client.orchestration import get_client as get_httpx_client

return get_httpx_client(), False
return get_httpx_client(), False


def client_injector(
@@ -73,7 +65,7 @@ def client_injector(
@wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
client, _ = get_or_create_client()
return await func(cast("PrefectClient", client), *args, **kwargs)
return await func(client, *args, **kwargs)

return wrapper

@@ -91,18 +83,18 @@ def inject_client(

@wraps(fn)
async def with_injected_client(*args: P.args, **kwargs: P.kwargs) -> R:
client, inferred = get_or_create_client(
cast(Optional["PrefectClient"], kwargs.pop("client", None))
)
_client = cast("PrefectClient", client)
given = kwargs.pop("client", None)
if TYPE_CHECKING:
assert given is None or isinstance(given, PrefectClient)
client, inferred = get_or_create_client(given)
if not inferred:
context = _client
context = client
else:
from prefect.utilities.asyncutils import asyncnullcontext

context = asyncnullcontext()
context = asyncnullcontext(client)
async with context as new_client:
kwargs.setdefault("client", new_client or _client)
kwargs |= {"client": new_client}
return await fn(*args, **kwargs)

return with_injected_client
33 changes: 12 additions & 21 deletions src/prefect/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Import user-facing API
from typing import Any

from prefect.deployments import deploy
from prefect.states import State
from prefect.logging import get_run_logger
@@ -25,28 +27,17 @@
# Perform any forward-ref updates needed for Pydantic models
import prefect.client.schemas

prefect.context.FlowRunContext.model_rebuild(
_types_namespace={
"Flow": Flow,
"BaseResult": BaseResult,
"ResultRecordMetadata": ResultRecordMetadata,
}
)
prefect.context.TaskRunContext.model_rebuild(
_types_namespace={"Task": Task, "BaseResult": BaseResult}
)
prefect.client.schemas.State.model_rebuild(
_types_namespace={
"BaseResult": BaseResult,
"ResultRecordMetadata": ResultRecordMetadata,
}
)
prefect.client.schemas.StateCreate.model_rebuild(
_types_namespace={
"BaseResult": BaseResult,
"ResultRecordMetadata": ResultRecordMetadata,
}
_types: dict[str, Any] = dict(
Task=Task,
Flow=Flow,
BaseResult=BaseResult,
ResultRecordMetadata=ResultRecordMetadata,
)
prefect.context.FlowRunContext.model_rebuild(_types_namespace=_types)
prefect.context.TaskRunContext.model_rebuild(_types_namespace=_types)
prefect.client.schemas.State.model_rebuild(_types_namespace=_types)
prefect.client.schemas.StateCreate.model_rebuild(_types_namespace=_types)
prefect.client.schemas.OrchestrationResult.model_rebuild(_types_namespace=_types)
Transaction.model_rebuild()

# Configure logging

0 comments on commit 24c7f08

Please sign in to comment.