Skip to content

Commit

Permalink
counting and clients
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle committed Jan 14, 2025
1 parent 2cb7eaf commit a421e51
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 12 deletions.
23 changes: 18 additions & 5 deletions src/prefect/server/events/clients.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
import abc
from textwrap import dedent
from types import TracebackType
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, Union
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
List,
Optional,
Tuple,
Type,
Union,
)
from uuid import UUID

import httpx
Expand All @@ -16,7 +26,10 @@
ResourceSpecification,
)

logger = get_logger(__name__)
if TYPE_CHECKING:
import logging

logger: "logging.Logger" = get_logger(__name__)

LabelValue: TypeAlias = Union[str, List[str]]

Expand Down Expand Up @@ -111,7 +124,7 @@ def assert_emitted_event_with(
resource: Optional[Dict[str, LabelValue]] = None,
related: Optional[List[Dict[str, LabelValue]]] = None,
payload: Optional[Dict[str, Any]] = None,
):
) -> None:
"""Assert that an event was emitted containing the given properties."""
assert cls.last is not None and cls.all, "No event client was created"

Expand Down Expand Up @@ -185,7 +198,7 @@ def assert_no_emitted_event_with(
resource: Optional[Dict[str, LabelValue]] = None,
related: Optional[List[Dict[str, LabelValue]]] = None,
payload: Optional[Dict[str, Any]] = None,
):
) -> None:
try:
cls.assert_emitted_event_with(event, resource, related, payload)
except AssertionError:
Expand Down Expand Up @@ -218,7 +231,7 @@ async def emit(self, event: Event) -> ReceivedEvent:
"context manager"
)
received_event = event.receive()
await self._publisher.publish_event(event)
await self._publisher.publish_event(received_event)
return received_event


Expand Down
16 changes: 9 additions & 7 deletions src/prefect/server/events/counting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from datetime import timedelta
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Generator

import pendulum
import sqlalchemy as sa
Expand Down Expand Up @@ -34,7 +34,7 @@ class TimeUnit(AutoEnum):
minute = AutoEnum.auto()
second = AutoEnum.auto()

def as_timedelta(self, interval) -> pendulum.Duration:
def as_timedelta(self, interval: float) -> pendulum.Duration:
if self == self.week:
return pendulum.Duration(days=7 * interval)
elif self == self.day:
Expand All @@ -50,7 +50,7 @@ def as_timedelta(self, interval) -> pendulum.Duration:

def validate_buckets(
self, start_datetime: DateTime, end_datetime: DateTime, interval: float
):
) -> None:
MAX_ALLOWED_BUCKETS = 1000

delta = self.as_timedelta(interval)
Expand All @@ -73,7 +73,7 @@ def get_interval_spans(
start_datetime: DateTime,
end_datetime: DateTime,
interval: float,
):
) -> Generator[int | tuple[pendulum.DateTime, pendulum.DateTime], None, None]:
"""Divide the given range of dates into evenly-sized spans of interval units"""
self.validate_buckets(start_datetime, end_datetime, interval)

Expand All @@ -100,7 +100,7 @@ def get_interval_spans(
yield (span_start, next_span_start - timedelta(microseconds=1))
span_start = next_span_start

def database_value_expression(self, time_interval: float):
def database_value_expression(self, time_interval: float) -> sa.Cast[str]:
"""Returns the SQL expression to place an event in a time bucket"""
# The date_bin function can do the bucketing for us:
# https://www.postgresql.org/docs/14/functions-datetime.html#FUNCTIONS-DATETIME-BIN
Expand Down Expand Up @@ -135,7 +135,9 @@ def database_value_expression(self, time_interval: float):
else:
raise NotImplementedError(f"Dialect {db.dialect.name} is not supported.")

def database_label_expression(self, db: PrefectDBInterface, time_interval: float):
def database_label_expression(
self, db: PrefectDBInterface, time_interval: float
) -> sa.Function[str]:
"""Returns the SQL expression to label a time bucket"""
time_delta = self.as_timedelta(time_interval)
if db.dialect.name == "postgresql":
Expand Down Expand Up @@ -176,7 +178,7 @@ def get_database_query(
filter: "EventFilter",
time_unit: TimeUnit,
time_interval: float,
) -> Select:
) -> Select[tuple[str, str, DateTime, DateTime, int]]:
db = provide_database_interface()
# The innermost SELECT pulls the matching events and groups them up by their
# buckets. At this point, there may be duplicate buckets for each value, since
Expand Down

0 comments on commit a421e51

Please sign in to comment.