From a86f9006ed7addeedb4e232b0c6671616e526441 Mon Sep 17 00:00:00 2001 From: Sidhant Kohli Date: Sun, 6 Oct 2024 22:20:00 -0700 Subject: [PATCH] chore: streaming sink Signed-off-by: Sidhant Kohli --- pynumaflow/proto/sinker/sink.proto | 40 ++- pynumaflow/proto/sinker/sink_pb2.py | 38 ++- pynumaflow/proto/sinker/sink_pb2.pyi | 89 +++-- pynumaflow/proto/sinker/sink_pb2_grpc.py | 6 +- pynumaflow/shared/asynciter.py | 4 +- pynumaflow/shared/server.py | 10 + pynumaflow/shared/servicer.py | 3 + pynumaflow/sinker/__init__.py | 5 +- pynumaflow/sinker/_dtypes.py | 20 ++ pynumaflow/sinker/server.py | 240 +++++++------- pynumaflow/sinker/servicer/async_servicer.py | 97 ++++-- pynumaflow/sinker/servicer/sync_servicer.py | 138 ++++---- pynumaflow/sourcer/servicer/async_servicer.py | 24 +- tests/sink/test_async_sink.py | 82 +++-- tests/sink/test_server.py | 312 +++++++++--------- 15 files changed, 623 insertions(+), 485 deletions(-) create mode 100644 pynumaflow/shared/servicer.py diff --git a/pynumaflow/proto/sinker/sink.proto b/pynumaflow/proto/sinker/sink.proto index df599f03..a6ce024a 100644 --- a/pynumaflow/proto/sinker/sink.proto +++ b/pynumaflow/proto/sinker/sink.proto @@ -1,4 +1,5 @@ syntax = "proto3"; + import "google/protobuf/empty.proto"; import "google/protobuf/timestamp.proto"; @@ -7,7 +8,7 @@ package sink.v1; service Sink { // SinkFn writes the request to a user defined sink. - rpc SinkFn(stream SinkRequest) returns (SinkResponse); + rpc SinkFn(stream SinkRequest) returns (stream SinkResponse); // IsReady is the heartbeat endpoint for gRPC. rpc IsReady(google.protobuf.Empty) returns (ReadyResponse); @@ -17,12 +18,32 @@ service Sink { * SinkRequest represents a request element. */ message SinkRequest { - repeated string keys = 1; - bytes value = 2; - google.protobuf.Timestamp event_time = 3; - google.protobuf.Timestamp watermark = 4; - string id = 5; - map headers = 6; + message Request { + repeated string keys = 1; + bytes value = 2; + google.protobuf.Timestamp event_time = 3; + google.protobuf.Timestamp watermark = 4; + string id = 5; + map headers = 6; + } + message Status { + bool eot = 1; + } + // Required field indicating the request. + Request request = 1; + // Required field indicating the status of the request. + // If eot is set to true, it indicates the end of transmission. + Status status = 2; + // optional field indicating the handshake message. + optional Handshake handshake = 3; +} + +/* + * Handshake message between client and server to indicate the start of transmission. + */ +message Handshake { + // Required field indicating the start of transmission. + bool sot = 1; } /** @@ -53,5 +74,6 @@ message SinkResponse { // err_msg is the error message, set it if success is set to false. string err_msg = 3; } - repeated Result results = 1; -} \ No newline at end of file + Result result = 1; + optional Handshake handshake = 2; +} diff --git a/pynumaflow/proto/sinker/sink_pb2.py b/pynumaflow/proto/sinker/sink_pb2.py index 00b8326e..85053d59 100644 --- a/pynumaflow/proto/sinker/sink_pb2.py +++ b/pynumaflow/proto/sinker/sink_pb2.py @@ -18,7 +18,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\nsink.proto\x12\x07sink.v1\x1a\x1bgoogle/protobuf/empty.proto\x1a\x1fgoogle/protobuf/timestamp.proto"\xf9\x01\n\x0bSinkRequest\x12\x0c\n\x04keys\x18\x01 \x03(\t\x12\r\n\x05value\x18\x02 \x01(\x0c\x12.\n\nevent_time\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12-\n\twatermark\x18\x04 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\n\n\x02id\x18\x05 \x01(\t\x12\x32\n\x07headers\x18\x06 \x03(\x0b\x32!.sink.v1.SinkRequest.HeadersEntry\x1a.\n\x0cHeadersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01"\x1e\n\rReadyResponse\x12\r\n\x05ready\x18\x01 \x01(\x08"\x85\x01\n\x0cSinkResponse\x12-\n\x07results\x18\x01 \x03(\x0b\x32\x1c.sink.v1.SinkResponse.Result\x1a\x46\n\x06Result\x12\n\n\x02id\x18\x01 \x01(\t\x12\x1f\n\x06status\x18\x02 \x01(\x0e\x32\x0f.sink.v1.Status\x12\x0f\n\x07\x65rr_msg\x18\x03 \x01(\t*0\n\x06Status\x12\x0b\n\x07SUCCESS\x10\x00\x12\x0b\n\x07\x46\x41ILURE\x10\x01\x12\x0c\n\x08\x46\x41LLBACK\x10\x02\x32z\n\x04Sink\x12\x37\n\x06SinkFn\x12\x14.sink.v1.SinkRequest\x1a\x15.sink.v1.SinkResponse(\x01\x12\x39\n\x07IsReady\x12\x16.google.protobuf.Empty\x1a\x16.sink.v1.ReadyResponseb\x06proto3' + b'\n\nsink.proto\x12\x07sink.v1\x1a\x1bgoogle/protobuf/empty.proto\x1a\x1fgoogle/protobuf/timestamp.proto"\xba\x03\n\x0bSinkRequest\x12-\n\x07request\x18\x01 \x01(\x0b\x32\x1c.sink.v1.SinkRequest.Request\x12+\n\x06status\x18\x02 \x01(\x0b\x32\x1b.sink.v1.SinkRequest.Status\x12*\n\thandshake\x18\x03 \x01(\x0b\x32\x12.sink.v1.HandshakeH\x00\x88\x01\x01\x1a\xfd\x01\n\x07Request\x12\x0c\n\x04keys\x18\x01 \x03(\t\x12\r\n\x05value\x18\x02 \x01(\x0c\x12.\n\nevent_time\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12-\n\twatermark\x18\x04 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\n\n\x02id\x18\x05 \x01(\t\x12:\n\x07headers\x18\x06 \x03(\x0b\x32).sink.v1.SinkRequest.Request.HeadersEntry\x1a.\n\x0cHeadersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\x15\n\x06Status\x12\x0b\n\x03\x65ot\x18\x01 \x01(\x08\x42\x0c\n\n_handshake"\x18\n\tHandshake\x12\x0b\n\x03sot\x18\x01 \x01(\x08"\x1e\n\rReadyResponse\x12\r\n\x05ready\x18\x01 \x01(\x08"\xbe\x01\n\x0cSinkResponse\x12,\n\x06result\x18\x01 \x01(\x0b\x32\x1c.sink.v1.SinkResponse.Result\x12*\n\thandshake\x18\x02 \x01(\x0b\x32\x12.sink.v1.HandshakeH\x00\x88\x01\x01\x1a\x46\n\x06Result\x12\n\n\x02id\x18\x01 \x01(\t\x12\x1f\n\x06status\x18\x02 \x01(\x0e\x32\x0f.sink.v1.Status\x12\x0f\n\x07\x65rr_msg\x18\x03 \x01(\tB\x0c\n\n_handshake*0\n\x06Status\x12\x0b\n\x07SUCCESS\x10\x00\x12\x0b\n\x07\x46\x41ILURE\x10\x01\x12\x0c\n\x08\x46\x41LLBACK\x10\x02\x32|\n\x04Sink\x12\x39\n\x06SinkFn\x12\x14.sink.v1.SinkRequest\x1a\x15.sink.v1.SinkResponse(\x01\x30\x01\x12\x39\n\x07IsReady\x12\x16.google.protobuf.Empty\x1a\x16.sink.v1.ReadyResponseb\x06proto3' ) _globals = globals() @@ -26,20 +26,26 @@ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "sink_pb2", _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals["_SINKREQUEST_HEADERSENTRY"]._options = None - _globals["_SINKREQUEST_HEADERSENTRY"]._serialized_options = b"8\001" - _globals["_STATUS"]._serialized_start = 505 - _globals["_STATUS"]._serialized_end = 553 + _globals["_SINKREQUEST_REQUEST_HEADERSENTRY"]._options = None + _globals["_SINKREQUEST_REQUEST_HEADERSENTRY"]._serialized_options = b"8\001" + _globals["_STATUS"]._serialized_start = 781 + _globals["_STATUS"]._serialized_end = 829 _globals["_SINKREQUEST"]._serialized_start = 86 - _globals["_SINKREQUEST"]._serialized_end = 335 - _globals["_SINKREQUEST_HEADERSENTRY"]._serialized_start = 289 - _globals["_SINKREQUEST_HEADERSENTRY"]._serialized_end = 335 - _globals["_READYRESPONSE"]._serialized_start = 337 - _globals["_READYRESPONSE"]._serialized_end = 367 - _globals["_SINKRESPONSE"]._serialized_start = 370 - _globals["_SINKRESPONSE"]._serialized_end = 503 - _globals["_SINKRESPONSE_RESULT"]._serialized_start = 433 - _globals["_SINKRESPONSE_RESULT"]._serialized_end = 503 - _globals["_SINK"]._serialized_start = 555 - _globals["_SINK"]._serialized_end = 677 + _globals["_SINKREQUEST"]._serialized_end = 528 + _globals["_SINKREQUEST_REQUEST"]._serialized_start = 238 + _globals["_SINKREQUEST_REQUEST"]._serialized_end = 491 + _globals["_SINKREQUEST_REQUEST_HEADERSENTRY"]._serialized_start = 445 + _globals["_SINKREQUEST_REQUEST_HEADERSENTRY"]._serialized_end = 491 + _globals["_SINKREQUEST_STATUS"]._serialized_start = 493 + _globals["_SINKREQUEST_STATUS"]._serialized_end = 514 + _globals["_HANDSHAKE"]._serialized_start = 530 + _globals["_HANDSHAKE"]._serialized_end = 554 + _globals["_READYRESPONSE"]._serialized_start = 556 + _globals["_READYRESPONSE"]._serialized_end = 586 + _globals["_SINKRESPONSE"]._serialized_start = 589 + _globals["_SINKRESPONSE"]._serialized_end = 779 + _globals["_SINKRESPONSE_RESULT"]._serialized_start = 695 + _globals["_SINKRESPONSE_RESULT"]._serialized_end = 765 + _globals["_SINK"]._serialized_start = 831 + _globals["_SINK"]._serialized_end = 955 # @@protoc_insertion_point(module_scope) diff --git a/pynumaflow/proto/sinker/sink_pb2.pyi b/pynumaflow/proto/sinker/sink_pb2.pyi index 71dcdf69..4ae2e2e3 100644 --- a/pynumaflow/proto/sinker/sink_pb2.pyi +++ b/pynumaflow/proto/sinker/sink_pb2.pyi @@ -25,37 +25,64 @@ FAILURE: Status FALLBACK: Status class SinkRequest(_message.Message): - __slots__ = ("keys", "value", "event_time", "watermark", "id", "headers") + __slots__ = ("request", "status", "handshake") - class HeadersEntry(_message.Message): - __slots__ = ("key", "value") - KEY_FIELD_NUMBER: _ClassVar[int] + class Request(_message.Message): + __slots__ = ("keys", "value", "event_time", "watermark", "id", "headers") + + class HeadersEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + KEYS_FIELD_NUMBER: _ClassVar[int] VALUE_FIELD_NUMBER: _ClassVar[int] - key: str - value: str - def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... - KEYS_FIELD_NUMBER: _ClassVar[int] - VALUE_FIELD_NUMBER: _ClassVar[int] - EVENT_TIME_FIELD_NUMBER: _ClassVar[int] - WATERMARK_FIELD_NUMBER: _ClassVar[int] - ID_FIELD_NUMBER: _ClassVar[int] - HEADERS_FIELD_NUMBER: _ClassVar[int] - keys: _containers.RepeatedScalarFieldContainer[str] - value: bytes - event_time: _timestamp_pb2.Timestamp - watermark: _timestamp_pb2.Timestamp - id: str - headers: _containers.ScalarMap[str, str] + EVENT_TIME_FIELD_NUMBER: _ClassVar[int] + WATERMARK_FIELD_NUMBER: _ClassVar[int] + ID_FIELD_NUMBER: _ClassVar[int] + HEADERS_FIELD_NUMBER: _ClassVar[int] + keys: _containers.RepeatedScalarFieldContainer[str] + value: bytes + event_time: _timestamp_pb2.Timestamp + watermark: _timestamp_pb2.Timestamp + id: str + headers: _containers.ScalarMap[str, str] + def __init__( + self, + keys: _Optional[_Iterable[str]] = ..., + value: _Optional[bytes] = ..., + event_time: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., + watermark: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., + id: _Optional[str] = ..., + headers: _Optional[_Mapping[str, str]] = ..., + ) -> None: ... + + class Status(_message.Message): + __slots__ = ("eot",) + EOT_FIELD_NUMBER: _ClassVar[int] + eot: bool + def __init__(self, eot: bool = ...) -> None: ... + REQUEST_FIELD_NUMBER: _ClassVar[int] + STATUS_FIELD_NUMBER: _ClassVar[int] + HANDSHAKE_FIELD_NUMBER: _ClassVar[int] + request: SinkRequest.Request + status: SinkRequest.Status + handshake: Handshake def __init__( self, - keys: _Optional[_Iterable[str]] = ..., - value: _Optional[bytes] = ..., - event_time: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., - watermark: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., - id: _Optional[str] = ..., - headers: _Optional[_Mapping[str, str]] = ..., + request: _Optional[_Union[SinkRequest.Request, _Mapping]] = ..., + status: _Optional[_Union[SinkRequest.Status, _Mapping]] = ..., + handshake: _Optional[_Union[Handshake, _Mapping]] = ..., ) -> None: ... +class Handshake(_message.Message): + __slots__ = ("sot",) + SOT_FIELD_NUMBER: _ClassVar[int] + sot: bool + def __init__(self, sot: bool = ...) -> None: ... + class ReadyResponse(_message.Message): __slots__ = ("ready",) READY_FIELD_NUMBER: _ClassVar[int] @@ -63,7 +90,7 @@ class ReadyResponse(_message.Message): def __init__(self, ready: bool = ...) -> None: ... class SinkResponse(_message.Message): - __slots__ = ("results",) + __slots__ = ("result", "handshake") class Result(_message.Message): __slots__ = ("id", "status", "err_msg") @@ -79,8 +106,12 @@ class SinkResponse(_message.Message): status: _Optional[_Union[Status, str]] = ..., err_msg: _Optional[str] = ..., ) -> None: ... - RESULTS_FIELD_NUMBER: _ClassVar[int] - results: _containers.RepeatedCompositeFieldContainer[SinkResponse.Result] + RESULT_FIELD_NUMBER: _ClassVar[int] + HANDSHAKE_FIELD_NUMBER: _ClassVar[int] + result: SinkResponse.Result + handshake: Handshake def __init__( - self, results: _Optional[_Iterable[_Union[SinkResponse.Result, _Mapping]]] = ... + self, + result: _Optional[_Union[SinkResponse.Result, _Mapping]] = ..., + handshake: _Optional[_Union[Handshake, _Mapping]] = ..., ) -> None: ... diff --git a/pynumaflow/proto/sinker/sink_pb2_grpc.py b/pynumaflow/proto/sinker/sink_pb2_grpc.py index ef673e9d..9609c76e 100644 --- a/pynumaflow/proto/sinker/sink_pb2_grpc.py +++ b/pynumaflow/proto/sinker/sink_pb2_grpc.py @@ -15,7 +15,7 @@ def __init__(self, channel): Args: channel: A grpc.Channel. """ - self.SinkFn = channel.stream_unary( + self.SinkFn = channel.stream_stream( "/sink.v1.Sink/SinkFn", request_serializer=sink__pb2.SinkRequest.SerializeToString, response_deserializer=sink__pb2.SinkResponse.FromString, @@ -45,7 +45,7 @@ def IsReady(self, request, context): def add_SinkServicer_to_server(servicer, server): rpc_method_handlers = { - "SinkFn": grpc.stream_unary_rpc_method_handler( + "SinkFn": grpc.stream_stream_rpc_method_handler( servicer.SinkFn, request_deserializer=sink__pb2.SinkRequest.FromString, response_serializer=sink__pb2.SinkResponse.SerializeToString, @@ -77,7 +77,7 @@ def SinkFn( timeout=None, metadata=None, ): - return grpc.experimental.stream_unary( + return grpc.experimental.stream_stream( request_iterator, target, "/sink.v1.Sink/SinkFn", diff --git a/pynumaflow/shared/asynciter.py b/pynumaflow/shared/asynciter.py index 3ab6135b..91155b93 100644 --- a/pynumaflow/shared/asynciter.py +++ b/pynumaflow/shared/asynciter.py @@ -8,8 +8,8 @@ class NonBlockingIterator: __slots__ = "_queue" - def __init__(self): - self._queue = asyncio.Queue() + def __init__(self, size=0): + self._queue = asyncio.Queue(maxsize=size) async def read_iterator(self): item = await self._queue.get() diff --git a/pynumaflow/shared/server.py b/pynumaflow/shared/server.py index 2e9de168..888cc7cc 100644 --- a/pynumaflow/shared/server.py +++ b/pynumaflow/shared/server.py @@ -1,3 +1,4 @@ +import asyncio import contextlib import io import multiprocessing @@ -278,3 +279,12 @@ def get_exception_traceback_str(exc) -> str: file = io.StringIO() traceback.print_exception(exc, value=exc, tb=exc.__traceback__, file=file) return file.getvalue().rstrip() + + +async def handle_exception(context, exception): + """Handle exceptions by updating the context and exiting.""" + handle_error(context, exception) + await asyncio.gather( + context.abort(grpc.StatusCode.UNKNOWN, details=repr(exception)), return_exceptions=True + ) + exit_on_error(err=repr(exception), parent=False, context=context, update_context=False) diff --git a/pynumaflow/shared/servicer.py b/pynumaflow/shared/servicer.py new file mode 100644 index 00000000..988330a1 --- /dev/null +++ b/pynumaflow/shared/servicer.py @@ -0,0 +1,3 @@ +def is_valid_handshake(req): + """Check if the handshake message is valid.""" + return req.handshake and req.handshake.sot diff --git a/pynumaflow/sinker/__init__.py b/pynumaflow/sinker/__init__.py index 4df6f270..edf1f5cf 100644 --- a/pynumaflow/sinker/__init__.py +++ b/pynumaflow/sinker/__init__.py @@ -1,6 +1,7 @@ from pynumaflow.sinker.async_server import SinkAsyncServer -from pynumaflow.sinker.server import SinkServer + +# from pynumaflow.sinker.server import SinkServer from pynumaflow.sinker._dtypes import Response, Responses, Datum, Sinker -__all__ = ["Response", "Responses", "Datum", "Sinker", "SinkServer", "SinkAsyncServer"] +__all__ = ["Response", "Responses", "Datum", "Sinker", "SinkAsyncServer"] diff --git a/pynumaflow/sinker/_dtypes.py b/pynumaflow/sinker/_dtypes.py index 1f436a85..e9154fb4 100644 --- a/pynumaflow/sinker/_dtypes.py +++ b/pynumaflow/sinker/_dtypes.py @@ -215,6 +215,26 @@ def handler(self, datums: Iterator[Datum]) -> Responses: pass +@dataclass +class EndOfStreamTransmission: + """ + Basic datatype for UDSink response. + + Args: + """ + + eos: bool + __slots__ = "eos" + + @classmethod + def as_completed(cls): + return EndOfStreamTransmission(eos=True) + + @classmethod + def as_failure(cls): + return EndOfStreamTransmission(eos=False) + + # SyncSinkCallable is a callable which can be used as a handler for the Synchronous UDSink. SinkHandlerCallable = Callable[[Iterator[Datum]], Responses] SyncSinkCallable = Union[Sinker, SinkHandlerCallable] diff --git a/pynumaflow/sinker/server.py b/pynumaflow/sinker/server.py index 10d3dade..d8d58cb4 100644 --- a/pynumaflow/sinker/server.py +++ b/pynumaflow/sinker/server.py @@ -1,120 +1,120 @@ -import os - -from pynumaflow.info.types import ServerInfo, ContainerType, MINIMUM_NUMAFLOW_VERSION -from pynumaflow.sinker.servicer.sync_servicer import SyncSinkServicer - -from pynumaflow._constants import ( - SINK_SOCK_PATH, - MAX_MESSAGE_SIZE, - NUM_THREADS_DEFAULT, - _LOGGER, - UDFType, - SINK_SERVER_INFO_FILE_PATH, - ENV_UD_CONTAINER_TYPE, - UD_CONTAINER_FALLBACK_SINK, - FALLBACK_SINK_SOCK_PATH, - FALLBACK_SINK_SERVER_INFO_FILE_PATH, - MAX_NUM_THREADS, -) - -from pynumaflow.shared.server import NumaflowServer, sync_server_start -from pynumaflow.sinker._dtypes import SyncSinkCallable - - -class SinkServer(NumaflowServer): - """ - SinkServer is the main class to start a gRPC server for a sinker. - """ - - def __init__( - self, - sinker_instance: SyncSinkCallable, - sock_path=SINK_SOCK_PATH, - max_message_size=MAX_MESSAGE_SIZE, - max_threads=NUM_THREADS_DEFAULT, - server_info_file=SINK_SERVER_INFO_FILE_PATH, - ): - """ - Create a new grpc Sink Server instance. - A new servicer instance is created and attached to the server. - The server instance is returned. - Args: - sinker_instance: The sinker instance to be used for Sink UDF - sock_path: The UNIX socket path to be used for the server - max_message_size: The max message size in bytes the server can receive and send - max_threads: The max number of threads to be spawned; - defaults to 4 and max capped at 16 - Example invocation: - import os - from collections.abc import Iterator - - from pynumaflow.sinker import Datum, Responses, Response, SinkServer - from pynumaflow.sinker import Sinker - from pynumaflow._constants import _LOGGER - - class UserDefinedSink(Sinker): - def handler(self, datums: Iterator[Datum]) -> Responses: - responses = Responses() - for msg in datums: - _LOGGER.info("User Defined Sink %s", msg.value.decode("utf-8")) - responses.append(Response.as_success(msg.id)) - return responses - - def udsink_handler(datums: Iterator[Datum]) -> Responses: - responses = Responses() - for msg in datums: - _LOGGER.info("User Defined Sink %s", msg.value.decode("utf-8")) - responses.append(Response.as_success(msg.id)) - return responses - - if __name__ == "__main__": - invoke = os.getenv("INVOKE", "func_handler") - if invoke == "class": - sink_handler = UserDefinedSink() - else: - sink_handler = udsink_handler - grpc_server = SinkServer(sink_handler) - grpc_server.start() - - """ - # If the container type is fallback sink, then use the fallback sink address and path. - if os.getenv(ENV_UD_CONTAINER_TYPE, "") == UD_CONTAINER_FALLBACK_SINK: - _LOGGER.info("Using Fallback Sink") - sock_path = FALLBACK_SINK_SOCK_PATH - server_info_file = FALLBACK_SINK_SERVER_INFO_FILE_PATH - - self.sock_path = f"unix://{sock_path}" - self.max_threads = min(max_threads, MAX_NUM_THREADS) - self.max_message_size = max_message_size - self.server_info_file = server_info_file - - self.sinker_instance = sinker_instance - - self._server_options = [ - ("grpc.max_send_message_length", self.max_message_size), - ("grpc.max_receive_message_length", self.max_message_size), - ] - self.servicer = SyncSinkServicer(sinker_instance) - - def start(self): - """ - Starts the Synchronous gRPC server on the - given UNIX socket with given max threads. - """ - _LOGGER.info( - "Sync GRPC Sink listening on: %s with max threads: %s", - self.sock_path, - self.max_threads, - ) - serv_info = ServerInfo.get_default_server_info() - serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Sinker] - # Start the server - sync_server_start( - servicer=self.servicer, - bind_address=self.sock_path, - max_threads=self.max_threads, - server_info_file=self.server_info_file, - server_options=self._server_options, - udf_type=UDFType.Sink, - server_info=serv_info, - ) +# import os +# +# from pynumaflow.info.types import ServerInfo, ContainerType, MINIMUM_NUMAFLOW_VERSION +# from pynumaflow.sinker.servicer.sync_servicer import SyncSinkServicer +# +# from pynumaflow._constants import ( +# SINK_SOCK_PATH, +# MAX_MESSAGE_SIZE, +# NUM_THREADS_DEFAULT, +# _LOGGER, +# UDFType, +# SINK_SERVER_INFO_FILE_PATH, +# ENV_UD_CONTAINER_TYPE, +# UD_CONTAINER_FALLBACK_SINK, +# FALLBACK_SINK_SOCK_PATH, +# FALLBACK_SINK_SERVER_INFO_FILE_PATH, +# MAX_NUM_THREADS, +# ) +# +# from pynumaflow.shared.server import NumaflowServer, sync_server_start +# from pynumaflow.sinker._dtypes import SyncSinkCallable +# +# +# class SinkServer(NumaflowServer): +# """ +# SinkServer is the main class to start a gRPC server for a sinker. +# """ +# +# def __init__( +# self, +# sinker_instance: SyncSinkCallable, +# sock_path=SINK_SOCK_PATH, +# max_message_size=MAX_MESSAGE_SIZE, +# max_threads=NUM_THREADS_DEFAULT, +# server_info_file=SINK_SERVER_INFO_FILE_PATH, +# ): +# """ +# Create a new grpc Sink Server instance. +# A new servicer instance is created and attached to the server. +# The server instance is returned. +# Args: +# sinker_instance: The sinker instance to be used for Sink UDF +# sock_path: The UNIX socket path to be used for the server +# max_message_size: The max message size in bytes the server can receive and send +# max_threads: The max number of threads to be spawned; +# defaults to 4 and max capped at 16 +# Example invocation: +# import os +# from collections.abc import Iterator +# +# from pynumaflow.sinker import Datum, Responses, Response, SinkServer +# from pynumaflow.sinker import Sinker +# from pynumaflow._constants import _LOGGER +# +# class UserDefinedSink(Sinker): +# def handler(self, datums: Iterator[Datum]) -> Responses: +# responses = Responses() +# for msg in datums: +# _LOGGER.info("User Defined Sink %s", msg.value.decode("utf-8")) +# responses.append(Response.as_success(msg.id)) +# return responses +# +# def udsink_handler(datums: Iterator[Datum]) -> Responses: +# responses = Responses() +# for msg in datums: +# _LOGGER.info("User Defined Sink %s", msg.value.decode("utf-8")) +# responses.append(Response.as_success(msg.id)) +# return responses +# +# if __name__ == "__main__": +# invoke = os.getenv("INVOKE", "func_handler") +# if invoke == "class": +# sink_handler = UserDefinedSink() +# else: +# sink_handler = udsink_handler +# grpc_server = SinkServer(sink_handler) +# grpc_server.start() +# +# """ +# # If the container type is fallback sink, then use the fallback sink address and path. +# if os.getenv(ENV_UD_CONTAINER_TYPE, "") == UD_CONTAINER_FALLBACK_SINK: +# _LOGGER.info("Using Fallback Sink") +# sock_path = FALLBACK_SINK_SOCK_PATH +# server_info_file = FALLBACK_SINK_SERVER_INFO_FILE_PATH +# +# self.sock_path = f"unix://{sock_path}" +# self.max_threads = min(max_threads, MAX_NUM_THREADS) +# self.max_message_size = max_message_size +# self.server_info_file = server_info_file +# +# self.sinker_instance = sinker_instance +# +# self._server_options = [ +# ("grpc.max_send_message_length", self.max_message_size), +# ("grpc.max_receive_message_length", self.max_message_size), +# ] +# self.servicer = SyncSinkServicer(sinker_instance) +# +# def start(self): +# """ +# Starts the Synchronous gRPC server on the +# given UNIX socket with given max threads. +# """ +# _LOGGER.info( +# "Sync GRPC Sink listening on: %s with max threads: %s", +# self.sock_path, +# self.max_threads, +# ) +# serv_info = ServerInfo.get_default_server_info() +# serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Sinker] +# # Start the server +# sync_server_start( +# servicer=self.servicer, +# bind_address=self.sock_path, +# max_threads=self.max_threads, +# server_info_file=self.server_info_file, +# server_options=self._server_options, +# udf_type=UDFType.Sink, +# server_info=serv_info, +# ) diff --git a/pynumaflow/sinker/servicer/async_servicer.py b/pynumaflow/sinker/servicer/async_servicer.py index 9f02d005..f643c36f 100644 --- a/pynumaflow/sinker/servicer/async_servicer.py +++ b/pynumaflow/sinker/servicer/async_servicer.py @@ -1,14 +1,17 @@ +import asyncio from collections.abc import AsyncIterable from google.protobuf import empty_pb2 as _empty_pb2 +from pynumaflow.shared.asynciter import NonBlockingIterator from pynumaflow.shared.server import exit_on_error -from pynumaflow.sinker._dtypes import Datum +from pynumaflow.shared.servicer import is_valid_handshake +from pynumaflow.sinker._dtypes import Datum, AsyncSinkCallable from pynumaflow.sinker._dtypes import SyncSinkCallable from pynumaflow.proto.sinker import sink_pb2_grpc, sink_pb2 from pynumaflow.sinker.servicer.utils import build_sink_response from pynumaflow.types import NumaflowServicerContext -from pynumaflow._constants import _LOGGER +from pynumaflow._constants import _LOGGER, STREAM_EOF async def datum_generator( @@ -16,16 +19,24 @@ async def datum_generator( ) -> AsyncIterable[Datum]: async for d in request_iterator: datum = Datum( - keys=list(d.keys), - sink_msg_id=d.id, - value=d.value, - event_time=d.event_time.ToDatetime(), - watermark=d.watermark.ToDatetime(), - headers=dict(d.headers), + keys=list(d.request.keys), + sink_msg_id=d.request.id, + value=d.request.value, + event_time=d.request.event_time.ToDatetime(), + watermark=d.request.watermark.ToDatetime(), + headers=dict(d.request.headers), ) yield datum +def _create_read_handshake_response(): + """Create a handshake response for the Sink function.""" + return sink_pb2.SinkResponse( + result=sink_pb2.SinkResponse.Result(status=sink_pb2.SUCCESS), + handshake=sink_pb2.Handshake(sot=True), + ) + + class AsyncSinkServicer(sink_pb2_grpc.SinkServicer): """ This class is used to create a new grpc Sink servicer instance. @@ -37,7 +48,8 @@ def __init__( self, handler: SyncSinkCallable, ): - self.__sink_handler: SyncSinkCallable = handler + self.background_tasks = set() + self.__sink_handler: AsyncSinkCallable = handler self.cleanup_coroutines = [] async def SinkFn( @@ -49,32 +61,71 @@ async def SinkFn( Applies a sink function to a list of datum elements. The pascal case function name comes from the proto sink_pb2_grpc.py file. """ - # if there is an exception, we will mark all the responses as a failure - datum_iterator = datum_generator(request_iterator=request_iterator) try: - results = await self.__invoke_sink(datum_iterator, context) + # The first message to be received should be a valid handshake + req = await request_iterator.__anext__() + if not is_valid_handshake(req): + raise Exception("ReadFn: expected handshake message") + await context.write(_create_read_handshake_response()) + + # cur_task is used to track the task (coroutine) processing + # the current batch of messages. + cur_task = None + # iterate of the incoming messages ot the sink + async for d in request_iterator: + # if we do not have any active task currently processing the batch + # we need to create one and call the User function for processing the same. + if cur_task is None: + req_queue = NonBlockingIterator() + cur_task = asyncio.create_task( + self.__invoke_sink(req_queue.read_iterator(), context) + ) + self.background_tasks.add(cur_task) + cur_task.add_done_callback(self.background_tasks.discard) + + # when we have end of transmission message, we need to stop the processing the + # current batch and wait for the next batch of messages. + # We will also wait for the current task to finish processing the current batch. + # We mark the current task as None to indicate that we are + # ready to process the next batch. + if d.status and d.status.eot: + await req_queue.put(STREAM_EOF) + await cur_task + cur_task = None + continue + + # if we have a valid message, we will add it to the request queue for processing. + datum = Datum( + keys=list(d.request.keys), + sink_msg_id=d.request.id, + value=d.request.value, + event_time=d.request.event_time.ToDatetime(), + watermark=d.request.watermark.ToDatetime(), + headers=dict(d.request.headers), + ) + await req_queue.put(datum) except BaseException as err: + # if there is an exception, we will mark all the responses as a failure err_msg = f"UDSinkError: {repr(err)}" _LOGGER.critical(err_msg, exc_info=True) exit_on_error(context, err_msg) return - return sink_pb2.SinkResponse(results=results) - async def __invoke_sink( - self, datum_iterator: AsyncIterable[Datum], context: NumaflowServicerContext + self, request_queue: AsyncIterable[Datum], context: NumaflowServicerContext ): try: - rspns = await self.__sink_handler(datum_iterator) + # invoke the user function with the request queue + rspns = await self.__sink_handler(request_queue) + # for each response, we will write the response back from the rpc. + for rspn in rspns: + sink_rsp = build_sink_response(rspn) + await context.write(sink_pb2.SinkResponse(result=sink_rsp)) except BaseException as err: err_msg = f"UDSinkError: {repr(err)}" _LOGGER.critical(err_msg, exc_info=True) exit_on_error(context, err_msg) raise err - responses = [] - for rspn in rspns: - responses.append(build_sink_response(rspn)) - return responses async def IsReady( self, request: _empty_pb2.Empty, context: NumaflowServicerContext @@ -84,3 +135,9 @@ async def IsReady( The pascal case function name comes from the proto sink_pb2_grpc.py file. """ return sink_pb2.ReadyResponse(ready=True) + + def clean_background(self, task): + """ + Remove the task from the background tasks collection + """ + self.background_tasks.remove(task) diff --git a/pynumaflow/sinker/servicer/sync_servicer.py b/pynumaflow/sinker/servicer/sync_servicer.py index a1f307d1..f1830b1c 100644 --- a/pynumaflow/sinker/servicer/sync_servicer.py +++ b/pynumaflow/sinker/servicer/sync_servicer.py @@ -1,69 +1,69 @@ -from collections.abc import Iterator, Iterable - -from google.protobuf import empty_pb2 as _empty_pb2 -from pynumaflow._constants import _LOGGER -from pynumaflow.shared.server import exit_on_error -from pynumaflow.sinker._dtypes import Datum -from pynumaflow.sinker._dtypes import SyncSinkCallable -from pynumaflow.proto.sinker import sink_pb2_grpc, sink_pb2 -from pynumaflow.sinker.servicer.utils import build_sink_response -from pynumaflow.types import NumaflowServicerContext - - -def datum_generator(request_iterator: Iterable[sink_pb2.SinkRequest]) -> Iterable[Datum]: - for d in request_iterator: - datum = Datum( - keys=list(d.keys), - sink_msg_id=d.id, - value=d.value, - event_time=d.event_time.ToDatetime(), - watermark=d.watermark.ToDatetime(), - headers=dict(d.headers), - ) - yield datum - - -class SyncSinkServicer(sink_pb2_grpc.SinkServicer): - """ - This class is used to create a new grpc Sink servicer instance. - It implements the SinkServicer interface from the proto sink.proto file. - Provides the functionality for the required rpc methods. - """ - - def __init__( - self, - handler: SyncSinkCallable, - ): - self.__sink_handler: SyncSinkCallable = handler - - def SinkFn( - self, request_iterator: Iterator[sink_pb2.SinkRequest], context: NumaflowServicerContext - ) -> sink_pb2.SinkResponse: - """ - Applies a sink function to a list of datum elements. - The pascal case function name comes from the proto sink_pb2_grpc.py file. - """ - # if there is an exception, we will mark all the responses as a failure - datum_iterator = datum_generator(request_iterator) - try: - rspns = self.__sink_handler(datum_iterator) - except BaseException as err: - err_msg = f"UDSinkError: {repr(err)}" - _LOGGER.critical(err_msg, exc_info=True) - exit_on_error(context, err_msg) - return - - responses = [] - for rspn in rspns: - responses.append(build_sink_response(rspn)) - - return sink_pb2.SinkResponse(results=responses) - - def IsReady( - self, request: _empty_pb2.Empty, context: NumaflowServicerContext - ) -> sink_pb2.ReadyResponse: - """ - IsReady is the heartbeat endpoint for gRPC. - The pascal case function name comes from the proto sink_pb2_grpc.py file. - """ - return sink_pb2.ReadyResponse(ready=True) +# from collections.abc import Iterator, Iterable +# +# from google.protobuf import empty_pb2 as _empty_pb2 +# from pynumaflow._constants import _LOGGER +# from pynumaflow.shared.server import exit_on_error +# from pynumaflow.sinker._dtypes import Datum +# from pynumaflow.sinker._dtypes import SyncSinkCallable +# from pynumaflow.proto.sinker import sink_pb2_grpc, sink_pb2 +# from pynumaflow.sinker.servicer.utils import build_sink_response +# from pynumaflow.types import NumaflowServicerContext +# +# +# def datum_generator(request_iterator: Iterable[sink_pb2.SinkRequest]) -> Iterable[Datum]: +# for d in request_iterator: +# datum = Datum( +# keys=list(d.keys), +# sink_msg_id=d.id, +# value=d.value, +# event_time=d.event_time.ToDatetime(), +# watermark=d.watermark.ToDatetime(), +# headers=dict(d.headers), +# ) +# yield datum +# +# +# class SyncSinkServicer(sink_pb2_grpc.SinkServicer): +# """ +# This class is used to create a new grpc Sink servicer instance. +# It implements the SinkServicer interface from the proto sink.proto file. +# Provides the functionality for the required rpc methods. +# """ +# +# def __init__( +# self, +# handler: SyncSinkCallable, +# ): +# self.__sink_handler: SyncSinkCallable = handler +# +# def SinkFn( +# self, request_iterator: Iterator[sink_pb2.SinkRequest], context: NumaflowServicerContext +# ) -> sink_pb2.SinkResponse: +# """ +# Applies a sink function to a list of datum elements. +# The pascal case function name comes from the proto sink_pb2_grpc.py file. +# """ +# # if there is an exception, we will mark all the responses as a failure +# datum_iterator = datum_generator(request_iterator) +# try: +# rspns = self.__sink_handler(datum_iterator) +# except BaseException as err: +# err_msg = f"UDSinkError: {repr(err)}" +# _LOGGER.critical(err_msg, exc_info=True) +# exit_on_error(context, err_msg) +# return +# +# responses = [] +# for rspn in rspns: +# responses.append(build_sink_response(rspn)) +# +# return sink_pb2.SinkResponse(results=responses) +# +# def IsReady( +# self, request: _empty_pb2.Empty, context: NumaflowServicerContext +# ) -> sink_pb2.ReadyResponse: +# """ +# IsReady is the heartbeat endpoint for gRPC. +# The pascal case function name comes from the proto sink_pb2_grpc.py file. +# """ +# return sink_pb2.ReadyResponse(ready=True) diff --git a/pynumaflow/sourcer/servicer/async_servicer.py b/pynumaflow/sourcer/servicer/async_servicer.py index dd07478b..23778c75 100644 --- a/pynumaflow/sourcer/servicer/async_servicer.py +++ b/pynumaflow/sourcer/servicer/async_servicer.py @@ -1,12 +1,12 @@ import asyncio from collections.abc import AsyncIterable -import grpc from google.protobuf import timestamp_pb2 as _timestamp_pb2 from google.protobuf import empty_pb2 as _empty_pb2 from pynumaflow.shared.asynciter import NonBlockingIterator -from pynumaflow.shared.server import exit_on_error, handle_error +from pynumaflow.shared.server import exit_on_error, handle_exception +from pynumaflow.shared.servicer import is_valid_handshake from pynumaflow.sourcer._dtypes import ReadRequest from pynumaflow.sourcer._dtypes import AckRequest, SourceCallable from pynumaflow.proto.sourcer import source_pb2 @@ -15,15 +15,6 @@ from pynumaflow._constants import _LOGGER, STREAM_EOF -async def _handle_exception(context, exception): - """Handle exceptions by updating the context and exiting.""" - handle_error(context, exception) - await asyncio.gather( - context.abort(grpc.StatusCode.UNKNOWN, details=repr(exception)), return_exceptions=True - ) - exit_on_error(err=repr(exception), parent=False, context=context, update_context=False) - - def _create_read_handshake_response(): """Create a handshake response for the Read function.""" return source_pb2.ReadResponse( @@ -99,7 +90,7 @@ async def ReadFn( try: # The first message to be received should be a valid handshake req = await request_iterator.__anext__() - if not _is_valid_handshake(req): + if not is_valid_handshake(req): raise Exception("ReadFn: expected handshake message") yield _create_read_handshake_response() @@ -117,7 +108,7 @@ async def ReadFn( async for resp in riter: if isinstance(resp, BaseException): - await _handle_exception(context, resp) + await handle_exception(context, resp) return yield _create_read_response(resp) @@ -157,7 +148,7 @@ async def AckFn( try: # The first message to be received should be a valid handshake req = await request_iterator.__anext__() - if not _is_valid_handshake(req): + if not is_valid_handshake(req): raise Exception("AckFn: expected handshake message") yield _create_ack_handshake_response() @@ -214,8 +205,3 @@ def clean_background(self, task): Remove the task from the background tasks collection """ self.background_tasks.remove(task) - - -def _is_valid_handshake(req): - """Check if the handshake message is valid.""" - return req.handshake and req.handshake.sot diff --git a/tests/sink/test_async_sink.py b/tests/sink/test_async_sink.py index f04230cd..73ec135b 100644 --- a/tests/sink/test_async_sink.py +++ b/tests/sink/test_async_sink.py @@ -44,13 +44,7 @@ async def udsink_handler(datums: AsyncIterable[Datum]) -> Responses: return responses -def request_generator(count, request): - for i in range(count): - request.id = str(i) - yield request - - -def start_sink_streaming_request(req_type="success") -> (Datum, tuple): +def start_sink_streaming_request(_id: str, req_type) -> (Datum, tuple): event_time_timestamp, watermark_timestamp = get_time_args() value = mock_message() if req_type == "err": @@ -59,12 +53,20 @@ def start_sink_streaming_request(req_type="success") -> (Datum, tuple): if req_type == "fallback": value = mock_fallback_message() - request = sink_pb2.SinkRequest( - value=value, - event_time=event_time_timestamp, - watermark=watermark_timestamp, + request = sink_pb2.SinkRequest.Request( + value=value, event_time=event_time_timestamp, watermark=watermark_timestamp, id=_id ) - return request + return sink_pb2.SinkRequest(request=request) + + +def request_generator(count, req_type="success", session=1): + yield sink_pb2.SinkRequest(handshake=sink_pb2.Handshake(sot=True)) + + for j in range(session): + for i in range(count): + yield start_sink_streaming_request(str(i), req_type) + + yield sink_pb2.SinkRequest(status=sink_pb2.SinkRequest.Status(eot=True)) _s: Server = None @@ -137,50 +139,62 @@ def test_run_server(self) -> None: def test_sink(self) -> None: stub = self.__stub() - request = start_sink_streaming_request() - print(request) generator_response = None try: generator_response = stub.SinkFn( - request_iterator=request_generator(count=10, request=request) + request_iterator=request_generator(count=10, req_type="success", session=1) ) + handshake = next(generator_response) + # assert that handshake response is received. + self.assertTrue(handshake.handshake.sot) + cnt = 0 + for r in generator_response: + # capture the output from the SinkFn generator and assert. + self.assertEqual(r.result.status, sink_pb2.Status.SUCCESS) + cnt += 1 + # 10 sink responses + self.assertEqual(10, cnt) except grpc.RpcError as e: logging.error(e) - # capture the output from the ReduceFn generator and assert. - self.assertEqual(10, len(generator_response.results)) - for x in generator_response.results: - self.assertEqual(x.status, sink_pb2.Status.SUCCESS) - def test_sink_err(self) -> None: stub = self.__stub() - request = start_sink_streaming_request(req_type="err") - grpcException = None + grpc_exception = None try: - stub.SinkFn(request_iterator=request_generator(count=10, request=request)) + generator_response = stub.SinkFn( + request_iterator=request_generator(count=10, req_type="err") + ) + for _ in generator_response: + pass + except BaseException as e: + self.assertTrue("UDSinkError: ValueError('test_mock_err_message')" in e.__str__()) + return except grpc.RpcError as e: - grpcException = e + grpc_exception = e self.assertEqual(grpc.StatusCode.UNKNOWN, e.code()) - logging.error(e) + print(e.details()) - self.assertIsNotNone(grpcException) + self.assertIsNotNone(grpc_exception) def test_sink_fallback(self) -> None: stub = self.__stub() - request = start_sink_streaming_request(req_type="fallback") - generator_response = None try: generator_response = stub.SinkFn( - request_iterator=request_generator(count=10, request=request) + request_iterator=request_generator(count=10, req_type="fallback", session=1) ) + cnt = 0 + handshake = next(generator_response) + # assert that handshake response is received. + self.assertTrue(handshake.handshake.sot) + for r in generator_response: + # capture the output from the SinkFn generator and assert. + self.assertEqual(r.result.status, sink_pb2.Status.FALLBACK) + cnt += 1 + # 10 sink responses + self.assertEqual(10, cnt) except grpc.RpcError as e: logging.error(e) - # capture the output from the ReduceFn generator and assert. - self.assertEqual(10, len(generator_response.results)) - for x in generator_response.results: - self.assertEqual(x.status, sink_pb2.Status.FALLBACK) - def __stub(self): return sink_pb2_grpc.SinkStub(_channel) diff --git a/tests/sink/test_server.py b/tests/sink/test_server.py index d5fa8b17..adf95dfe 100644 --- a/tests/sink/test_server.py +++ b/tests/sink/test_server.py @@ -1,23 +1,10 @@ import os import unittest -from datetime import datetime, timezone from collections.abc import Iterator +from datetime import datetime, timezone from unittest import mock -from unittest.mock import patch - -from google.protobuf import empty_pb2 as _empty_pb2 -from google.protobuf import timestamp_pb2 as _timestamp_pb2 -from grpc import StatusCode -from grpc_testing import server_from_dictionary, strict_real_time -from pynumaflow._constants import ( - UD_CONTAINER_FALLBACK_SINK, - FALLBACK_SINK_SERVER_INFO_FILE_PATH, - FALLBACK_SINK_SOCK_PATH, -) -from pynumaflow.sinker import Responses, Datum, Response, SinkServer -from pynumaflow.proto.sinker import sink_pb2 -from tests.testing_utils import mock_terminate_on_stop +from pynumaflow.sinker import Responses, Datum, Response def mockenv(**envvars): @@ -65,153 +52,154 @@ def mock_watermark(): return t -# We are mocking the terminate function from the psutil to not exit the program during testing -@patch("psutil.Process.kill", mock_terminate_on_stop) -class TestServer(unittest.TestCase): - def setUp(self) -> None: - server = SinkServer(sinker_instance=udsink_handler) - my_servicer = server.servicer - services = {sink_pb2.DESCRIPTOR.services_by_name["Sink"]: my_servicer} - self.test_server = server_from_dictionary(services, strict_real_time()) - - def test_is_ready(self): - method = self.test_server.invoke_unary_unary( - method_descriptor=( - sink_pb2.DESCRIPTOR.services_by_name["Sink"].methods_by_name["IsReady"] - ), - invocation_metadata={}, - request=_empty_pb2.Empty(), - timeout=1, - ) - - response, metadata, code, details = method.termination() - expected = sink_pb2.ReadyResponse(ready=True) - self.assertEqual(expected, response) - self.assertEqual(code, StatusCode.OK) - - def test_udsink_err(self): - server = SinkServer(sinker_instance=err_udsink_handler) - my_servicer = server.servicer - services = {sink_pb2.DESCRIPTOR.services_by_name["Sink"]: my_servicer} - self.test_server = server_from_dictionary(services, strict_real_time()) - - event_time_timestamp = _timestamp_pb2.Timestamp() - event_time_timestamp.FromDatetime(dt=mock_event_time()) - watermark_timestamp = _timestamp_pb2.Timestamp() - watermark_timestamp.FromDatetime(dt=mock_watermark()) - - test_datums = [ - sink_pb2.SinkRequest( - id="test_id_0", - value=mock_message(), - event_time=event_time_timestamp, - watermark=watermark_timestamp, - ), - sink_pb2.SinkRequest( - id="test_id_1", - value=mock_err_message(), - event_time=event_time_timestamp, - watermark=watermark_timestamp, - ), - sink_pb2.SinkRequest( - id="test_id_2", - value=mock_fallback_message(), - event_time=event_time_timestamp, - watermark=watermark_timestamp, - ), - ] - - method = self.test_server.invoke_stream_unary( - method_descriptor=( - sink_pb2.DESCRIPTOR.services_by_name["Sink"].methods_by_name["SinkFn"] - ), - invocation_metadata={}, - timeout=1, - ) - - method.send_request(test_datums[0]) - method.send_request(test_datums[1]) - method.send_request(test_datums[2]) - method.requests_closed() - - response, metadata, code, details = method.termination() - self.assertEqual(StatusCode.UNKNOWN, code) - - def test_forward_message(self): - event_time_timestamp = _timestamp_pb2.Timestamp() - event_time_timestamp.FromDatetime(dt=mock_event_time()) - watermark_timestamp = _timestamp_pb2.Timestamp() - watermark_timestamp.FromDatetime(dt=mock_watermark()) - - test_datums = [ - sink_pb2.SinkRequest( - id="test_id_0", - value=mock_message(), - event_time=event_time_timestamp, - watermark=watermark_timestamp, - ), - sink_pb2.SinkRequest( - id="test_id_1", - value=mock_err_message(), - event_time=event_time_timestamp, - watermark=watermark_timestamp, - ), - sink_pb2.SinkRequest( - id="test_id_2", - value=mock_fallback_message(), - event_time=event_time_timestamp, - watermark=watermark_timestamp, - ), - ] - - method = self.test_server.invoke_stream_unary( - method_descriptor=( - sink_pb2.DESCRIPTOR.services_by_name["Sink"].methods_by_name["SinkFn"] - ), - invocation_metadata={}, - timeout=1, - ) - - method.send_request(test_datums[0]) - method.send_request(test_datums[1]) - method.send_request(test_datums[2]) - method.requests_closed() - - response, metadata, code, details = method.termination() - self.assertEqual(3, len(response.results)) - self.assertEqual("test_id_0", response.results[0].id) - self.assertEqual("test_id_1", response.results[1].id) - self.assertEqual("test_id_2", response.results[2].id) - self.assertEqual(response.results[0].status, sink_pb2.Status.SUCCESS) - self.assertEqual(response.results[1].status, sink_pb2.Status.FAILURE) - self.assertEqual(response.results[2].status, sink_pb2.Status.FALLBACK) - self.assertEqual("", response.results[0].err_msg) - self.assertEqual("mock sink message error", response.results[1].err_msg) - self.assertEqual("", response.results[2].err_msg) - self.assertEqual(code, StatusCode.OK) - - def test_invalid_init(self): - with self.assertRaises(TypeError): - SinkServer() - - @mockenv(NUMAFLOW_UD_CONTAINER_TYPE=UD_CONTAINER_FALLBACK_SINK) - def test_start_fallback_sink(self): - server = SinkServer(sinker_instance=udsink_handler) - self.assertEqual(server.sock_path, f"unix://{FALLBACK_SINK_SOCK_PATH}") - self.assertEqual(server.server_info_file, FALLBACK_SINK_SERVER_INFO_FILE_PATH) - - def test_max_threads(self): - # max cap at 16 - server = SinkServer(sinker_instance=udsink_handler, max_threads=32) - self.assertEqual(server.max_threads, 16) - - # use argument provided - server = SinkServer(sinker_instance=udsink_handler, max_threads=5) - self.assertEqual(server.max_threads, 5) - - # defaults to 4 - server = SinkServer(sinker_instance=udsink_handler) - self.assertEqual(server.max_threads, 4) +# +# # We are mocking the terminate function from the psutil to not exit the program during testing +# @patch("psutil.Process.kill", mock_terminate_on_stop) +# class TestServer(unittest.TestCase): +# def setUp(self) -> None: +# server = SinkServer(sinker_instance=udsink_handler) +# my_servicer = server.servicer +# services = {sink_pb2.DESCRIPTOR.services_by_name["Sink"]: my_servicer} +# self.test_server = server_from_dictionary(services, strict_real_time()) +# +# def test_is_ready(self): +# method = self.test_server.invoke_unary_unary( +# method_descriptor=( +# sink_pb2.DESCRIPTOR.services_by_name["Sink"].methods_by_name["IsReady"] +# ), +# invocation_metadata={}, +# request=_empty_pb2.Empty(), +# timeout=1, +# ) +# +# response, metadata, code, details = method.termination() +# expected = sink_pb2.ReadyResponse(ready=True) +# self.assertEqual(expected, response) +# self.assertEqual(code, StatusCode.OK) +# +# def test_udsink_err(self): +# server = SinkServer(sinker_instance=err_udsink_handler) +# my_servicer = server.servicer +# services = {sink_pb2.DESCRIPTOR.services_by_name["Sink"]: my_servicer} +# self.test_server = server_from_dictionary(services, strict_real_time()) +# +# event_time_timestamp = _timestamp_pb2.Timestamp() +# event_time_timestamp.FromDatetime(dt=mock_event_time()) +# watermark_timestamp = _timestamp_pb2.Timestamp() +# watermark_timestamp.FromDatetime(dt=mock_watermark()) +# +# test_datums = [ +# sink_pb2.SinkRequest( +# id="test_id_0", +# value=mock_message(), +# event_time=event_time_timestamp, +# watermark=watermark_timestamp, +# ), +# sink_pb2.SinkRequest( +# id="test_id_1", +# value=mock_err_message(), +# event_time=event_time_timestamp, +# watermark=watermark_timestamp, +# ), +# sink_pb2.SinkRequest( +# id="test_id_2", +# value=mock_fallback_message(), +# event_time=event_time_timestamp, +# watermark=watermark_timestamp, +# ), +# ] +# +# method = self.test_server.invoke_stream_unary( +# method_descriptor=( +# sink_pb2.DESCRIPTOR.services_by_name["Sink"].methods_by_name["SinkFn"] +# ), +# invocation_metadata={}, +# timeout=1, +# ) +# +# method.send_request(test_datums[0]) +# method.send_request(test_datums[1]) +# method.send_request(test_datums[2]) +# method.requests_closed() +# +# response, metadata, code, details = method.termination() +# self.assertEqual(StatusCode.UNKNOWN, code) +# +# def test_forward_message(self): +# event_time_timestamp = _timestamp_pb2.Timestamp() +# event_time_timestamp.FromDatetime(dt=mock_event_time()) +# watermark_timestamp = _timestamp_pb2.Timestamp() +# watermark_timestamp.FromDatetime(dt=mock_watermark()) +# +# test_datums = [ +# sink_pb2.SinkRequest( +# id="test_id_0", +# value=mock_message(), +# event_time=event_time_timestamp, +# watermark=watermark_timestamp, +# ), +# sink_pb2.SinkRequest( +# id="test_id_1", +# value=mock_err_message(), +# event_time=event_time_timestamp, +# watermark=watermark_timestamp, +# ), +# sink_pb2.SinkRequest( +# id="test_id_2", +# value=mock_fallback_message(), +# event_time=event_time_timestamp, +# watermark=watermark_timestamp, +# ), +# ] +# +# method = self.test_server.invoke_stream_unary( +# method_descriptor=( +# sink_pb2.DESCRIPTOR.services_by_name["Sink"].methods_by_name["SinkFn"] +# ), +# invocation_metadata={}, +# timeout=1, +# ) +# +# method.send_request(test_datums[0]) +# method.send_request(test_datums[1]) +# method.send_request(test_datums[2]) +# method.requests_closed() +# +# response, metadata, code, details = method.termination() +# self.assertEqual(3, len(response.results)) +# self.assertEqual("test_id_0", response.results[0].id) +# self.assertEqual("test_id_1", response.results[1].id) +# self.assertEqual("test_id_2", response.results[2].id) +# self.assertEqual(response.results[0].status, sink_pb2.Status.SUCCESS) +# self.assertEqual(response.results[1].status, sink_pb2.Status.FAILURE) +# self.assertEqual(response.results[2].status, sink_pb2.Status.FALLBACK) +# self.assertEqual("", response.results[0].err_msg) +# self.assertEqual("mock sink message error", response.results[1].err_msg) +# self.assertEqual("", response.results[2].err_msg) +# self.assertEqual(code, StatusCode.OK) +# +# def test_invalid_init(self): +# with self.assertRaises(TypeError): +# SinkServer() +# +# @mockenv(NUMAFLOW_UD_CONTAINER_TYPE=UD_CONTAINER_FALLBACK_SINK) +# def test_start_fallback_sink(self): +# server = SinkServer(sinker_instance=udsink_handler) +# self.assertEqual(server.sock_path, f"unix://{FALLBACK_SINK_SOCK_PATH}") +# self.assertEqual(server.server_info_file, FALLBACK_SINK_SERVER_INFO_FILE_PATH) +# +# def test_max_threads(self): +# # max cap at 16 +# server = SinkServer(sinker_instance=udsink_handler, max_threads=32) +# self.assertEqual(server.max_threads, 16) +# +# # use argument provided +# server = SinkServer(sinker_instance=udsink_handler, max_threads=5) +# self.assertEqual(server.max_threads, 5) +# +# # defaults to 4 +# server = SinkServer(sinker_instance=udsink_handler) +# self.assertEqual(server.max_threads, 4) if __name__ == "__main__":