-
Notifications
You must be signed in to change notification settings - Fork 561
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Permit async predictors First steps towards allowing `async def predict` method signatures. This commit adds support to Worker for starting an `asyncio` event loop if the `predict` function returns an awaitable or an async generator. For now, we drop support for output capture as well as cancelation. * Async-compatible stream interception In an async context, attempting to intercept stream writes at the file descriptor layer is futile. We can do it, but we will have no way of associating a write made from native code with a specific prediction -- and the only reason to intercept/swap out the STDOUT/STDERR file descriptors is so that we can catch writes from native code. This commit adds an altogether simpler implementation which can work for async code with that restriction. All it does is patch `sys.stdout` and `sys.stderr` with objects that can redirect (or tee) the output to a callback function. * Implement single-task cancelation for async predictors This implements basic cancelation for async predictors. Whereas regular predictors implement cancelation using a custom CancelationException, asyncio already has a concept of task cancelation, so we use that. When cancelation is requested, we send a `Cancel()` event down the events pipe to the child. Regular predictors ignore these, but async predictors cancel the currently-running task when they receive one. In future, these `Cancel()` events will specify which running prediction they are intended to cancel. * Ensure that graceful shutdown works as expected for async predictors When a `Shutdown()` event is sent, any running prediction should be allowed to completed. For now, we implement this by awaiting any task that is tracked when we break out of the child worker's event loop. * Update support-async-predictors branch for Pydantic v2 - Use renamed _ChildWorker type - Set initial `__url__` to `None` prior to URL parsing potentially throwing `ValueError` - Declare `pid` field in `FakeChildWorker` - Do not use nested redirectors --------- Co-authored-by: Nick Stenning <[email protected]> Co-authored-by: Dominic Baggott <[email protected]>
- Loading branch information
1 parent
72d7d50
commit a86adcd
Showing
9 changed files
with
407 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import asyncio | ||
import multiprocessing | ||
from multiprocessing.connection import Connection | ||
from typing import Any, Optional | ||
|
||
from typing_extensions import Buffer | ||
|
||
_spawn = multiprocessing.get_context("spawn") | ||
|
||
|
||
class AsyncConnection: | ||
def __init__(self, connection: Connection) -> None: | ||
self._connection = connection | ||
self._event = asyncio.Event() | ||
loop = asyncio.get_event_loop() | ||
loop.add_reader(self._connection.fileno(), self._event.set) | ||
|
||
def send(self, obj: Any) -> None: | ||
"""Send a (picklable) object""" | ||
|
||
self._connection.send(obj) | ||
|
||
async def _wait_for_input(self) -> None: | ||
"""Wait until there is an input available to be read""" | ||
|
||
while not self._connection.poll(): | ||
await self._event.wait() | ||
self._event.clear() | ||
|
||
async def recv(self) -> Any: | ||
"""Receive a (picklable) object""" | ||
|
||
await self._wait_for_input() | ||
return self._connection.recv() | ||
|
||
def fileno(self) -> int: | ||
"""File descriptor or handle of the connection""" | ||
return self._connection.fileno() | ||
|
||
def close(self) -> None: | ||
"""Close the connection""" | ||
self._connection.close() | ||
|
||
async def poll(self, timeout: float = 0.0) -> bool: | ||
"""Whether there is an input available to be read""" | ||
|
||
if self._connection.poll(): | ||
return True | ||
|
||
try: | ||
await asyncio.wait_for(self._wait_for_input(), timeout=timeout) | ||
except asyncio.TimeoutError: | ||
return False | ||
return self._connection.poll() | ||
|
||
def send_bytes( | ||
self, buf: Buffer, offset: int = 0, size: Optional[int] = None | ||
) -> None: | ||
"""Send the bytes data from a bytes-like object""" | ||
|
||
self._connection.send_bytes(buf, offset, size) | ||
|
||
async def recv_bytes(self, maxlength: Optional[int] = None) -> bytes: | ||
""" | ||
Receive bytes data as a bytes object. | ||
""" | ||
|
||
await self._wait_for_input() | ||
return self._connection.recv_bytes(maxlength) | ||
|
||
async def recv_bytes_into(self, buf: Buffer, offset: int = 0) -> int: | ||
""" | ||
Receive bytes data into a writeable bytes-like object. | ||
Return the number of bytes read. | ||
""" | ||
|
||
await self._wait_for_input() | ||
return self._connection.recv_bytes_into(buf, offset) | ||
|
||
|
||
class LockedConnection: | ||
def __init__(self, connection: Connection) -> None: | ||
self.connection = connection | ||
self._lock = _spawn.Lock() | ||
|
||
def send(self, obj: Any) -> None: | ||
with self._lock: | ||
self.connection.send(obj) | ||
|
||
def recv(self) -> Any: | ||
return self.connection.recv() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.