diff --git a/example.py b/example.py index e2b9cd4..e615dd7 100644 --- a/example.py +++ b/example.py @@ -1,6 +1,8 @@ +import time from potassium import Potassium, Request, Response from transformers import pipeline import torch +import os app = Potassium("my_app") @@ -28,5 +30,19 @@ def handler(context: dict, request: Request) -> Response: status=200 ) +@app.handler("/stream") +def stream(context: dict, request: Request): + def stream(): + for i in range(100): + yield f"{i}\n" + time.sleep(1) + + return Response( + body=stream(), + status=200, + headers={"Content-Type": "text/plain"} + ) + + if __name__ == "__main__": - app.serve() \ No newline at end of file + app.serve() diff --git a/potassium/__init__.py b/potassium/__init__.py index 8006620..a50db17 100644 --- a/potassium/__init__.py +++ b/potassium/__init__.py @@ -1,3 +1,4 @@ from .potassium import * from .hooks import * -from .store import Store, RedisConfig \ No newline at end of file +from .store import Store, RedisConfig +from .types import Request, Response diff --git a/potassium/exceptions.py b/potassium/exceptions.py new file mode 100644 index 0000000..93c42f6 --- /dev/null +++ b/potassium/exceptions.py @@ -0,0 +1,10 @@ +class InvalidEndpointTypeException(Exception): + def __init__(self): + super().__init__("Invalid endpoint type. Must be 'handler' or 'background'") + + +class RouteAlreadyInUseException(Exception): + def __init__(self): + super().__init__("Route already in use") + + diff --git a/potassium/potassium.py b/potassium/potassium.py index 9ad355b..7f60564 100644 --- a/potassium/potassium.py +++ b/potassium/potassium.py @@ -1,70 +1,70 @@ import time +import os from types import GeneratorType -from typing import Generator, Optional, Union from flask import Flask, request, make_response, abort, Response as FlaskResponse +from huggingface_hub.file_download import uuid from werkzeug.serving import make_server -from werkzeug.datastructures.headers import EnvironHeaders -from threading import Thread, Lock, Condition +from threading import Thread, Lock +from queue import Queue as ThreadQueue import functools -import traceback -import json as jsonlib from termcolor import colored - - -class Endpoint(): - def __init__(self, type, func): - self.type = type - self.func = func - -class Request(): - def __init__(self, id: str, headers: EnvironHeaders, json: dict): - self.id = id - self.headers = headers - self.json = json - -ResponseBody = Union[bytes, Generator[bytes, None, None]] - -class Response(): - def __init__(self, status: int = 200, json: Optional[dict] = None, headers: Optional[dict] = None, body: Optional[ResponseBody] = None): - assert json == None or body == None, "Potassium Response object cannot have both json and body set" - - - self.headers = headers if headers != None else {} - - # convert json to body if not None - if json != None: - self.body = jsonlib.dumps(json).encode("utf-8") - self.headers["Content-Type"] = "application/json" - else: - self.body = body - - self.status = status - - @property - def json(self): - if self.body == None: - return None - if type(self.body) == bytes: - try: - return jsonlib.loads(self.body.decode("utf-8")) - except: - return None - return None - - @json.setter - def json(self, json): - self.body = jsonlib.dumps(json).encode("utf-8") - self.headers["Content-Type"] = "application/json" - - -class InvalidEndpointTypeException(Exception): - def __init__(self): - super().__init__("Invalid endpoint type. Must be 'handler' or 'background'") - - -class RouteAlreadyInUseException(Exception): - def __init__(self): - super().__init__("Route already in use") +from multiprocessing import Pool as ProcessPool, Queue as ProcessQueue +from multiprocessing.pool import ThreadPool +from .status import PotassiumStatus, StatusEvent +from .worker import run_worker, init_worker +from .exceptions import RouteAlreadyInUseException, InvalidEndpointTypeException +from .types import Request, Endpoint, RequestHeaders, Response + +class ResponseMailbox(): + def __init__(self, response_queue): + self._response_queue = response_queue + self._mailbox = {} + self._lock = Lock() + + t = Thread(target=self._response_handler, daemon=True) + t.start() + + def _response_handler(self): + while True: + request_id, payload = self._response_queue.get() + with self._lock: + if request_id not in self._mailbox: + self._mailbox[request_id] = ThreadQueue() + self._mailbox[request_id].put(payload) + + def get_response(self, request_id): + with self._lock: + if request_id not in self._mailbox: + self._mailbox[request_id] = ThreadQueue() + result, stream_id = self._mailbox[request_id].get() + + if stream_id is not None: + result.body = self._stream_body(stream_id) + + with self._lock: + del self._mailbox[request_id] + + return result + + def _stream_body(self, stream_id): + with self._lock: + if stream_id not in self._mailbox: + self._mailbox[stream_id] = ThreadQueue() + queue = self._mailbox[stream_id] + + while True: + result = queue.get() + if isinstance(result, Exception): + with self._lock: + del self._mailbox[stream_id] + raise result + elif result == None: + break + else: + yield result + + with self._lock: + del self._mailbox[stream_id] class Potassium(): @@ -74,19 +74,35 @@ def __init__(self, name): self.name = name # default init function, if the user doesn't specify one - self._init_func = lambda: {} + self._init_func = lambda _: {} # dictionary to store unlimited Endpoints, by unique route self._endpoints = {} self._context = {} - self._gpu_lock = Lock() - self._background_task_cv = Condition() - self._sequence_number = 0 - self._sequence_number_lock = Lock() - self._idle_start_time = 0 - self._last_inference_start_time = None self._flask_app = self._create_flask_app() + self._event_queue = ProcessQueue() + self._response_queue = ProcessQueue() + self._response_mailbox = ResponseMailbox(self._response_queue) + + self._num_workers = int(os.environ.get("POTASSIUM_NUM_WORKERS", 1)) + + self._worker_pool = None + + self.event_handler_thread = Thread(target=self._event_handler, daemon=True) + + self._status = PotassiumStatus( + num_started_inference_requests=0, + num_completed_inference_requests=0, + num_workers=self._num_workers, + num_workers_started=0, + idle_start_timestamp=time.time(), + in_flight_request_start_times=[] + ) + + def _event_handler(self): + while True: + event = self._event_queue.get() + self._status = self._status.update(event) - # def init(self, func): """init runs once on server start, and is used to initialize the app's context. You can use this to load models onto the GPU, set up connections, etc. @@ -97,14 +113,14 @@ def init(self, func): - the context is not shared between multiple replicas of the app """ - def wrapper(): - print(colored("Running init()", 'yellow')) - self._context = func() - if not isinstance(self._context, dict): - raise Exception("Potassium init() must return a dictionary") + # def wrapper(worker_num): + # print(colored("Running init()", 'yellow')) + # self._context = func(worker_num) + # if not isinstance(self._context, dict): + # raise Exception("Potassium init() must return a dictionary") - self._init_func = wrapper - return wrapper + self._init_func = func + return func @staticmethod def _standardize_route(route): @@ -128,9 +144,9 @@ def handler(self, route: str = "/"): def actual_decorator(func): @functools.wraps(func) - def wrapper(request): + def wrapper(context, request): # send in app's stateful context if GPU, and the request - out = func(self._context, request) + out = func(context, request) if type(out) != Response: raise Exception("Potassium Response object not returned") @@ -166,86 +182,9 @@ def wrapper(request): def test_client(self): "test_client returns a Flask test client for the app" + self._init_server() return self._flask_app.test_client() - # _handle_generic takes in a request and the endpoint it was routed to and handles it as expected by that endpoint - def _handle_generic(self, endpoint, flask_request): - # potassium rejects if lock already in use - try: - self._gpu_lock.acquire(blocking=False) - except: - res = make_response() - res.status_code = 423 - return res - - res = None - self._last_inference_start_time = time.time() - - try: - req = Request( - headers=flask_request.headers, - json=flask_request.get_json(), - id=flask_request.headers.get("X-Banana-Request-Id", "") - ) - except: - res = make_response() - res.status_code = 400 - self._gpu_lock.release() - return res - - if endpoint.type == "handler": - try: - out = endpoint.func(req) - - # create flask response - res = make_response() - res = FlaskResponse( - out.body, status=out.status, headers=out.headers) - except: - tb_str = traceback.format_exc() - print(colored(tb_str, "red")) - res = make_response(tb_str) - res.status_code = 500 - self._idle_start_time = time.time() - self._last_inference_start_time = None - self._gpu_lock.release() - elif endpoint.type == "background": - # run as threaded task - def task(endpoint, lock, req): - try: - endpoint.func(req) - except Exception as e: - # do any cleanup before re-raising user error - raise e - finally: - with self._background_task_cv: - self._background_task_cv.notify_all() - - self._idle_start_time = time.time() - self._last_inference_start_time = None - lock.release() - - thread = Thread(target=task, args=(endpoint, self._gpu_lock, req)) - thread.start() - - # send task start success message - res = make_response({'started': True}) - else: - raise InvalidEndpointTypeException() - - return res - - # WARNING: cover depends on this being called so it should not be changed - def _read_event_chan(self) -> bool: - """ - _read_event_chan essentially waits for a background task to finish, - and then returns True - """ - with self._background_task_cv: - # wait until the background task is done - self._background_task_cv.wait() - return True - def _create_flask_app(self): flask_app = Flask(__name__) @@ -253,20 +192,57 @@ def _create_flask_app(self): @flask_app.route('/', defaults={'path': ''}, methods=["POST"]) @flask_app.route('/', methods=["POST"]) def handle(path): - with self._sequence_number_lock: - self._sequence_number += 1 - route = "/" + path if route not in self._endpoints: abort(404) endpoint = self._endpoints[route] - return self._handle_generic(endpoint, request) - + request_id = request.headers.get("X-Banana-Request-Id", None) + if request_id is None: + request_id = str(uuid.uuid4()) + try: + req = Request( + headers=RequestHeaders(dict(request.headers.items())), + json=request.get_json(), + id=request_id + ) + except: + res = make_response() + res.status_code = 400 + return res + + self._event_queue.put((StatusEvent.INFERENCE_REQUEST_RECEIVED,)) + + assert self._worker_pool is not None, "Worker pool not initialized" + # use an internal id for critical path to prevent user from accidentally + # breaking things by sending multiple requests with the same id + internal_id = str(uuid.uuid4()) + if endpoint.type == "handler": + self._worker_pool.apply_async(run_worker, args=(endpoint.func, req, internal_id, True)) + resp = self._response_mailbox.get_response(internal_id) + + flask_response = FlaskResponse( + resp.body, + status=resp.status, + headers=resp.headers + ) + elif endpoint.type == "background": + self._worker_pool.apply_async(run_worker, args=(endpoint.func, req, internal_id)) + + flask_response = make_response({'started': True}) + else: + raise InvalidEndpointTypeException() + + return flask_response + @flask_app.route('/_k/warmup', methods=["POST"]) def warm(): - with self._sequence_number_lock: - self._sequence_number += 1 + request_id = str(uuid.uuid4()) + + # a bit of a hack but we need to send a start and end event to the event queue + # in order to update the status the way the load balancer expects + self._event_queue.put((StatusEvent.INFERENCE_START, request_id)) + self._event_queue.put((StatusEvent.INFERENCE_END, request_id)) res = make_response({ "warm": True, }) @@ -276,33 +252,37 @@ def warm(): @flask_app.route('/_k/status', methods=["GET"]) @flask_app.route('/__status__', methods=["GET"]) def status(): - idle_time = 0 - inference_time = 0 - gpu_available = not self._gpu_lock.locked() - - if self._last_inference_start_time != None: - inference_time = int((time.time() - self._last_inference_start_time)*1000) - - if gpu_available: - idle_time = int((time.time() - self._idle_start_time)*1000) + cur_status = self._status res = make_response({ - "gpu_available": gpu_available, - "sequence_number": self._sequence_number, - "idle_time": idle_time, - "inference_time": inference_time, + "gpu_available": cur_status.gpu_available, + "sequence_number": cur_status.sequence_number, + "idle_time": cur_status.idle_time, + "inference_time": cur_status.longest_inference_time, }) res.status_code = 200 return res return flask_app + + def _init_server(self): + self._idle_start_time = time.time() + index_queue = ProcessQueue() + for i in range(self._num_workers): + index_queue.put(i) + if self._num_workers == 1: + Pool = ThreadPool + else: + Pool = ProcessPool + self._worker_pool = Pool(self._num_workers, init_worker, (index_queue, self._event_queue, self._response_queue, self._init_func)) # serve runs the http server def serve(self, host="0.0.0.0", port=8000): print(colored("------\nStarting Potassium Server 🍌", 'yellow')) - self._init_func() server = make_server(host, port, self._flask_app, threaded=True) print(colored(f"Serving at http://{host}:{port}\n------", 'green')) - self._idle_start_time = time.time() + self._init_server() + server.serve_forever() + diff --git a/potassium/status.py b/potassium/status.py new file mode 100644 index 0000000..72b4da7 --- /dev/null +++ b/potassium/status.py @@ -0,0 +1,97 @@ +from enum import Enum +import time +from typing import List, Tuple +from dataclasses import dataclass + +from .types import RequestID + +class StatusEvent(Enum): + INFERENCE_REQUEST_RECEIVED = "INFERENCE_REQUEST_RECEIVED" + INFERENCE_START = "INFERENCE_START" + INFERENCE_END = "INFERENCE_END" + WORKER_STARTED = "WORKER_STARTED" + +@dataclass +class PotassiumStatus(): + """PotassiumStatus is a simple class that represents the status of a Potassium app.""" + num_started_inference_requests: int + num_completed_inference_requests: int + num_workers: int + num_workers_started: int + idle_start_timestamp: float + in_flight_request_start_times: List[Tuple[RequestID, float]] + + @property + def requests_in_progress(self): + return self.num_started_inference_requests - self.num_completed_inference_requests + + @property + def gpu_available(self): + return self.num_workers - self.requests_in_progress > 0 + + @property + def sequence_number(self): + return self.num_started_inference_requests + + @property + def idle_time(self): + if not self.gpu_available: + return 0 + return time.time() - self.idle_start_timestamp + + @property + def longest_inference_time(self): + if self.in_flight_request_start_times == []: + return 0 + + oldest_start_time = min([start_time for _, start_time in self.in_flight_request_start_times]) + + return time.time() - oldest_start_time + + def update(self, event): + event_type = event[0] + event_data = event[1:] + if event_type not in event_handlers: + raise Exception(f"Invalid event {event}") + return event_handlers[event](self.clone(), *event_data) + + + def clone(self): + return PotassiumStatus( + self.num_started_inference_requests, + self.num_completed_inference_requests, + self.num_workers, + self.num_workers_started, + self.idle_start_timestamp, + self.in_flight_request_start_times + ) + +def handle_start_inference(status: PotassiumStatus, request_id: RequestID): + status.in_flight_request_start_times.append((request_id, time.time())) + return status + +def handle_end_inference(status: PotassiumStatus, request_id: RequestID): + status.num_completed_inference_requests += 1 + status.in_flight_request_start_times = [t for t in status.in_flight_request_start_times if t[0] != request_id] + + if status.gpu_available: + status.idle_start_timestamp = time.time() + + return status + +def handle_inference_request_received(status: PotassiumStatus): + status.num_started_inference_requests += 1 + return status + +def handle_worker_started(status: PotassiumStatus): + status.num_workers_started += 1 + return status + +event_handlers = { + StatusEvent.INFERENCE_REQUEST_RECEIVED: handle_inference_request_received, + StatusEvent.INFERENCE_START: handle_start_inference, + StatusEvent.INFERENCE_END: handle_end_inference, + StatusEvent.WORKER_STARTED: lambda status: status, +} + + diff --git a/potassium/types.py b/potassium/types.py new file mode 100644 index 0000000..6358cff --- /dev/null +++ b/potassium/types.py @@ -0,0 +1,68 @@ +from dataclasses import dataclass +from typing import Any, Callable, Dict, Generator, Optional, Union, Generator, Optional, Union +import json as jsonlib + +@dataclass +class Endpoint(): + type: str + func: Callable + +class RequestHeaders(): + def __init__(self, headers: Dict[str, str]): + self._headers = headers + + def __getitem__(self, key): + if not isinstance(key, str): + raise KeyError(key) + key = key.upper().replace("-", "_") + + return self._headers[key] + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + +@dataclass +class Request(): + id: str + headers: RequestHeaders + json: Dict[str, Any] + +ResponseBody = Union[bytes, Generator[bytes, None, None]] +RequestID = str + +class Response(): + def __init__(self, status: int = 200, json: Optional[dict] = None, headers: Optional[dict] = None, body: Optional[ResponseBody] = None): + assert json == None or body == None, "Potassium Response object cannot have both json and body set" + + + self.headers = headers if headers != None else {} + + # convert json to body if not None + if json != None: + self.body = jsonlib.dumps(json).encode("utf-8") + self.headers["Content-Type"] = "application/json" + else: + self.body = body + + self.status = status + + @property + def json(self): + if self.body == None: + return None + if type(self.body) == bytes: + try: + return jsonlib.loads(self.body.decode("utf-8")) + except: + return None + return None + + @json.setter + def json(self, json): + self.body = jsonlib.dumps(json).encode("utf-8") + self.headers["Content-Type"] = "application/json" + + diff --git a/potassium/worker.py b/potassium/worker.py new file mode 100644 index 0000000..6c6f19c --- /dev/null +++ b/potassium/worker.py @@ -0,0 +1,117 @@ +from multiprocessing import Queue +import os +import threading +from typing import Dict, Any, Generator +from dataclasses import dataclass +from flask import make_response, Response as FlaskResponse +from termcolor import colored +import traceback +import inspect + +from .status import StatusEvent +from .types import Response + +worker = None + +class FDRedirect(): + def __init__(self, fd: int): + self._fd = fd + self._fd_copy = os.dup(fd) + self._redirect_w = None + + def _run_redirect_loop(self, redirect_r, prefix): + redirect_r = os.fdopen(redirect_r, "r") + + for line in redirect_r: + os.write(self._fd_copy, (prefix + line).encode("utf-8")) + redirect_r.close() + + def set_prefix(self, prefix): + if self._redirect_w is not None: + os.dup2(self._fd_copy, self._fd) + os.close(self._redirect_w) + + fd = self._fd + redirect_r, redirect_w = os.pipe() + + self._fd_copy = os.dup(fd) + os.dup2(redirect_w, fd) + self._redirect_w = redirect_w + + t = threading.Thread(target=self._run_redirect_loop, args=(redirect_r, prefix)) + t.daemon = True + t.start() + + +@dataclass +class Worker(): + context: Dict[Any, Any] + event_queue: Queue + response_queue: Queue + stderr_redirect: FDRedirect + stdout_redirect: FDRedirect + + +def init_worker(index_queue, event_queue, response_queue, init_func): + global worker + worker_num = index_queue.get() + + # check if the init function takes in a worker number + if len(inspect.signature(init_func).parameters) == 0: + context = init_func() + else: + context = init_func(worker_num) + + event_queue.put((StatusEvent.WORKER_STARTED, worker_num)) + + worker = Worker( + context, + event_queue, + response_queue, + FDRedirect(1), + FDRedirect(2) + ) + +def run_worker(func, request, internal_id, use_response=False): + assert worker is not None, "worker is not initialized" + + worker.stderr_redirect.set_prefix(f"[requestID {request.id}] ") + worker.stdout_redirect.set_prefix(f"[requestID {request.id}] ") + + resp = None + worker.event_queue.put((StatusEvent.INFERENCE_START, internal_id)) + + try: + resp = func(worker.context, request) + except: + tb_str = traceback.format_exc() + print(colored(tb_str, "red")) + resp = Response( + status=500, + body=tb_str.encode("utf-8"), + headers={ + "Content-Type": "text/plain" + } + ) + + if use_response: + generator = None + stream_id = None + if inspect.isgenerator(resp.body): + stream_id = 'stream-' + internal_id + generator = resp.body + resp.body = None + print("has generator: ", generator is not None) + worker.response_queue.put((internal_id, (resp, stream_id))) + + # if the response is a generator, we need to iterate through it + if stream_id: + assert generator is not None + for chunk in generator: + worker.response_queue.put((stream_id, chunk)) + worker.response_queue.put((stream_id, None)) + + + + worker.event_queue.put((StatusEvent.INFERENCE_END, internal_id)) + diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 7622b02..edecd06 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -8,6 +8,7 @@ def test_handler(): app = potassium.Potassium("my_app") + global init @app.init def init(): return {} @@ -114,6 +115,7 @@ def handler5(context: dict, request: potassium.Request) -> potassium.Response: def test_path_collision(paths): app = potassium.Potassium("my_app") + global init @app.init def init(): return {} @@ -139,6 +141,7 @@ def test_status(): resolve_background_condition = threading.Condition() + global init @app.init def init(): return {} @@ -209,6 +212,7 @@ def test_wait_for_background_task(): order_of_execution_queue = queue.Queue() resolve_background_condition = threading.Condition() + global init @app.init def init(): return {} @@ -246,6 +250,8 @@ def wait_for_background_task(): def test_warmup(): app = potassium.Potassium("my_app") + + global init @app.init def init(): return {}