Skip to content

Commit

Permalink
Merge pull request #215 from kpn/fix/consumer-record
Browse files Browse the repository at this point in the history
fix: move consumer record to kstreams
  • Loading branch information
woile authored Oct 16, 2024
2 parents 8623734 + 06cb28f commit 84384dd
Show file tree
Hide file tree
Showing 12 changed files with 261 additions and 96 deletions.
4 changes: 2 additions & 2 deletions kstreams/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from aiokafka.structs import ConsumerRecord, RecordMetadata, TopicPartition
from aiokafka.structs import RecordMetadata, TopicPartition

from .clients import Consumer, Producer
from .create import StreamEngine, create_engine
Expand All @@ -11,7 +11,7 @@
from .streams import Stream, stream
from .structs import TopicPartitionOffset
from .test_utils import TestStreamClient
from .types import Send
from .types import ConsumerRecord, Send

__all__ = [
"Consumer",
Expand Down
10 changes: 6 additions & 4 deletions kstreams/middleware/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from aiokafka import errors

from kstreams import ConsumerRecord, types
from kstreams import types
from kstreams.streams_utils import StreamErrorPolicy

if typing.TYPE_CHECKING:
Expand All @@ -25,7 +25,9 @@ def __init__(
**kwargs: typing.Any,
) -> None: ... # pragma: no cover

async def __call__(self, cr: ConsumerRecord) -> typing.Any: ... # pragma: no cover
async def __call__(
self, cr: types.ConsumerRecord
) -> typing.Any: ... # pragma: no cover


class Middleware:
Expand Down Expand Up @@ -56,7 +58,7 @@ def __init__(
self.send = send
self.stream = stream

async def __call__(self, cr: ConsumerRecord) -> typing.Any:
async def __call__(self, cr: types.ConsumerRecord) -> typing.Any:
raise NotImplementedError


Expand All @@ -76,7 +78,7 @@ def __init__(
self.engine = engine
self.error_policy = error_policy

async def __call__(self, cr: ConsumerRecord) -> typing.Any:
async def __call__(self, cr: types.ConsumerRecord) -> typing.Any:
try:
return await self.next_call(cr)
except errors.ConsumerStoppedError as exc:
Expand Down
8 changes: 4 additions & 4 deletions kstreams/middleware/udf_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
import typing

from kstreams import ConsumerRecord, types
from kstreams import types
from kstreams.streams import Stream
from kstreams.streams_utils import UDFType, setup_type

Expand All @@ -21,18 +21,18 @@ def __init__(self, *args, **kwargs) -> None:
self.params = list(signature.parameters.values())
self.type: UDFType = setup_type(self.params)

def bind_udf_params(self, cr: ConsumerRecord) -> typing.List:
def bind_udf_params(self, cr: types.ConsumerRecord) -> typing.List:
# NOTE: When `no typing` support is deprecated then this can
# be more eficient as the CR will be always there.
ANNOTATIONS_TO_PARAMS = {
ConsumerRecord: cr,
types.ConsumerRecord: cr,
Stream: self.stream,
types.Send: self.send,
}

return [ANNOTATIONS_TO_PARAMS[param.annotation] for param in self.params]

async def __call__(self, cr: ConsumerRecord) -> typing.Any:
async def __call__(self, cr: types.ConsumerRecord) -> typing.Any:
"""
Call the coroutine `async def my_function(...)` defined by the end user
in a proper way according to its parameters. The `handler` is the
Expand Down
4 changes: 1 addition & 3 deletions kstreams/serializers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from typing import Any, Dict, Optional, Protocol

from kstreams import ConsumerRecord

from .types import Headers
from .types import ConsumerRecord, Headers


class Deserializer(Protocol):
Expand Down
6 changes: 3 additions & 3 deletions kstreams/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from aiokafka import errors

from kstreams import ConsumerRecord, TopicPartition
from kstreams import TopicPartition
from kstreams.exceptions import BackendNotSet
from kstreams.middleware.middleware import ExceptionMiddleware
from kstreams.structs import TopicPartitionOffset
Expand All @@ -19,7 +19,7 @@
from .rebalance_listener import RebalanceListener
from .serializers import Deserializer
from .streams_utils import StreamErrorPolicy, UDFType
from .types import Deprecated, StreamFunc
from .types import ConsumerRecord, Deprecated, StreamFunc

if typing.TYPE_CHECKING:
from kstreams import StreamEngine
Expand Down Expand Up @@ -352,7 +352,7 @@ async def start(self) -> None:
)
warnings.warn(msg, DeprecationWarning, stacklevel=2)

func = self.udf_handler.next_call(self)
func = self.udf_handler.next_call(self) # type: ignore
await func
else:
# Typing cases
Expand Down
28 changes: 17 additions & 11 deletions kstreams/test_utils/test_clients.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from datetime import datetime
from typing import Any, Coroutine, Dict, List, Optional, Sequence, Set
from typing import Any, Coroutine, Dict, List, Optional, Sequence, Set, Union

from kstreams import ConsumerRecord, RebalanceListener, TopicPartition
from kstreams import RebalanceListener, TopicPartition
from kstreams.clients import Consumer, Producer
from kstreams.serializers import Serializer
from kstreams.types import Headers
from kstreams.types import ConsumerRecord, EncodedHeaders, Headers

from .structs import RecordMetadata
from .topics import TopicManager
Expand All @@ -23,28 +23,34 @@ async def send(
value: Any = None,
key: Any = None,
partition: int = 0,
timestamp_ms: Optional[float] = None,
timestamp_ms: Optional[int] = None,
headers: Optional[Headers] = None,
serializer: Optional[Serializer] = None,
serializer_kwargs: Optional[Dict] = None,
) -> Coroutine:
topic, _ = TopicManager.get_or_create(topic_name)
timestamp_ms = timestamp_ms or datetime.now().timestamp()
timestamp_ms = timestamp_ms or datetime.now().toordinal()
total_partition_events = topic.offset(partition=partition)
partition = partition or 0

consumer_record = ConsumerRecord(
_headers: EncodedHeaders = []
if isinstance(headers, dict):
_headers = [(key, value.encode()) for key, value in headers.items()]

serialized_key_size = -1 if key is None else len(key)
serialized_value_size = -1 if value is None else len(value)
consumer_record: ConsumerRecord = ConsumerRecord(
topic=topic_name,
value=value,
key=key,
headers=headers,
headers=_headers,
partition=partition,
timestamp=timestamp_ms,
offset=total_partition_events + 1,
timestamp_type=None,
timestamp_type=0,
checksum=None,
serialized_key_size=None,
serialized_value_size=None,
serialized_key_size=-serialized_key_size,
serialized_value_size=serialized_value_size,
)

await topic.put(consumer_record)
Expand Down Expand Up @@ -204,7 +210,7 @@ async def getmany(
*partitions: List[TopicPartition],
timeout_ms: int = 0,
max_records: int = 1,
) -> Dict[TopicPartition, List[ConsumerRecord]]:
) -> Dict[TopicPartition, List[Union[ConsumerRecord, None]]]:
"""
Basic getmany implementation.
`partitions` and `timeout_ms` could be added to the logic
Expand Down
4 changes: 2 additions & 2 deletions kstreams/test_utils/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from types import TracebackType
from typing import Any, Dict, List, Optional, Type

from kstreams import Consumer, ConsumerRecord, Producer
from kstreams import Consumer, Producer
from kstreams.engine import StreamEngine
from kstreams.prometheus.monitor import PrometheusMonitor
from kstreams.serializers import Serializer
from kstreams.streams import Stream
from kstreams.types import Headers
from kstreams.types import ConsumerRecord, Headers

from .structs import RecordMetadata
from .test_clients import TestConsumer, TestProducer
Expand Down
2 changes: 1 addition & 1 deletion kstreams/test_utils/topics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass, field
from typing import ClassVar, DefaultDict, Dict, Optional, Sequence, Tuple

from kstreams import ConsumerRecord
from kstreams.types import ConsumerRecord

from . import test_clients

Expand Down
52 changes: 50 additions & 2 deletions kstreams/types.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import typing
from dataclasses import dataclass

from kstreams import ConsumerRecord, RecordMetadata
from aiokafka.structs import RecordMetadata

if typing.TYPE_CHECKING:
from .serializers import Serializer # pragma: no cover

Headers = typing.Dict[str, str]
EncodedHeaders = typing.Sequence[typing.Tuple[str, bytes]]
StreamFunc = typing.Callable
NextMiddlewareCall = typing.Callable[[ConsumerRecord], typing.Awaitable[None]]

EngineHooks = typing.Sequence[typing.Callable[[], typing.Any]]


Expand All @@ -28,3 +29,50 @@ def __call__(

D = typing.TypeVar("D")
Deprecated = typing.Annotated[D, "deprecated"]

KT = typing.TypeVar("KT")
VT = typing.TypeVar("VT")


@dataclass
class ConsumerRecord(typing.Generic[KT, VT]):
topic: str
"The topic this record is received from"

partition: int
"The partition from which this record is received"

offset: int
"The position of this record in the corresponding Kafka partition."

timestamp: int
"The timestamp of this record"

timestamp_type: int
"The timestamp type of this record"

key: typing.Optional[KT]
"The key (or `None` if no key is specified)"

value: typing.Optional[VT]
"The value"

checksum: typing.Optional[int]
"Deprecated"

serialized_key_size: int
"The size of the serialized, uncompressed key in bytes."

serialized_value_size: int
"The size of the serialized, uncompressed value in bytes."

headers: EncodedHeaders
"The headers"


NextMiddlewareCall = typing.Callable[[ConsumerRecord], typing.Awaitable[None]]

# 0 for CreateTime; 1 for LogAppendTime;
# aiokafka also supports None, which means it's unsupported, but
# we only support messages > 1
TimestampType = typing.Literal[0, 1]
Loading

0 comments on commit 84384dd

Please sign in to comment.