Skip to content

Commit

Permalink
Merge pull request #421 fix-hungup-bad-codec
Browse files Browse the repository at this point in the history
  • Loading branch information
rekby authored Apr 24, 2024
2 parents f6f591e + 8177d7f commit 70b8a81
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 17 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
* Fixed hungup topic reader on unknown codec

## 3.11.1 ##
* fixed unexpected require requests module on import

Expand Down
13 changes: 13 additions & 0 deletions tests/topics/test_topic_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,19 @@ def decode(b: bytes):
batch = await reader.receive_batch()
assert batch.messages[0].data.decode() == "123"

async def test_error_unknown_codec(self, driver, topic_path, topic_consumer):
codec = 10001

def encode(b: bytes):
return bytes(reversed(b))

async with driver.topic_client.writer(topic_path, codec=codec, encoders={codec: encode}) as writer:
await writer.write("123")

async with driver.topic_client.reader(topic_path, topic_consumer) as reader:
with pytest.raises(ydb.TopicReaderUnexpectedCodecError):
await asyncio.wait_for(reader.receive_batch(), timeout=5)

async def test_read_from_two_topics(self, driver, topic_path, topic2_path, topic_consumer):
async with driver.topic_client.writer(topic_path) as writer:
await writer.write("1")
Expand Down
33 changes: 23 additions & 10 deletions ydb/_topic_reader/topic_reader_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from collections import deque
from typing import Optional, Set, Dict, Union, Callable

import ydb
from .. import _apis, issues
from .._utilities import AtomicCounter
from ..aio import Driver
Expand Down Expand Up @@ -35,7 +36,7 @@ class TopicReaderError(YdbError):
pass


class TopicReaderUnexpectedCodec(YdbError):
class PublicTopicReaderUnexpectedCodecError(YdbError):
pass


Expand Down Expand Up @@ -222,9 +223,7 @@ def commit(self, batch: datatypes.ICommittable) -> datatypes.PartitionSession.Co

async def close(self, flush: bool):
if self._stream_reader:
if flush:
await self.flush()
await self._stream_reader.close()
await self._stream_reader.close(flush)
for task in self._background_tasks:
task.cancel()

Expand Down Expand Up @@ -339,9 +338,12 @@ async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMess
self._update_token_event.set()

self._background_tasks.add(asyncio.create_task(self._read_messages_loop(), name="read_messages_loop"))
self._background_tasks.add(asyncio.create_task(self._decode_batches_loop()))
self._background_tasks.add(asyncio.create_task(self._decode_batches_loop(), name="decode_batches"))
if self._get_token_function:
self._background_tasks.add(asyncio.create_task(self._update_token_loop(), name="update_token_loop"))
self._background_tasks.add(
asyncio.create_task(self._handle_background_errors(), name="handle_background_errors")
)

async def wait_error(self):
raise await self._first_error
Expand Down Expand Up @@ -411,6 +413,17 @@ def commit(self, batch: datatypes.ICommittable) -> datatypes.PartitionSession.Co

return waiter

async def _handle_background_errors(self):
done, _ = await asyncio.wait(self._background_tasks, return_when=asyncio.FIRST_EXCEPTION)
for f in done:
f = f # type: asyncio.Future
err = f.exception()
if not isinstance(err, ydb.Error):
old_err = err
err = ydb.Error("Background process failed unexpected")
err.__cause__ = old_err
self._set_first_error(err)

async def _read_messages_loop(self):
try:
self._stream.write(
Expand Down Expand Up @@ -602,7 +615,7 @@ async def _decode_batch_inplace(self, batch):
try:
decode_func = self._decoders[batch._codec]
except KeyError:
raise TopicReaderUnexpectedCodec("Receive message with unexpected codec: %s" % batch._codec)
raise PublicTopicReaderUnexpectedCodecError("Receive message with unexpected codec: %s" % batch._codec)

decode_data_futures = []
for message in batch.messages:
Expand All @@ -628,22 +641,22 @@ def _get_first_error(self) -> Optional[YdbError]:
return self._first_error.result()

async def flush(self):
if self._closed:
raise RuntimeError("Flush on closed Stream")

futures = []
for session in self._partition_sessions.values():
futures.extend(w.future for w in session._ack_waiters)

if futures:
await asyncio.wait(futures)

async def close(self):
async def close(self, flush: bool):
if self._closed:
return

self._closed = True

if flush:
await self.flush()

self._set_first_error(TopicReaderStreamClosedError())
self._state_changed.set()
self._stream.close()
Expand Down
14 changes: 7 additions & 7 deletions ydb/_topic_reader/topic_reader_asyncio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,14 +183,14 @@ async def stream_reader(self, stream_reader_started: ReaderStream):
yield stream_reader_started

assert stream_reader_started._get_first_error() is None
await stream_reader_started.close()
await stream_reader_started.close(False)

@pytest.fixture()
async def stream_reader_finish_with_error(self, stream_reader_started: ReaderStream):
yield stream_reader_started

assert stream_reader_started._get_first_error() is not None
await stream_reader_started.close()
await stream_reader_started.close(False)

@staticmethod
def create_message(
Expand Down Expand Up @@ -372,7 +372,7 @@ async def test_close_ack_waiters_when_close_stream_reader(
self, stream_reader_started: ReaderStream, partition_session
):
waiter = partition_session.add_waiter(self.partition_session_committed_offset + 1)
await wait_for_fast(stream_reader_started.close())
await wait_for_fast(stream_reader_started.close(False))

with pytest.raises(topic_reader_asyncio.PublicTopicReaderPartitionExpiredError):
waiter.future.result()
Expand Down Expand Up @@ -402,7 +402,7 @@ async def test_flush(self, stream, stream_reader_started: ReaderStream, partitio
# don't raises
assert waiter.future.result() is None

await wait_for_fast(stream_reader_started.close())
await wait_for_fast(stream_reader_started.close(False))

async def test_commit_ranges_for_received_messages(
self, stream, stream_reader_started: ReaderStream, partition_session
Expand All @@ -422,7 +422,7 @@ async def test_commit_ranges_for_received_messages(
received = stream_reader_started.receive_batch_nowait().messages
assert received == [m2]

await stream_reader_started.close()
await stream_reader_started.close(False)

# noinspection PyTypeChecker
@pytest.mark.parametrize(
Expand Down Expand Up @@ -613,7 +613,7 @@ async def test_init_reader(self, stream, default_reader_settings):
)

assert reader._session_id == "test"
await reader.close()
await reader.close(False)

async def test_start_partition(
self,
Expand Down Expand Up @@ -1230,7 +1230,7 @@ async def test_update_token(self, stream):
got = await wait_for_fast(stream.from_client.get())
assert expected == got

await reader.close()
await reader.close(False)

async def test_read_unknown_message(self, stream, stream_reader, caplog):
class TestMessage:
Expand Down
2 changes: 2 additions & 0 deletions ydb/topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"TopicReaderMessage",
"TopicReaderSelector",
"TopicReaderSettings",
"TopicReaderUnexpectedCodecError",
"TopicReaderPartitionExpiredError",
"TopicStatWindow",
"TopicWriteResult",
Expand Down Expand Up @@ -49,6 +50,7 @@
from ._topic_reader.topic_reader_asyncio import (
PublicAsyncIOReader as TopicReaderAsyncIO,
PublicTopicReaderPartitionExpiredError as TopicReaderPartitionExpiredError,
PublicTopicReaderUnexpectedCodecError as TopicReaderUnexpectedCodecError,
)

from ._topic_writer.topic_writer import ( # noqa: F401
Expand Down

0 comments on commit 70b8a81

Please sign in to comment.