diff --git a/aio_pika/queue.py b/aio_pika/queue.py index 2f0f9d45..939237ee 100644 --- a/aio_pika/queue.py +++ b/aio_pika/queue.py @@ -484,6 +484,11 @@ def __init__(self, queue: Queue, **kwargs: Any): self._queue = asyncio.Queue() self._consume_kwargs = kwargs + if kwargs.get("timeout"): + self.__anext_impl = self.__anext_with_timeout + else: + self.__anext_impl = self.__anext_without_timeout + self._amqp_queue.close_callbacks.add(self.close) async def on_message(self, message: AbstractIncomingMessage) -> None: @@ -511,6 +516,9 @@ async def __aexit__( await self.close() async def __anext__(self) -> IncomingMessage: + return await self.__anext_impl() + + async def __anext_with_timeout(self) -> IncomingMessage: if not hasattr(self, "_consumer_tag"): await self.consume() try: @@ -530,5 +538,14 @@ async def __anext__(self) -> IncomingMessage: await asyncio.wait_for(self.close(), timeout=timeout) raise + async def __anext_without_timeout(self) -> IncomingMessage: + if not hasattr(self, "_consumer_tag"): + await self.consume() + try: + return await self._queue.get() + except asyncio.CancelledError: + await self.close() + raise + __all__ = ("Queue", "QueueIterator", "ConsumerTag")