From 9865a7589196104ef941d8c9ef7e09f2d1b33536 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gary=20Wang=20=F0=9F=98=A4?= Date: Thu, 9 Jan 2025 21:42:30 +0000 Subject: [PATCH] Enable inference serving capabilities on sagemaker endpoint using tornado --- template/v3/Dockerfile | 2 +- .../sagemaker-inference-server/__init__.py | 3 + .../dirs/etc/sagemaker-inference-server/serve | 2 + .../etc/sagemaker-inference-server/serve.py | 6 + .../tornado_server/__init__.py | 12 ++ .../tornado_server/async_handler.py | 76 +++++++++++ .../tornado_server/server.py | 124 ++++++++++++++++++ .../tornado_server/stream_handler.py | 54 ++++++++ .../tornado_server/sync_handler.py | 78 +++++++++++ .../utils/__init__.py | 1 + .../utils/environment.py | 59 +++++++++ .../utils/exception.py | 21 +++ .../utils/logger.py | 43 ++++++ 13 files changed, 480 insertions(+), 1 deletion(-) create mode 100644 template/v3/dirs/etc/sagemaker-inference-server/__init__.py create mode 100755 template/v3/dirs/etc/sagemaker-inference-server/serve create mode 100644 template/v3/dirs/etc/sagemaker-inference-server/serve.py create mode 100644 template/v3/dirs/etc/sagemaker-inference-server/tornado_server/__init__.py create mode 100644 template/v3/dirs/etc/sagemaker-inference-server/tornado_server/async_handler.py create mode 100644 template/v3/dirs/etc/sagemaker-inference-server/tornado_server/server.py create mode 100644 template/v3/dirs/etc/sagemaker-inference-server/tornado_server/stream_handler.py create mode 100644 template/v3/dirs/etc/sagemaker-inference-server/tornado_server/sync_handler.py create mode 100644 template/v3/dirs/etc/sagemaker-inference-server/utils/__init__.py create mode 100644 template/v3/dirs/etc/sagemaker-inference-server/utils/environment.py create mode 100644 template/v3/dirs/etc/sagemaker-inference-server/utils/exception.py create mode 100644 template/v3/dirs/etc/sagemaker-inference-server/utils/logger.py diff --git a/template/v3/Dockerfile b/template/v3/Dockerfile index 4cfb9e30..ec510dc1 100644 --- a/template/v3/Dockerfile +++ b/template/v3/Dockerfile @@ -190,7 +190,7 @@ RUN mkdir -p $SAGEMAKER_LOGGING_DIR && \ && ${HOME_DIR}/oss_compliance/generate_oss_compliance.sh ${HOME_DIR} python \ && rm -rf ${HOME_DIR}/oss_compliance* -ENV PATH="/opt/conda/bin:/opt/conda/condabin:$PATH" +ENV PATH="/etc/sagemaker-inference-server:/opt/conda/bin:/opt/conda/condabin:$PATH" WORKDIR "/home/${NB_USER}" ENV SHELL=/bin/bash ENV OPENSSL_MODULES=/opt/conda/lib64/ossl-modules/ diff --git a/template/v3/dirs/etc/sagemaker-inference-server/__init__.py b/template/v3/dirs/etc/sagemaker-inference-server/__init__.py new file mode 100644 index 00000000..0427e383 --- /dev/null +++ b/template/v3/dirs/etc/sagemaker-inference-server/__init__.py @@ -0,0 +1,3 @@ +from __future__ import absolute_import + +import utils.logger diff --git a/template/v3/dirs/etc/sagemaker-inference-server/serve b/template/v3/dirs/etc/sagemaker-inference-server/serve new file mode 100755 index 00000000..bd604df3 --- /dev/null +++ b/template/v3/dirs/etc/sagemaker-inference-server/serve @@ -0,0 +1,2 @@ +#!/bin/bash +python /etc/sagemaker-inference-server/serve.py diff --git a/template/v3/dirs/etc/sagemaker-inference-server/serve.py b/template/v3/dirs/etc/sagemaker-inference-server/serve.py new file mode 100644 index 00000000..d45cf256 --- /dev/null +++ b/template/v3/dirs/etc/sagemaker-inference-server/serve.py @@ -0,0 +1,6 @@ +from __future__ import absolute_import + +from tornado_server.server import TornadoServer + +inference_server = TornadoServer() +inference_server.serve() diff --git a/template/v3/dirs/etc/sagemaker-inference-server/tornado_server/__init__.py b/template/v3/dirs/etc/sagemaker-inference-server/tornado_server/__init__.py new file mode 100644 index 00000000..28b0e2cc --- /dev/null +++ b/template/v3/dirs/etc/sagemaker-inference-server/tornado_server/__init__.py @@ -0,0 +1,12 @@ +from __future__ import absolute_import + +import pathlib +import sys + +# make the utils modules accessible to modules from within the tornado_server folder +utils_path = pathlib.Path(__file__).parent.parent / "utils" +sys.path.insert(0, str(utils_path.resolve())) + +# make the tornado_server modules accessible to each other +tornado_module_path = pathlib.Path(__file__).parent +sys.path.insert(0, str(tornado_module_path.resolve())) diff --git a/template/v3/dirs/etc/sagemaker-inference-server/tornado_server/async_handler.py b/template/v3/dirs/etc/sagemaker-inference-server/tornado_server/async_handler.py new file mode 100644 index 00000000..e67ec277 --- /dev/null +++ b/template/v3/dirs/etc/sagemaker-inference-server/tornado_server/async_handler.py @@ -0,0 +1,76 @@ +from __future__ import absolute_import + +import asyncio +import logging +from typing import AsyncGenerator, Generator + +import tornado.web +from stream_handler import StreamHandler + +from utils.environment import Environment +from utils.exception import AsyncInvocationsException +from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER + +logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER) + + +class InvocationsHandler(tornado.web.RequestHandler, StreamHandler): + """Handler mapped to the /invocations POST route. + + This handler wraps the async handler retrieved from the inference script + and encapsulates it behind the post() method. The post() method is done + asynchronously. + """ + + def initialize(self, handler: callable, environment: Environment): + """Initializes the handler function and the serving environment.""" + + self._handler = handler + self._environment = environment + + async def post(self): + """POST method used to encapsulate and invoke the async handle method asynchronously""" + + try: + response = await self._handler(self.request) + + if isinstance(response, Generator): + await self.stream(response) + elif isinstance(response, AsyncGenerator): + await self.astream(response) + else: + self.write(response) + except Exception as e: + raise AsyncInvocationsException(e) + + +class PingHandler(tornado.web.RequestHandler): + """Handler mapped to the /ping GET route. + + Ping handler to monitor the health of the Tornados server. + """ + + def get(self): + """Simple GET method to assess the health of the server.""" + + self.write("") + + +async def handle(handler: callable, environment: Environment): + """Serves the async handler function using Tornado. + + Opens the /invocations and /ping routes used by a SageMaker Endpoint + for inference serving capabilities. + """ + + logger.info("Starting inference server in asynchronous mode...") + + app = tornado.web.Application( + [ + (r"/invocations", InvocationsHandler, dict(handler=handler, environment=environment)), + (r"/ping", PingHandler), + ] + ) + app.listen(environment.port) + logger.debug(f"Asynchronous inference server listening on port: `{environment.port}`") + await asyncio.Event().wait() diff --git a/template/v3/dirs/etc/sagemaker-inference-server/tornado_server/server.py b/template/v3/dirs/etc/sagemaker-inference-server/tornado_server/server.py new file mode 100644 index 00000000..69310038 --- /dev/null +++ b/template/v3/dirs/etc/sagemaker-inference-server/tornado_server/server.py @@ -0,0 +1,124 @@ +from __future__ import absolute_import + +import asyncio +import importlib +import logging +import subprocess +import sys +from pathlib import Path + +from utils.environment import Environment +from utils.exception import ( + InferenceCodeLoadException, + RequirementsInstallException, + ServerStartException, +) +from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER + +logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER) + + +class TornadoServer: + """Holds serving logic using the Tornado framework. + + The serve.py script will invoke TornadoServer.serve() to start the serving process. + The TornadoServer will install the runtime requirements specified through a requirements file. + It will then load an handler function within an inference script and then front it will an /invocations + route using the Tornado framework. + """ + + def __init__(self): + """Initialize the serving behaviors. + + Defines the serving behavior through Environment() and locate where + the inference code is contained. + """ + + self._environment = Environment() + logger.setLevel(self._environment.logging_level) + logger.debug(f"Environment: {str(self._environment)}") + + self._path_to_inference_code = ( + Path(self._environment.base_directory).joinpath(self._environment.code_directory) + if self._environment.code_directory + else Path(self._environment.base_directory) + ) + logger.debug(f"Path to inference code: `{str(self._path_to_inference_code)}`") + + self._handler = None + + def initialize(self): + """Initialize the serving artifacts and dependencies. + + Install the runtime requirements and then locate the handler function from + the inference script. + """ + + logger.info("Initializing inference server...") + self._install_runtime_requirements() + self._handler = self._load_inference_handler() + + def serve(self): + """Orchestrate the initialization and server startup behavior. + + Call the initalize() method, determine the right Tornado serving behavior (async or sync), + and then start the Tornado server through asyncio + """ + + logger.info("Serving inference requests using Tornado...") + self.initialize() + + if asyncio.iscoroutinefunction(self._handler): + import async_handler as inference_handler + else: + import sync_handler as inference_handler + + try: + asyncio.run(inference_handler.handle(self._handler, self._environment)) + except Exception as e: + raise ServerStartException(e) + + def _install_runtime_requirements(self): + """Install the runtime requirements.""" + + logger.info("Installing runtime requirements...") + requirements_txt = self._path_to_inference_code.joinpath(self._environment.requirements) + if requirements_txt.is_file(): + try: + subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", str(requirements_txt)]) + except Exception as e: + raise RequirementsInstallException(e) + else: + logger.debug(f"No requirements file was found at `{str(requirements_txt)}`") + + def _load_inference_handler(self) -> callable: + """Load the handler function from the inference script.""" + + logger.info("Loading inference handler...") + inference_module_name, handle_name = self._environment.code.split(".") + if inference_module_name and handle_name: + inference_module_file = f"{inference_module_name}.py" + module_spec = importlib.util.spec_from_file_location( + inference_module_file, str(self._path_to_inference_code.joinpath(inference_module_file)) + ) + if module_spec: + sys.path.insert(0, str(self._path_to_inference_code.resolve())) + module = importlib.util.module_from_spec(module_spec) + module_spec.loader.exec_module(module) + + if hasattr(module, handle_name): + handler = getattr(module, handle_name) + else: + logger.info(dir(inference_module)) + raise InferenceCodeLoadException( + f"Handler `{handle_name}` could not be found in module `{inference_module_file}`" + ) + logger.debug(f"Loaded handler `{handle_name}` from module `{inference_module_name}`") + return handler + else: + raise InferenceCodeLoadException( + f"Inference code could not be found at `{str(self._path_to_inference_code.joinpath(inference_module_file))}`" + ) + raise InferenceCodeLoadException( + f"Inference code expected in the format of `.` but was provided as {self._environment.code}" + ) diff --git a/template/v3/dirs/etc/sagemaker-inference-server/tornado_server/stream_handler.py b/template/v3/dirs/etc/sagemaker-inference-server/tornado_server/stream_handler.py new file mode 100644 index 00000000..c3c9ea4b --- /dev/null +++ b/template/v3/dirs/etc/sagemaker-inference-server/tornado_server/stream_handler.py @@ -0,0 +1,54 @@ +from __future__ import absolute_import + +from typing import AsyncGenerator, Generator + +from tornado.ioloop import IOLoop + + +class StreamHandler: + """Mixin that enables async and sync streaming capabilities to the async and sync handlers + + stream() runs a provided generator fn in an async manner. + astream() runs a provided async generator fn in an async manner. + """ + + async def stream(self, generator: Generator): + """Streams the response from a sync response generator + + A sync generator must be manually iterated through asynchronously. + In a loop, iterate through each next(generator) call in an async execution. + """ + + self._set_stream_headers() + + while True: + try: + chunk = await IOLoop.current().run_in_executor(None, next, generator) + # Some generators do not throw a StopIteration upon exhaustion. + # Instead, they return an empty response. Account for this case. + if not chunk: + raise StopIteration() + + self.write(chunk) + await self.flush() + except StopIteration: + break + except Exception as e: + logger.error("Unexpected exception occurred when streaming response...") + break + + async def astream(self, agenerator: AsyncGenerator): + """Streams the response from an async response generator""" + + self._set_stream_headers() + + async for chunk in agenerator: + self.write(chunk) + await self.flush() + + def _set_stream_headers(self): + """Set the headers in preparation for the streamed response""" + + self.set_header("Content-Type", "text/event-stream") + self.set_header("Cache-Control", "no-cache") + self.set_header("Connection", "keep-alive") diff --git a/template/v3/dirs/etc/sagemaker-inference-server/tornado_server/sync_handler.py b/template/v3/dirs/etc/sagemaker-inference-server/tornado_server/sync_handler.py new file mode 100644 index 00000000..ff69dce5 --- /dev/null +++ b/template/v3/dirs/etc/sagemaker-inference-server/tornado_server/sync_handler.py @@ -0,0 +1,78 @@ +from __future__ import absolute_import + +import asyncio +import logging +from typing import AsyncGenerator, Generator + +import tornado.web +from stream_handler import StreamHandler +from tornado.ioloop import IOLoop + +from utils.environment import Environment +from utils.exception import SyncInvocationsException +from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER + +logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER) + + +class InvocationsHandler(tornado.web.RequestHandler, StreamHandler): + """Handler mapped to the /invocations POST route. + + This handler wraps the sync handler retrieved from the inference script + and encapsulates it behind the post() method. The post() method is done + asynchronously. + """ + + def initialize(self, handler: callable, environment: Environment): + """Initializes the handler function and the serving environment.""" + + self._handler = handler + self._environment = environment + + async def post(self): + """POST method used to encapsulate and invoke the sync handle method asynchronously""" + + try: + response = await IOLoop.current().run_in_executor(None, self._handler, self.request) + + if isinstance(response, Generator): + logger.warning("") + await self.stream(response) + elif isinstance(response, AsyncGenerator): + await self.astream(response) + else: + self.write(response) + except Exception as e: + raise SyncInvocationsException(e) + + +class PingHandler(tornado.web.RequestHandler): + """Handler mapped to the /ping GET route. + + Ping handler to monitor the health of the Tornados server. + """ + + def get(self): + """Simple GET method to assess the health of the server.""" + + self.write("") + + +async def handle(handler: callable, environment: Environment): + """Serves the sync handler function using Tornado. + + Opens the /invocations and /ping routes used by a SageMaker Endpoint + for inference serving capabilities. + """ + + logger.info("Starting inference server in synchronous mode...") + + app = tornado.web.Application( + [ + (r"/invocations", InvocationsHandler, dict(handler=handler, environment=environment)), + (r"/ping", PingHandler), + ] + ) + app.listen(environment.port) + logger.debug(f"Synchronous inference server listening on port: `{environment.port}`") + await asyncio.Event().wait() diff --git a/template/v3/dirs/etc/sagemaker-inference-server/utils/__init__.py b/template/v3/dirs/etc/sagemaker-inference-server/utils/__init__.py new file mode 100644 index 00000000..c3961685 --- /dev/null +++ b/template/v3/dirs/etc/sagemaker-inference-server/utils/__init__.py @@ -0,0 +1 @@ +from __future__ import absolute_import diff --git a/template/v3/dirs/etc/sagemaker-inference-server/utils/environment.py b/template/v3/dirs/etc/sagemaker-inference-server/utils/environment.py new file mode 100644 index 00000000..0cda0c09 --- /dev/null +++ b/template/v3/dirs/etc/sagemaker-inference-server/utils/environment.py @@ -0,0 +1,59 @@ +from __future__ import absolute_import + +import json +import os +from enum import Enum + + +class SageMakerInference(str, Enum): + """Simple enum to define the mapping between dictionary key and environement variable.""" + + BASE_DIRECTORY = "SAGEMAKER_INFERENCE_BASE_DIRECTORY" + REQUIREMENTS = "SAGEMAKER_INFERENCE_REQUIREMENTS" + CODE_DIRECTORY = "SAGEMAKER_INFERENCE_CODE_DIRECTORY" + CODE = "SAGEMAKER_INFERENCE_CODE" + LOGGING_LEVEL = "SAGEMAKER_INFERENCE_LOGGING_LEVEL" + PORT = "SAGEMAKER_INFERENCE_PORT" + + +class Environment: + """Retrieves and encapsulates SAGEMAKER_INFERENCE prefixed environment variables.""" + + def __init__(self): + """Initialize the environment variable mapping""" + + self._environment_variables = { + SageMakerInference.BASE_DIRECTORY: "/opt/ml/model", + SageMakerInference.REQUIREMENTS: "requirements.txt", + SageMakerInference.CODE_DIRECTORY: os.getenv(SageMakerInference.CODE_DIRECTORY, None), + SageMakerInference.CODE: os.getenv(SageMakerInference.CODE, "inference.handler"), + SageMakerInference.LOGGING_LEVEL: os.getenv(SageMakerInference.LOGGING_LEVEL, 10), + SageMakerInference.PORT: os.getenv(SageMakerInference.PORT, 8080), + } + + def __str__(self): + return json.dumps(self._environment_variables) + + @property + def base_directory(self): + return self._environment_variables.get(SageMakerInference.BASE_DIRECTORY) + + @property + def requirements(self): + return self._environment_variables.get(SageMakerInference.REQUIREMENTS) + + @property + def code_directory(self): + return self._environment_variables.get(SageMakerInference.CODE_DIRECTORY) + + @property + def code(self): + return self._environment_variables.get(SageMakerInference.CODE) + + @property + def logging_level(self): + return self._environment_variables.get(SageMakerInference.LOGGING_LEVEL) + + @property + def port(self): + return self._environment_variables.get(SageMakerInference.PORT) diff --git a/template/v3/dirs/etc/sagemaker-inference-server/utils/exception.py b/template/v3/dirs/etc/sagemaker-inference-server/utils/exception.py new file mode 100644 index 00000000..eb961889 --- /dev/null +++ b/template/v3/dirs/etc/sagemaker-inference-server/utils/exception.py @@ -0,0 +1,21 @@ +from __future__ import absolute_import + + +class RequirementsInstallException(Exception): + pass + + +class InferenceCodeLoadException(Exception): + pass + + +class ServerStartException(Exception): + pass + + +class SyncInvocationsException(Exception): + pass + + +class AsyncInvocationsException(Exception): + pass diff --git a/template/v3/dirs/etc/sagemaker-inference-server/utils/logger.py b/template/v3/dirs/etc/sagemaker-inference-server/utils/logger.py new file mode 100644 index 00000000..c8800868 --- /dev/null +++ b/template/v3/dirs/etc/sagemaker-inference-server/utils/logger.py @@ -0,0 +1,43 @@ +from __future__ import absolute_import + +import logging.config + +SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER = "sagemaker_distribution.inference_server" +LOGGING_CONFIG = { + "version": 1, + "disable_existing_loggers": True, + "formatters": { + "standard": {"format": "%(asctime)s [%(levelname)s] %(name)s: %(message)s"}, + }, + "handlers": { + "default": { + "level": "DEBUG", + "formatter": "standard", + "class": "logging.StreamHandler", + "stream": "ext://sys.stdout", + }, + }, + "loggers": { + SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER: { + "level": "DEBUG", + "handlers": ["default"], + "propagate": True, + }, + "tornado.application": { + "level": "DEBUG", + "handlers": ["default"], + "propagate": True, + }, + "tornado.general": { + "level": "DEBUG", + "handlers": ["default"], + "propagate": True, + }, + "tornado.access": { + "level": "DEBUG", + "handlers": ["default"], + "propagate": True, + }, + }, +} +logging.config.dictConfig(LOGGING_CONFIG)