diff --git a/python/cog/server/connection.py b/python/cog/server/connection.py deleted file mode 100644 index c0d85722d7..0000000000 --- a/python/cog/server/connection.py +++ /dev/null @@ -1,91 +0,0 @@ -import asyncio -import multiprocessing -from multiprocessing.connection import Connection -from typing import Any, Optional - -from typing_extensions import Buffer - -_spawn = multiprocessing.get_context("spawn") - - -class AsyncConnection: - def __init__(self, connection: Connection) -> None: - self._connection = connection - self._event = asyncio.Event() - loop = asyncio.get_event_loop() - loop.add_reader(self._connection.fileno(), self._event.set) - - def send(self, obj: Any) -> None: - """Send a (picklable) object""" - - self._connection.send(obj) - - async def _wait_for_input(self) -> None: - """Wait until there is an input available to be read""" - - while not self._connection.poll(): - await self._event.wait() - self._event.clear() - - async def recv(self) -> Any: - """Receive a (picklable) object""" - - await self._wait_for_input() - return self._connection.recv() - - def fileno(self) -> int: - """File descriptor or handle of the connection""" - return self._connection.fileno() - - def close(self) -> None: - """Close the connection""" - self._connection.close() - - async def poll(self, timeout: float = 0.0) -> bool: - """Whether there is an input available to be read""" - - if self._connection.poll(): - return True - - try: - await asyncio.wait_for(self._wait_for_input(), timeout=timeout) - except asyncio.TimeoutError: - return False - return self._connection.poll() - - def send_bytes( - self, buf: Buffer, offset: int = 0, size: Optional[int] = None - ) -> None: - """Send the bytes data from a bytes-like object""" - - self._connection.send_bytes(buf, offset, size) - - async def recv_bytes(self, maxlength: Optional[int] = None) -> bytes: - """ - Receive bytes data as a bytes object. - """ - - await self._wait_for_input() - return self._connection.recv_bytes(maxlength) - - async def recv_bytes_into(self, buf: Buffer, offset: int = 0) -> int: - """ - Receive bytes data into a writeable bytes-like object. - Return the number of bytes read. - """ - - await self._wait_for_input() - return self._connection.recv_bytes_into(buf, offset) - - -class LockedConnection: - def __init__(self, connection: Connection) -> None: - self.connection = connection - self._lock = _spawn.Lock() - - def send(self, obj: Any) -> None: - with self._lock: - self.connection.send(obj) - - def recv(self) -> Any: - return self.connection.recv() diff --git a/python/cog/server/eventtypes.py b/python/cog/server/eventtypes.py index c8300cd5e5..ff39bd072e 100644 --- a/python/cog/server/eventtypes.py +++ b/python/cog/server/eventtypes.py @@ -5,12 +5,6 @@ # From worker parent process # -@define -class Cancel: - # TODO: identify which prediction! - pass - - @define class PredictionInput: payload: Dict[str, Any] diff --git a/python/cog/server/helpers.py b/python/cog/server/helpers.py index 74cc59b2bd..71dde4d788 100644 --- a/python/cog/server/helpers.py +++ b/python/cog/server/helpers.py @@ -10,7 +10,7 @@ import threading import uuid from types import TracebackType -from typing import Any, BinaryIO, Callable, Dict, List, Sequence, TextIO, Union +from typing import Any, Callable, Dict, List, Sequence, TextIO, Union import pydantic from typing_extensions import Self @@ -19,45 +19,6 @@ from .errors import CogRuntimeError, CogTimeoutError -class _SimpleStreamWrapper(io.TextIOWrapper): - """ - _SimpleStreamWrapper wraps a binary I/O buffer and provides a TextIOWrapper - interface (primarily write and flush methods) which call a provided - callback function instead of (or, if `tee` is True, in addition to) writing - to the underlying buffer. - """ - - def __init__( - self, - buffer: BinaryIO, - callback: Callable[[str, str], None], - tee: bool = False, - ) -> None: - super().__init__(buffer, line_buffering=True) - - self._callback = callback - self._tee = tee - self._buffer = [] - - def write(self, s: str) -> int: - length = len(s) - self._buffer.append(s) - if self._tee: - super().write(s) - else: - # If we're not teeing, we have to handle automatic flush on - # newline. When `tee` is true, this is handled by the write method. - if "\n" in s or "\r" in s: - self.flush() - return length - - def flush(self) -> None: - self._callback(self.name, "".join(self._buffer)) - self._buffer.clear() - if self._tee: - super().flush() - - class _StreamWrapper: def __init__(self, name: str, stream: TextIO) -> None: self.name = name @@ -125,66 +86,6 @@ def original(self) -> TextIO: return self._original_fp -if sys.version_info < (3, 9): - - class _AsyncStreamRedirectorBase(contextlib.AbstractContextManager): - pass -else: - - class _AsyncStreamRedirectorBase( - contextlib.AbstractContextManager["AsyncStreamRedirector"] - ): - pass - - -class AsyncStreamRedirector(_AsyncStreamRedirectorBase): - """ - AsyncStreamRedirector is a context manager that redirects I/O streams to a - callback function. If `tee` is True, it also writes output to the original - streams. - - Unlike StreamRedirector, the underlying stream file descriptors are not - modified, which means that only stream writes from Python code will be - captured. Writes from native code will not be captured. - - Unlike StreamRedirector, the streams redirected cannot be configured. The - context manager is only able to redirect STDOUT and STDERR. - """ - - def __init__( - self, - callback: Callable[[str, str], None], - tee: bool = False, - ) -> None: - self._callback = callback - self._tee = tee - - stdout_wrapper = _SimpleStreamWrapper(sys.stdout.buffer, callback, tee) - stderr_wrapper = _SimpleStreamWrapper(sys.stderr.buffer, callback, tee) - self._stdout_ctx = contextlib.redirect_stdout(stdout_wrapper) - self._stderr_ctx = contextlib.redirect_stderr(stderr_wrapper) - - def __enter__(self) -> Self: - self._stdout_ctx.__enter__() - self._stderr_ctx.__enter__() - return self - - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_value: BaseException | None, - traceback: TracebackType | None, - ) -> None: - self._stdout_ctx.__exit__(exc_type, exc_value, traceback) - self._stderr_ctx.__exit__(exc_type, exc_value, traceback) - - def drain(self, timeout: float = 0.0) -> None: - # Draining isn't complicated for AsyncStreamRedirector, since we're not - # moving data between threads. We just need to flush the streams. - sys.stdout.flush() - sys.stderr.flush() - - if sys.version_info < (3, 9): class _StreamRedirectorBase(contextlib.AbstractContextManager): diff --git a/python/cog/server/worker.py b/python/cog/server/worker.py index 760f43cb95..e213b8acf6 100644 --- a/python/cog/server/worker.py +++ b/python/cog/server/worker.py @@ -1,6 +1,3 @@ -import asyncio -import contextlib -import inspect import multiprocessing import os import signal @@ -12,16 +9,14 @@ from concurrent.futures import Future, ThreadPoolExecutor from enum import Enum, auto, unique from multiprocessing.connection import Connection -from typing import Any, Callable, Dict, Iterator, Optional, Union +from typing import Any, Callable, Dict, Optional, Union import structlog from ..json import make_encodeable from ..predictor import BasePredictor, get_predict, load_predictor_from_ref, run_setup from ..types import PYDANTIC_V2, URLPath -from .connection import AsyncConnection, LockedConnection from .eventtypes import ( - Cancel, Done, Log, PredictionInput, @@ -34,7 +29,7 @@ FatalWorkerException, InvalidStateException, ) -from .helpers import AsyncStreamRedirector, StreamRedirector +from .helpers import StreamRedirector if PYDANTIC_V2: from .helpers import unwrap_pydantic_serialization_iterators @@ -56,7 +51,7 @@ class WorkerState(Enum): class Worker: - def __init__(self, child: "_ChildWorker", events: Connection) -> None: + def __init__(self, child: "ChildWorker", events: Connection) -> None: self._child = child self._events = events @@ -135,7 +130,6 @@ def terminate(self) -> None: def cancel(self) -> None: if self._allow_cancel: self._child.send_cancel() - self._events.send(Cancel()) self._allow_cancel = False def _assert_state(self, state: WorkerState) -> None: @@ -266,7 +260,7 @@ def recv(self) -> Any: return self.conn.recv() -class _ChildWorker(_spawn.Process): # type: ignore +class ChildWorker(_spawn.Process): # type: ignore def __init__( self, predictor_ref: str, @@ -275,9 +269,7 @@ def __init__( ) -> None: self._predictor_ref = predictor_ref self._predictor: Optional[BasePredictor] = None - self._events: Union[AsyncConnection, LockedConnection] = LockedConnection( - events - ) + self._events = LockedConn(events) self._tee_output = tee_output self._cancelable = False @@ -289,38 +281,17 @@ def run(self) -> None: # shutdown is coordinated by the parent process. signal.signal(signal.SIGINT, signal.SIG_IGN) - # Initially, we ignore SIGUSR1. - signal.signal(signal.SIGUSR1, signal.SIG_IGN) + # We use SIGUSR1 to signal an interrupt for cancelation. + signal.signal(signal.SIGUSR1, self._signal_handler) redirector = StreamRedirector( - callback=self._stream_write_hook, tee=self._tee_output, + callback=self._stream_write_hook, ) - # TODO: support async setup? see where `redirector` is redefined below if the predict is async with redirector: self._setup(redirector) - - # If setup didn't set the predictor, we're done here. - if not self._predictor: - return - - predict = get_predict(self._predictor) - if inspect.iscoroutinefunction(predict) or inspect.isasyncgenfunction(predict): - # Replace the stream redirector with one that will work in an async - # context. - redirector = AsyncStreamRedirector( - callback=self._stream_write_hook, - tee=self._tee_output, - ) - - asyncio.run(self._aloop(predict, redirector)) - else: - # We use SIGUSR1 to signal an interrupt for cancelation. - signal.signal(signal.SIGUSR1, self._signal_handler) - - with redirector: - self._loop(predict, redirector) + self._loop(redirector) def send_cancel(self) -> None: if self.is_alive() and self.pid: @@ -357,57 +328,27 @@ def _setup(self, redirector: StreamRedirector) -> None: raise self._events.send(done) - def _loop( - self, - predict: Callable[..., Any], - redirector: StreamRedirector, - ) -> None: - with redirector: - while True: - ev = self._events.recv() - if isinstance(ev, Cancel): - continue # Ignored in sync predictors. - elif isinstance(ev, Shutdown): - break - elif isinstance(ev, PredictionInput): - self._predict(ev.payload, predict, redirector) - else: - print(f"Got unexpected event: {ev}", file=sys.stderr) - - async def _aloop( - self, - predict: Callable[..., Any], - redirector: AsyncStreamRedirector, - ) -> None: - # Unwrap and replace the events connection with an async one. - assert isinstance(self._events, LockedConnection) - self._events = AsyncConnection(self._events.connection) - - task = None - - with redirector: - while True: - ev = await self._events.recv() - if isinstance(ev, Cancel) and task and self._cancelable: - task.cancel() - elif isinstance(ev, Shutdown): - break - elif isinstance(ev, PredictionInput): - task = asyncio.create_task( - self._apredict(ev.payload, predict, redirector) - ) - else: - print(f"Got unexpected event: {ev}", file=sys.stderr) - if task: - await task + def _loop(self, redirector: StreamRedirector) -> None: + while True: + ev = self._events.recv() + if isinstance(ev, Shutdown): + break + if isinstance(ev, PredictionInput): + self._predict(ev.payload, redirector) + else: + print(f"Got unexpected event: {ev}", file=sys.stderr) def _predict( self, payload: Dict[str, Any], - predict: Callable[..., Any], redirector: StreamRedirector, ) -> None: - with self._handle_predict_error(redirector): + assert self._predictor + done = Done() + send_done = True + self._cancelable = True + try: + predict = get_predict(self._predictor) result = predict(**payload) if result: @@ -430,46 +371,8 @@ def _predict( else: payload = make_encodeable(result) self._events.send(PredictionOutput(payload=payload)) - - async def _apredict( - self, - payload: Dict[str, Any], - predict: Callable[..., Any], - redirector: AsyncStreamRedirector, - ) -> None: - with self._handle_predict_error(redirector): - result = predict(**payload) - - if result: - if inspect.isasyncgen(result): - self._events.send(PredictionOutputType(multi=True)) - async for r in result: - self._events.send(PredictionOutput(payload=make_encodeable(r))) - else: - output = await result - self._events.send(PredictionOutputType(multi=False)) - self._events.send(PredictionOutput(payload=make_encodeable(output))) - - @contextlib.contextmanager - def _handle_predict_error( - self, redirector: Union[AsyncStreamRedirector, StreamRedirector] - ) -> Iterator[None]: - done = Done() - send_done = True - self._cancelable = True - try: - yield - # regular cancelation except CancelationException: done.canceled = True - # async cancelation - except asyncio.CancelledError: - done.canceled = True - # We've handled the requested cancelation, so we uncancel the task. - # This ensures that any cleanup work we do won't be interrupted. - task = asyncio.current_task() - assert task - task.uncancel() except Exception as e: # pylint: disable=broad-exception-caught traceback.print_exc() done.error = True @@ -514,7 +417,7 @@ def _stream_write_hook(self, stream_name: str, data: str) -> None: def make_worker(predictor_ref: str, tee_output: bool = True) -> Worker: parent_conn, child_conn = _spawn.Pipe() - child = _ChildWorker(predictor_ref, events=child_conn, tee_output=tee_output) + child = ChildWorker(predictor_ref, events=child_conn, tee_output=tee_output) parent = Worker(child=child, events=parent_conn) return parent diff --git a/python/cog/types.py b/python/cog/types.py index 0e0175cf21..f4110e68ca 100644 --- a/python/cog/types.py +++ b/python/cog/types.py @@ -288,8 +288,6 @@ class URLFile(io.IOBase): __slots__ = ("__target__", "__url__") def __init__(self, url: str) -> None: - object.__setattr__(self, "__url__", None) - parsed = urllib.parse.urlparse(url) if parsed.scheme not in { "http", diff --git a/python/tests/server/fixtures/hello_world_async.py b/python/tests/server/fixtures/hello_world_async.py deleted file mode 100644 index a79c2ae331..0000000000 --- a/python/tests/server/fixtures/hello_world_async.py +++ /dev/null @@ -1,3 +0,0 @@ -class Predictor: - async def predict(self, name): - return f"hello, {name}" diff --git a/python/tests/server/fixtures/logging_async.py b/python/tests/server/fixtures/logging_async.py deleted file mode 100644 index 8193e5065c..0000000000 --- a/python/tests/server/fixtures/logging_async.py +++ /dev/null @@ -1,34 +0,0 @@ -import ctypes -import logging -import sys -import time - -libc = ctypes.CDLL(None) - -# test that we can still capture type signature even if we write -# a bunch of stuff at import time. -libc.puts(b"writing some stuff from C at import time") -libc.fflush(None) -sys.stdout.write("writing to stdout at import time\n") -sys.stderr.write("writing to stderr at import time\n") - - -class Predictor: - def setup(self): - print("setting up predictor") - self.foo = "foo" - - async def predict(self) -> str: - time.sleep(0.1) - logging.warn("writing log message") - time.sleep(0.1) - libc.puts(b"writing from C") # not expected to be seen - libc.fflush(None) - time.sleep(0.1) - sys.stderr.write("writing to stderr\n") - time.sleep(0.1) - sys.stderr.flush() - time.sleep(0.1) - print("writing with print") - time.sleep(0.1) - return "output" diff --git a/python/tests/server/fixtures/sleep_async.py b/python/tests/server/fixtures/sleep_async.py deleted file mode 100644 index 0113bd040a..0000000000 --- a/python/tests/server/fixtures/sleep_async.py +++ /dev/null @@ -1,10 +0,0 @@ -import asyncio - -from cog import BasePredictor - - -class Predictor(BasePredictor): - async def predict(self, sleep: float = 0) -> str: - print("starting") - await asyncio.sleep(sleep) - return f"done in {sleep} seconds" diff --git a/python/tests/server/test_worker.py b/python/tests/server/test_worker.py index aa02b7e319..a0947addd6 100644 --- a/python/tests/server/test_worker.py +++ b/python/tests/server/test_worker.py @@ -63,11 +63,6 @@ {"name": ST_NAMES}, lambda x: f"hello, {x['name']}", ), - ( - WorkerConfig("hello_world_async"), - {"name": ST_NAMES}, - lambda x: f"hello, {x['name']}", - ), ( WorkerConfig("count_up"), {"upto": st.integers(min_value=0, max_value=100)}, @@ -82,36 +77,20 @@ SETUP_LOGS_FIXTURES = [ ( - WorkerConfig("logging", setup=False), - ( - "writing some stuff from C at import time\n" - "writing to stdout at import time\n" - "setting up predictor\n" - ), - "writing to stderr at import time\n", - ), - ( - WorkerConfig("logging_async", setup=False), ( "writing some stuff from C at import time\n" "writing to stdout at import time\n" "setting up predictor\n" ), "writing to stderr at import time\n", - ), + ) ] PREDICT_LOGS_FIXTURES = [ ( - WorkerConfig("logging"), ("writing from C\n" "writing with print\n"), ("WARNING:root:writing log message\n" "writing to stderr\n"), - ), - ( - WorkerConfig("logging_async"), - ("writing with print\n"), - ("WARNING:root:writing log message\n" "writing to stderr\n"), - ), + ) ] @@ -213,7 +192,7 @@ def test_no_exceptions_from_recoverable_failures(worker): @uses_worker("stream_redirector_race_condition") def test_stream_redirector_race_condition(worker): """ - StreamRedirector and _ChildWorker are using the same pipe to send data. When + StreamRedirector and ChildWorker are using the same pipe to send data. When there are multiple threads trying to write to the same pipe, it can cause data corruption by race condition. The data corruption will cause pipe receiver to raise an exception due to unpickling error. @@ -243,11 +222,8 @@ def test_output(worker, payloads, output_generator, data): assert result.output == expected_output -@pytest.mark.parametrize( - "worker,expected_stdout,expected_stderr", - SETUP_LOGS_FIXTURES, - indirect=["worker"], -) +@uses_worker("logging", setup=False) +@pytest.mark.parametrize("expected_stdout,expected_stderr", SETUP_LOGS_FIXTURES) def test_setup_logging(worker, expected_stdout, expected_stderr): """ We should get the logs we expect from predictors that generate logs during @@ -260,11 +236,8 @@ def test_setup_logging(worker, expected_stdout, expected_stderr): assert result.stderr == expected_stderr -@pytest.mark.parametrize( - "worker,expected_stdout,expected_stderr", - PREDICT_LOGS_FIXTURES, - indirect=["worker"], -) +@uses_worker("logging") +@pytest.mark.parametrize("expected_stdout,expected_stderr", PREDICT_LOGS_FIXTURES) def test_predict_logging(worker, expected_stdout, expected_stderr): """ We should get the logs we expect from predictors that generate logs during @@ -276,7 +249,7 @@ def test_predict_logging(worker, expected_stdout, expected_stderr): assert result.stderr == expected_stderr -@uses_worker(["sleep", "sleep_async"], setup=False) +@uses_worker("sleep", setup=False) def test_cancel_is_safe(worker): """ Calls to cancel at any time should not result in unexpected things @@ -310,7 +283,7 @@ def test_cancel_is_safe(worker): assert result2.output == "done in 0.1 seconds" -@uses_worker(["sleep", "sleep_async"], setup=False) +@uses_worker("sleep", setup=False) def test_cancel_idempotency(worker): """ Multiple calls to cancel within the same prediction, while not necessary or @@ -342,7 +315,7 @@ def cancel_a_bunch(_): assert result2.output == "done in 0.1 seconds" -@uses_worker(["sleep", "sleep_async"]) +@uses_worker("sleep") def test_cancel_multiple_predictions(worker): """ Multiple predictions cancelled in a row shouldn't be a problem. This test @@ -360,7 +333,7 @@ def test_cancel_multiple_predictions(worker): assert not worker.predict({"sleep": 0}).result().canceled -@uses_worker(["sleep", "sleep_async"]) +@uses_worker("sleep") def test_graceful_shutdown(worker): """ On shutdown, the worker should finish running the current prediction, and @@ -402,7 +375,6 @@ class FakeChildWorker: exitcode = None cancel_sent = False alive = True - pid: int = 0 def start(self): pass