diff --git a/faust/streams.py b/faust/streams.py index 4a2a6b5a0..3ac2fddd1 100644 --- a/faust/streams.py +++ b/faust/streams.py @@ -394,6 +394,100 @@ async def add_to_buffer(value: T) -> T: self.enable_acks = stream_enable_acks self._processors.remove(add_to_buffer) + async def take_events(self, max_: int, + within: Seconds) -> AsyncIterable[Sequence[EventT]]: + """Buffer n events at a time and yield a list of buffered events. + + Arguments: + max_: Max number of messages to receive. When more than this + number of messages are received within the specified number of + seconds then we flush the buffer immediately. + within: Timeout for when we give up waiting for another value, + and process the values we have. + Warning: If there's no timeout (i.e. `timeout=None`), + the agent is likely to stall and block buffered events for an + unreasonable length of time(!). + """ + buffer: List[T_co] = [] + events: List[EventT] = [] + buffer_add = buffer.append + event_add = events.append + buffer_size = buffer.__len__ + buffer_full = asyncio.Event(loop=self.loop) + buffer_consumed = asyncio.Event(loop=self.loop) + timeout = want_seconds(within) if within else None + stream_enable_acks: bool = self.enable_acks + + buffer_consuming: Optional[asyncio.Future] = None + + channel_it = aiter(self.channel) + + # We add this processor to populate the buffer, and the stream + # is passively consumed in the background (enable_passive below). + async def add_to_buffer(value: T) -> T: + try: + # buffer_consuming is set when consuming buffer after timeout. + nonlocal buffer_consuming + if buffer_consuming is not None: + try: + await buffer_consuming + finally: + buffer_consuming = None + buffer_add(cast(T_co, value)) + event = self.current_event + if event is None: + raise RuntimeError( + 'Take buffer found current_event is None') + event_add(event) + if buffer_size() >= max_: + # signal that the buffer is full and should be emptied. + buffer_full.set() + # strict wait for buffer to be consumed after buffer full. + # If max is 1000, we are not allowed to return 1001 values. + buffer_consumed.clear() + await self.wait(buffer_consumed) + except CancelledError: # pragma: no cover + raise + except Exception as exc: + self.log.exception('Error adding to take buffer: %r', exc) + await self.crash(exc) + return value + + # Disable acks to ensure this method acks manually + # events only after they are consumed by the user + self.enable_acks = False + + self.add_processor(add_to_buffer) + self._enable_passive(cast(ChannelT, channel_it)) + try: + while not self.should_stop: + # wait until buffer full, or timeout + await self.wait_for_stopped(buffer_full, timeout=timeout) + if buffer: + # make sure background thread does not add new items to + # buffer while we read. + buffer_consuming = self.loop.create_future() + try: + yield list(events) + finally: + buffer.clear() + for event in events: + await self.ack(event) + events.clear() + # allow writing to buffer again + notify(buffer_consuming) + buffer_full.clear() + buffer_consumed.set() + else: # pragma: no cover + pass + else: # pragma: no cover + pass + + finally: + # Restore last behaviour of "enable_acks" + self.enable_acks = stream_enable_acks + self._processors.remove(add_to_buffer) + def enumerate(self, start: int = 0) -> AsyncIterable[Tuple[int, T_co]]: """Enumerate values received on this stream. diff --git a/faust/types/streams.py b/faust/types/streams.py index bde351d40..66227203f 100644 --- a/faust/types/streams.py +++ b/faust/types/streams.py @@ -174,6 +174,12 @@ async def take(self, max_: int, within: Seconds) -> AsyncIterable[Sequence[T_co]]: ... + @abc.abstractmethod + @no_type_check + async def take_events(self, max_: int, + within: Seconds) -> AsyncIterable[Sequence[EventT]]: + ... + @abc.abstractmethod def enumerate(self, start: int = 0) -> AsyncIterable[Tuple[int, T_co]]: ... diff --git a/t/functional/test_streams.py b/t/functional/test_streams.py index 9df87f79d..cc5c27762 100644 --- a/t/functional/test_streams.py +++ b/t/functional/test_streams.py @@ -618,6 +618,31 @@ async def in_one_second_finalize(): assert not event.message.refcount assert s.enable_acks is True +@pytest.mark.asyncio +async def test_take_events(app, loop): + s = new_stream(app) + async with s: + assert s.enable_acks is True + for i in range(10): + await s.channel.send(value=i) + + event_processor = s.take_events(10, within=1) + async for event_list in event_processor: + assert all(isinstance(e, faust.Event) for e in event_list) + assert list(range(10)) == [e.message.value for e in event_list] + break + + try: + await event_processor.athrow(asyncio.CancelledError()) + except asyncio.CancelledError: + pass + + # need one sleep on Python 3.6.0-3.6.6 + 3.7.0 + # need two sleeps on Python 3.6.7 + 3.7.1 :-/ + await asyncio.sleep(0) # needed for some reason + await asyncio.sleep(0) # needed for some reason + await asyncio.sleep(0) # needed for some reason + @pytest.mark.asyncio async def test_take__no_event_crashes(app, loop):