Skip to content

Commit

Permalink
Merge pull request #306 Fix handle stop partition request
Browse files Browse the repository at this point in the history
  • Loading branch information
rekby authored May 25, 2023
2 parents 759ccfc + 0a9d7fa commit cd93e4a
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 34 deletions.
20 changes: 20 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,26 @@ async def topic2_path(driver, topic_consumer, database) -> str:
return topic_path


@pytest.fixture()
@pytest.mark.asyncio()
async def topic_with_two_partitions_path(driver, topic_consumer, database) -> str:
topic_path = database + "/test-topic-two-partitions"

try:
await driver.topic_client.drop_topic(topic_path)
except issues.SchemeError:
pass

await driver.topic_client.create_topic(
path=topic_path,
consumers=[topic_consumer],
min_active_partitions=2,
partition_count_limit=2,
)

return topic_path


@pytest.fixture()
@pytest.mark.asyncio()
async def topic_with_messages(driver, topic_consumer, database):
Expand Down
44 changes: 44 additions & 0 deletions tests/topics/test_topic_reader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio

import pytest

import ydb
Expand Down Expand Up @@ -161,3 +163,45 @@ def decode(b: bytes):
with driver_sync.topic_client.reader(topic_path, topic_consumer, decoders={codec: decode}) as reader:
batch = reader.receive_batch()
assert batch.messages[0].data.decode() == "123"


@pytest.mark.asyncio
class TestBugFixesAsync:
async def test_issue_297_bad_handle_stop_partition(
self, driver, topic_consumer, topic_with_two_partitions_path: str
):
async def wait(fut):
return await asyncio.wait_for(fut, timeout=10)

topic = topic_with_two_partitions_path # type: str

async with driver.topic_client.writer(topic, partition_id=0) as writer:
await writer.write_with_ack("00")

async with driver.topic_client.writer(topic, partition_id=1) as writer:
await writer.write_with_ack("01")

# Start first reader and receive messages from both partitions
reader0 = driver.topic_client.reader(topic, consumer=topic_consumer)
await wait(reader0.receive_message())
await wait(reader0.receive_message())

# Start second reader for same topic, same consumer, partition 1
reader1 = driver.topic_client.reader(topic, consumer=topic_consumer)

# receive uncommited message
await reader1.receive_message()

# write one message for every partition
async with driver.topic_client.writer(topic, partition_id=0) as writer:
await writer.write_with_ack("10")
async with driver.topic_client.writer(topic, partition_id=1) as writer:
await writer.write_with_ack("11")

msg0 = await wait(reader0.receive_message())
msg1 = await wait(reader1.receive_message())

datas = [msg0.data.decode(), msg1.data.decode()]
datas.sort()

assert datas == ["10", "11"]
69 changes: 61 additions & 8 deletions ydb/_grpc/grpcwrapper/ydb_topic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import datetime
import enum
import typing
Expand All @@ -8,6 +10,7 @@

from . import ydb_topic_public_types
from ... import scheme
from ... import issues

# Workaround for good IDE and universal for runtime
if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -588,16 +591,32 @@ def from_proto(
)

@dataclass
class PartitionSessionStatusRequest:
class PartitionSessionStatusRequest(IToProto):
partition_session_id: int

def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.PartitionSessionStatusRequest:
return ydb_topic_pb2.StreamReadMessage.PartitionSessionStatusRequest(
partition_session_id=self.partition_session_id
)

@dataclass
class PartitionSessionStatusResponse:
class PartitionSessionStatusResponse(IFromProto):
partition_session_id: int
partition_offsets: "OffsetsRange"
committed_offset: int
write_time_high_watermark: float

@staticmethod
def from_proto(
msg: ydb_topic_pb2.StreamReadMessage.PartitionSessionStatusResponse,
) -> "StreamReadMessage.PartitionSessionStatusResponse":
return StreamReadMessage.PartitionSessionStatusResponse(
partition_session_id=msg.partition_session_id,
partition_offsets=OffsetsRange.from_proto(msg.partition_offsets),
committed_offset=msg.committed_offset,
write_time_high_watermark=msg.write_time_high_watermark,
)

@dataclass
class StartPartitionSessionRequest(IFromProto):
partition_session: "StreamReadMessage.PartitionSession"
Expand Down Expand Up @@ -632,15 +651,30 @@ def to_proto(
return res

@dataclass
class StopPartitionSessionRequest:
class StopPartitionSessionRequest(IFromProto):
partition_session_id: int
graceful: bool
committed_offset: int

@staticmethod
def from_proto(
msg: ydb_topic_pb2.StreamReadMessage.StopPartitionSessionRequest,
) -> StreamReadMessage.StopPartitionSessionRequest:
return StreamReadMessage.StopPartitionSessionRequest(
partition_session_id=msg.partition_session_id,
graceful=msg.graceful,
committed_offset=msg.committed_offset,
)

@dataclass
class StopPartitionSessionResponse:
class StopPartitionSessionResponse(IToProto):
partition_session_id: int

def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.StopPartitionSessionResponse:
return ydb_topic_pb2.StreamReadMessage.StopPartitionSessionResponse(
partition_session_id=self.partition_session_id,
)

@dataclass
class FromClient(IToProto):
client_message: "ReaderMessagesFromClientToServer"
Expand All @@ -660,6 +694,10 @@ def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.FromClient:
res.update_token_request.CopyFrom(self.client_message.to_proto())
elif isinstance(self.client_message, StreamReadMessage.StartPartitionSessionResponse):
res.start_partition_session_response.CopyFrom(self.client_message.to_proto())
elif isinstance(self.client_message, StreamReadMessage.StopPartitionSessionResponse):
res.stop_partition_session_response.CopyFrom(self.client_message.to_proto())
elif isinstance(self.client_message, StreamReadMessage.PartitionSessionStatusRequest):
res.start_partition_session_response.CopyFrom(self.client_message.to_proto())
else:
raise NotImplementedError("Unknown message type: %s" % type(self.client_message))
return res
Expand Down Expand Up @@ -694,17 +732,32 @@ def from_proto(
return StreamReadMessage.FromServer(
server_status=server_status,
server_message=StreamReadMessage.StartPartitionSessionRequest.from_proto(
msg.start_partition_session_request
msg.start_partition_session_request,
),
)
elif mess_type == "stop_partition_session_request":
return StreamReadMessage.FromServer(
server_status=server_status,
server_message=StreamReadMessage.StopPartitionSessionRequest.from_proto(
msg.stop_partition_session_request
),
)
elif mess_type == "update_token_response":
return StreamReadMessage.FromServer(
server_status=server_status,
server_message=UpdateTokenResponse.from_proto(msg.update_token_response),
)

# todo replace exception to log
raise NotImplementedError()
elif mess_type == "partition_session_status_response":
return StreamReadMessage.FromServer(
server_status=server_status,
server_message=StreamReadMessage.PartitionSessionStatusResponse.from_proto(
msg.partition_session_status_response
),
)
else:
raise issues.UnexpectedGrpcMessage(
"Unexpected message while parse ReaderMessagesFromServerToClient: '%s'" % mess_type
)


ReaderMessagesFromClientToServer = Union[
Expand Down
53 changes: 29 additions & 24 deletions ydb/_topic_reader/topic_reader_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
Codec,
)
from .._errors import check_retriable_error
import logging

logger = logging.getLogger(__name__)


class TopicReaderError(YdbError):
Expand Down Expand Up @@ -146,7 +149,6 @@ class ReaderReconnector:

def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings):
self._id = self._static_reader_reconnector_counter.inc_and_get()

self._settings = settings
self._driver = driver
self._background_tasks = set()
Expand Down Expand Up @@ -395,39 +397,42 @@ async def _read_messages_loop(self):
)
)
while True:
message = await self._stream.receive() # type: StreamReadMessage.FromServer
_process_response(message.server_status)
try:
message = await self._stream.receive() # type: StreamReadMessage.FromServer
_process_response(message.server_status)

if isinstance(message.server_message, StreamReadMessage.ReadResponse):
self._on_read_response(message.server_message)
if isinstance(message.server_message, StreamReadMessage.ReadResponse):
self._on_read_response(message.server_message)

elif isinstance(message.server_message, StreamReadMessage.CommitOffsetResponse):
self._on_commit_response(message.server_message)
elif isinstance(message.server_message, StreamReadMessage.CommitOffsetResponse):
self._on_commit_response(message.server_message)

elif isinstance(
message.server_message,
StreamReadMessage.StartPartitionSessionRequest,
):
self._on_start_partition_session(message.server_message)
elif isinstance(
message.server_message,
StreamReadMessage.StartPartitionSessionRequest,
):
self._on_start_partition_session(message.server_message)

elif isinstance(
message.server_message,
StreamReadMessage.StopPartitionSessionRequest,
):
self._on_partition_session_stop(message.server_message)
elif isinstance(
message.server_message,
StreamReadMessage.StopPartitionSessionRequest,
):
self._on_partition_session_stop(message.server_message)

elif isinstance(message.server_message, UpdateTokenResponse):
self._update_token_event.set()
elif isinstance(message.server_message, UpdateTokenResponse):
self._update_token_event.set()

else:
raise NotImplementedError(
"Unexpected type of StreamReadMessage.FromServer message: %s" % message.server_message
)
else:
raise issues.UnexpectedGrpcMessage(
"Unexpected message in _read_messages_loop: %s" % type(message.server_message)
)
except issues.UnexpectedGrpcMessage as e:
logger.exception("unexpected message in stream reader: %s" % e)

self._state_changed.set()
except Exception as e:
self._set_first_error(e)
raise
return

async def _update_token_loop(self):
while True:
Expand Down
23 changes: 23 additions & 0 deletions ydb/_topic_reader/topic_reader_asyncio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,6 +1127,29 @@ async def test_update_token(self, stream):

await reader.close()

async def test_read_unknown_message(self, stream, stream_reader, caplog):
class TestMessage:
pass

# noinspection PyTypeChecker
stream.from_server.put_nowait(
StreamReadMessage.FromServer(
server_status=ServerStatus(
status=issues.StatusCode.SUCCESS,
issues=[],
),
server_message=TestMessage(),
)
)

def logged():
for rec in caplog.records:
if TestMessage.__name__ in rec.message:
return True
return False

await wait_condition(logged)


@pytest.mark.asyncio
class TestReaderReconnector:
Expand Down
5 changes: 5 additions & 0 deletions ydb/issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ class SessionPoolEmpty(Error, queue.Empty):
status = StatusCode.SESSION_POOL_EMPTY


class UnexpectedGrpcMessage(Error):
def __init__(self, message: str):
super().__init__(message)


def _format_issues(issues):
if not issues:
return ""
Expand Down
4 changes: 2 additions & 2 deletions ydb/topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def reader(
if not decoder_executor:
decoder_executor = self._executor

args = locals()
args = locals().copy()
del args["self"]

settings = TopicReaderSettings(**args)
Expand All @@ -188,7 +188,7 @@ def writer(
encoders: Optional[Mapping[_ydb_topic_public_types.PublicCodec, Callable[[bytes], bytes]]] = None,
encoder_executor: Optional[concurrent.futures.Executor] = None, # default shared client executor pool
) -> TopicWriterAsyncIO:
args = locals()
args = locals().copy()
del args["self"]

settings = TopicWriterSettings(**args)
Expand Down

0 comments on commit cd93e4a

Please sign in to comment.