diff --git a/vllm/utils.py b/vllm/utils.py index 1fd395c04ca24..4137aaec8a93c 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -405,7 +405,7 @@ async def iterate_with_cancellation( async def merge_async_iterators( *iterators: AsyncGenerator[T, None], - is_cancelled: Callable[[], Awaitable[bool]], + is_cancelled: Optional[Callable[[], Awaitable[bool]]] = None, ) -> AsyncGenerator[Tuple[int, T], None]: """Merge multiple asynchronous iterators into a single iterator. @@ -413,8 +413,8 @@ async def merge_async_iterators( When it yields, it yields a tuple (i, item) where i is the index of the iterator that yields the item. - It also polls the provided function at least once per second to check - for client cancellation. + It also optionally polls a provided function at least once per second + to check for client cancellation. """ # Can use anext() in python >= 3.10 @@ -422,12 +422,13 @@ async def merge_async_iterators( ensure_future(pair[1].__anext__()): pair for pair in enumerate(iterators) } + timeout = None if is_cancelled is None else 1 try: while awaits: done, pending = await asyncio.wait(awaits.keys(), return_when=FIRST_COMPLETED, - timeout=1) - if await is_cancelled(): + timeout=timeout) + if is_cancelled is not None and await is_cancelled(): raise asyncio.CancelledError("client cancelled") for d in done: pair = awaits.pop(d)