Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added take_events() function #661

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions faust/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,100 @@ async def add_to_buffer(value: T) -> T:
self.enable_acks = stream_enable_acks
self._processors.remove(add_to_buffer)

async def take_events(self, max_: int,
within: Seconds) -> AsyncIterable[Sequence[EventT]]:
"""Buffer n events at a time and yield a list of buffered events.

Arguments:
max_: Max number of messages to receive. When more than this
number of messages are received within the specified number of
seconds then we flush the buffer immediately.
within: Timeout for when we give up waiting for another value,
and process the values we have.
Warning: If there's no timeout (i.e. `timeout=None`),
the agent is likely to stall and block buffered events for an
unreasonable length of time(!).
"""
buffer: List[T_co] = []
events: List[EventT] = []
buffer_add = buffer.append
event_add = events.append
buffer_size = buffer.__len__
buffer_full = asyncio.Event(loop=self.loop)
buffer_consumed = asyncio.Event(loop=self.loop)
timeout = want_seconds(within) if within else None
stream_enable_acks: bool = self.enable_acks

buffer_consuming: Optional[asyncio.Future] = None

channel_it = aiter(self.channel)

# We add this processor to populate the buffer, and the stream
# is passively consumed in the background (enable_passive below).
async def add_to_buffer(value: T) -> T:
try:
# buffer_consuming is set when consuming buffer after timeout.
nonlocal buffer_consuming
if buffer_consuming is not None:
try:
await buffer_consuming
finally:
buffer_consuming = None
buffer_add(cast(T_co, value))
event = self.current_event
if event is None:
raise RuntimeError(
'Take buffer found current_event is None')
event_add(event)
if buffer_size() >= max_:
# signal that the buffer is full and should be emptied.
buffer_full.set()
# strict wait for buffer to be consumed after buffer full.
# If max is 1000, we are not allowed to return 1001 values.
buffer_consumed.clear()
await self.wait(buffer_consumed)
except CancelledError: # pragma: no cover
raise
except Exception as exc:
self.log.exception('Error adding to take buffer: %r', exc)
await self.crash(exc)
return value

# Disable acks to ensure this method acks manually
# events only after they are consumed by the user
self.enable_acks = False

self.add_processor(add_to_buffer)
self._enable_passive(cast(ChannelT, channel_it))
try:
while not self.should_stop:
# wait until buffer full, or timeout
await self.wait_for_stopped(buffer_full, timeout=timeout)
if buffer:
# make sure background thread does not add new items to
# buffer while we read.
buffer_consuming = self.loop.create_future()
try:
yield list(events)
finally:
buffer.clear()
for event in events:
await self.ack(event)
events.clear()
# allow writing to buffer again
notify(buffer_consuming)
buffer_full.clear()
buffer_consumed.set()
else: # pragma: no cover
pass
else: # pragma: no cover
pass

finally:
# Restore last behaviour of "enable_acks"
self.enable_acks = stream_enable_acks
self._processors.remove(add_to_buffer)

def enumerate(self, start: int = 0) -> AsyncIterable[Tuple[int, T_co]]:
"""Enumerate values received on this stream.

Expand Down
6 changes: 6 additions & 0 deletions faust/types/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,12 @@ async def take(self, max_: int,
within: Seconds) -> AsyncIterable[Sequence[T_co]]:
...

@abc.abstractmethod
@no_type_check
async def take_events(self, max_: int,
within: Seconds) -> AsyncIterable[Sequence[EventT]]:
...

@abc.abstractmethod
def enumerate(self, start: int = 0) -> AsyncIterable[Tuple[int, T_co]]:
...
Expand Down
25 changes: 25 additions & 0 deletions t/functional/test_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,31 @@ async def in_one_second_finalize():
assert not event.message.refcount
assert s.enable_acks is True

@pytest.mark.asyncio
async def test_take_events(app, loop):
s = new_stream(app)
async with s:
assert s.enable_acks is True
for i in range(10):
await s.channel.send(value=i)

event_processor = s.take_events(10, within=1)
async for event_list in event_processor:
assert all(isinstance(e, faust.Event) for e in event_list)
assert list(range(10)) == [e.message.value for e in event_list]
break

try:
await event_processor.athrow(asyncio.CancelledError())
except asyncio.CancelledError:
pass

# need one sleep on Python 3.6.0-3.6.6 + 3.7.0
# need two sleeps on Python 3.6.7 + 3.7.1 :-/
await asyncio.sleep(0) # needed for some reason
await asyncio.sleep(0) # needed for some reason
await asyncio.sleep(0) # needed for some reason


@pytest.mark.asyncio
async def test_take__no_event_crashes(app, loop):
Expand Down