diff --git a/src/deepsparse/transformers/helpers.py b/src/deepsparse/transformers/helpers.py index d7acc71a99..7273b61406 100644 --- a/src/deepsparse/transformers/helpers.py +++ b/src/deepsparse/transformers/helpers.py @@ -17,24 +17,26 @@ """ +import logging import os import re from pathlib import Path from tempfile import NamedTemporaryFile -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy import onnx +import transformers from onnx import ModelProto from deepsparse.log import get_main_logger -from deepsparse.utils.onnx import _MODEL_DIR_ONNX_NAME, truncate_onnx_model +from deepsparse.utils.onnx import MODEL_ONNX_NAME, truncate_onnx_model from sparsezoo import Model from sparsezoo.utils import save_onnx __all__ = [ - "get_deployment_path", + "setup_transformers_pipeline", "overwrite_transformer_onnx_model_inputs", "fix_numpy_types", "get_transformer_layer_init_names", @@ -44,7 +46,94 @@ _LOGGER = get_main_logger() -def get_deployment_path(model_path: str) -> Tuple[str, str]: +def setup_transformers_pipeline( + model_path: str, + sequence_length: int, + tokenizer_padding_side: str = "left", + engine_kwargs: Optional[Dict] = None, + onnx_model_name: Optional[str] = None, +) -> Tuple[ + str, transformers.PretrainedConfig, transformers.PreTrainedTokenizer, Dict[str, Any] +]: + """ + A helper function that sets up the model path, config, tokenizer, + and engine kwargs for a transformers model. + :param model_path: The path to the model to load + :param sequence_length: The sequence length to use for the model + :param tokenizer_padding_side: The side to pad on for the tokenizer, + either "left" or "right" + :param engine_kwargs: The kwargs to pass to the engine + :param onnx_model_name: The name of the onnx model to be loaded. + If not specified, defaults are used (see setup_onnx_file_path) + :return The model path, config, tokenizer, and engine kwargs + """ + model_path, config, tokenizer = setup_onnx_file_path( + model_path, sequence_length, onnx_model_name + ) + + tokenizer.padding_side = tokenizer_padding_side + if not tokenizer.pad_token: + tokenizer.pad_token = tokenizer.eos_token + + engine_kwargs = engine_kwargs or {} + if engine_kwargs.get("model_path"): + raise ValueError( + "The engine kwargs already specify " + f"a model path: {engine_kwargs['model_path']}, " + f"but a model path was also provided: {model_path}. " + "Please only provide one." + ) + engine_kwargs["model_path"] = model_path + return model_path, config, tokenizer, engine_kwargs + + +def setup_onnx_file_path( + model_path: str, + sequence_length: int, + onnx_model_name: Optional[str] = None, + task: Optional[str] = None, +) -> Tuple[str, transformers.PretrainedConfig, transformers.PreTrainedTokenizer]: + """ + Parses ONNX model from the `model_path` provided. It additionally + creates config and tokenizer objects from the `deployment path`, + derived from the `model_path` provided. + :param model_path: path to the model to be parsed + :param sequence_length: maximum sequence length of the model + :param onnx_model_name: optionally, the precise name of the ONNX model + of interest may be specified. If not specified, the default ONNX model + name will be used (refer to `get_deployment_path` for details) + :return: file path to the processed ONNX file for the engine to compile + """ + deployment_path, onnx_path = get_deployment_path(model_path, onnx_model_name) + + hf_logger = logging.getLogger("transformers") + hf_logger_level = hf_logger.level + hf_logger.setLevel(logging.ERROR) + + config = transformers.PretrainedConfig.from_pretrained( + deployment_path, finetuning_task=task + ) + hf_logger.setLevel(hf_logger_level) + + trust_remote_code = False + tokenizer = transformers.AutoTokenizer.from_pretrained( + deployment_path, + trust_remote_code=trust_remote_code, + model_max_length=sequence_length, + ) + + if not config or not tokenizer: + raise RuntimeError( + "Invalid config or tokenizer provided. Please provide " + "paths to the files or ensure they exist in the `model_path` provided. " + "See `tokenizer` and `config` arguments for details." + ) + return onnx_path, config, tokenizer + + +def get_deployment_path( + model_path: str, onnx_model_name: Optional[str] = None +) -> Tuple[str, str]: """ Returns the path to the deployment directory for the given model path and the path to the mandatory @@ -53,9 +142,12 @@ def get_deployment_path(model_path: str) -> Tuple[str, str]: for running the transformers model in the deepsparse pipeline :param model_path: path to model directory, sparsezoo stub, or ONNX file + :param onnx_model_name: name of the ONNX file to look for in the deployment + directory. Defaults to MODEL_ONNX_NAME :return: path to the deployment directory and path to the ONNX file inside the deployment directory """ + onnx_model_name = onnx_model_name or MODEL_ONNX_NAME if os.path.isfile(model_path): # return the parent directory of the ONNX file return os.path.dirname(model_path), model_path @@ -63,26 +155,26 @@ def get_deployment_path(model_path: str) -> Tuple[str, str]: if os.path.isdir(model_path): model_files = os.listdir(model_path) - if _MODEL_DIR_ONNX_NAME not in model_files: + if onnx_model_name not in model_files: raise ValueError( - f"{_MODEL_DIR_ONNX_NAME} not found in transformers model directory " + f"{onnx_model_name} not found in transformers model directory " f"{model_path}. Be sure that an export of the model is written to " - f"{os.path.join(model_path, _MODEL_DIR_ONNX_NAME)}" + f"{os.path.join(model_path, onnx_model_name)}" ) - return model_path, os.path.join(model_path, _MODEL_DIR_ONNX_NAME) + return model_path, os.path.join(model_path, onnx_model_name) elif model_path.startswith("zoo:"): zoo_model = Model(model_path) deployment_path = zoo_model.deployment_directory_path - return deployment_path, os.path.join(deployment_path, _MODEL_DIR_ONNX_NAME) + return deployment_path, os.path.join(deployment_path, onnx_model_name) elif model_path.startswith("hf:"): from huggingface_hub import snapshot_download deployment_path = snapshot_download(repo_id=model_path.replace("hf:", "", 1)) - onnx_path = os.path.join(deployment_path, _MODEL_DIR_ONNX_NAME) + onnx_path = os.path.join(deployment_path, onnx_model_name) if not os.path.isfile(onnx_path): raise ValueError( - f"{_MODEL_DIR_ONNX_NAME} not found in transformers model directory " + f"{onnx_model_name} not found in transformers model directory " f"{deployment_path}. Be sure that an export of the model is written to " f"{onnx_path}" ) diff --git a/src/deepsparse/transformers/pipelines/pipeline.py b/src/deepsparse/transformers/pipelines/pipeline.py index 065a26ce71..ac54c4a3db 100644 --- a/src/deepsparse/transformers/pipelines/pipeline.py +++ b/src/deepsparse/transformers/pipelines/pipeline.py @@ -16,19 +16,18 @@ Base Pipeline class for transformers inference pipeline """ -import logging + import warnings from pathlib import Path from typing import Any, Dict, List, Mapping, Optional, Union import numpy import transformers -from transformers.models.auto import AutoTokenizer from deepsparse import Bucketable, Pipeline +from deepsparse.transformers.helpers import overwrite_transformer_onnx_model_inputs from deepsparse.transformers.helpers import ( - get_deployment_path, - overwrite_transformer_onnx_model_inputs, + setup_onnx_file_path as setup_onnx_file_path_v2, ) @@ -124,24 +123,15 @@ def setup_onnx_file_path(self) -> str: :return: file path to the processed ONNX file for the engine to compile """ - deployment_path, onnx_path = get_deployment_path(self.model_path) - - # temporarily set transformers logger to ERROR to avoid - # printing misleading warnings - hf_logger = logging.getLogger("transformers") - hf_logger_level = hf_logger.level - hf_logger.setLevel(logging.ERROR) - self.config = transformers.PretrainedConfig.from_pretrained( - deployment_path, - finetuning_task=self.task if hasattr(self, "task") else None, - ) - hf_logger.setLevel(hf_logger_level) - - self.tokenizer = AutoTokenizer.from_pretrained( - deployment_path, - trust_remote_code=self._trust_remote_code, - model_max_length=self.sequence_length, + # we will be soon retiring V1 pipelines. This is why I am deciding + # to reuse the functions from V2 pipelines in the (soon) legacy pipelines + onnx_path, config, tokenizer = setup_onnx_file_path_v2( + model_path=self.model_path, + sequence_length=self.sequence_length, + task=self.task if hasattr(self, "task") else None, ) + self.config = config + self.tokenizer = tokenizer if not self._delay_overwriting_inputs: # overwrite onnx graph to given required input shape @@ -153,12 +143,6 @@ def setup_onnx_file_path(self) -> str: onnx_path, max_length=self.sequence_length ) - if not self.config or not self.tokenizer: - raise RuntimeError( - "Invalid config or tokenizer provided. Please provide " - "paths to the files or ensure they exist in the `model_path` provided. " - "See `tokenizer` and `config` arguments for details." - ) return onnx_path def tokens_to_engine_input( diff --git a/src/deepsparse/transformers/utils/helpers.py b/src/deepsparse/transformers/utils/helpers.py index 38e3ec4a4c..648bdef9cf 100644 --- a/src/deepsparse/transformers/utils/helpers.py +++ b/src/deepsparse/transformers/utils/helpers.py @@ -14,7 +14,7 @@ import logging import pathlib import uuid -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy from transformers import AutoTokenizer, GenerationConfig @@ -33,6 +33,7 @@ "override_config", "process_generation_config", "validate_session_ids", + "compute_engine_inputs", "set_generated_length", ] @@ -82,6 +83,95 @@ def set_generated_length( ) +def compute_engine_inputs(onnx_input_names: str, **kwargs) -> List[numpy.ndarray]: + """ + Given the names of the onnx inputs, compute the inputs + to the engine. The inputs will be calculating from the + passed kwargs. The information about the required kwargs + can be found in the docstring of the individual compute + functions. + + :param onnx_input_names: The names of the onnx inputs + :param kwargs: The kwargs to compute the inputs from + :return: The computed inputs to the engine + """ + engine_inputs = [] + for input_name in onnx_input_names: + if input_name == "causal_mask": + # delay the computation of the causal mask + continue + # fetch the compute function for the + # given input_name + compute_func = _get_compute_func(input_name) + # compute the engine input from the kwargs + # and append it to the engine_inputs + engine_inputs.append(compute_func(**kwargs)) + + if "causal_mask" in onnx_input_names: + # compute the causal mask and append it to the engine_inputs + input_ids, attention_mask, *_ = engine_inputs + engine_inputs.append(create_causal_mask(input_ids, attention_mask)) + + return engine_inputs + + +def _get_compute_func(input_name: str) -> Callable[..., numpy.ndarray]: + # given the input_name, return the appropriate compute function + compute_func = { + "input_ids": _compute_input_ids, + "attention_mask": _compute_attention_mask, + "positions": _compute_positions, + }.get(input_name) + if compute_func is None: + raise ValueError( + "Could not find compute function " f"for the input_name: {input_name}" + ) + return compute_func + + +def _compute_input_ids(token_batch: List[int], **kwargs) -> numpy.ndarray: + # convert the token_batch to a numpy array + return numpy.array([token_batch]) + + +def _compute_attention_mask( + sequence_length: int, + prompt_sequence_length: int, + num_total_processed_tokens: int, + **kwargs, +) -> numpy.ndarray: + # create a fully masked attention mask with the appropriate + # shape (equal to the sequence_length) + attention_mask = numpy.zeros((1, sequence_length), dtype=numpy.int64) + # unmask the appropriate number of tokens, the sum of + # - the number of tokens already processed and cached (num_total_processed_tokens) + # - the number of tokens currently processed (prompt_sequence_length) + # the sum cannot exceed the maximum length of the attention_mask + num_attention_entries_to_unmask = min( + num_total_processed_tokens + prompt_sequence_length, sequence_length + ) + # unmask the bits from the right-hand side + attention_mask[:, -num_attention_entries_to_unmask:] = 1 + return attention_mask + + +def _compute_positions( + num_total_processed_tokens: int, prompt_sequence_length: int, **kwargs +): + # create the positions array with the appropriate shape + # positions count starts from the number of tokens already processed + # and ends at the number of tokens already processed + the number of tokens + # currently processed + return ( + numpy.arange( + num_total_processed_tokens, + num_total_processed_tokens + prompt_sequence_length, + ) + .reshape(1, -1) + .astype(numpy.int64) + ) + + def validate_session_ids( session_ids: Optional[str], other_attributes: Dict[str, Any] ) -> Optional[List[str]]: diff --git a/src/deepsparse/transformers/utils/token_generator.py b/src/deepsparse/transformers/utils/token_generator.py index 5fa82b7bc4..0421da06e2 100644 --- a/src/deepsparse/transformers/utils/token_generator.py +++ b/src/deepsparse/transformers/utils/token_generator.py @@ -77,16 +77,17 @@ def generate(self, logits: numpy.ndarray) -> numpy.ndarray: :param logits: the logits from the model with shape (vocab_size,) :return: the sampled token """ - if self.top_k: - logits = self.apply_top_k(logits) - if self.top_p: - logits = self.apply_top_p(logits) - if self.deterministic: token = numpy.argmax(logits) self.tokens.append(token) return token + if self.top_k: + logits = self.apply_top_k(logits) + + if self.top_p: + logits = self.apply_top_p(logits) + if self.sampling_temperature != 1.0: logits /= self.sampling_temperature diff --git a/src/deepsparse/utils/onnx.py b/src/deepsparse/utils/onnx.py index e69bf67321..f518620c2f 100644 --- a/src/deepsparse/utils/onnx.py +++ b/src/deepsparse/utils/onnx.py @@ -56,12 +56,12 @@ "has_model_kv_cache", "CACHE_INPUT_PREFIX", "CACHE_OUTPUT_PREFIX", - "_MODEL_DIR_ONNX_NAME", + "MODEL_ONNX_NAME", ] _LOGGER = logging.getLogger(__name__) -_MODEL_DIR_ONNX_NAME = "model.onnx" +MODEL_ONNX_NAME = "model.onnx" CACHE_INPUT_PREFIX = "past_key_values" CACHE_OUTPUT_PREFIX = "present" @@ -132,7 +132,7 @@ def model_to_path(model: Union[str, Model, File]) -> str: model.deployment_directory_path # default to the main onnx file for the model - model = model.deployment.get_file(_MODEL_DIR_ONNX_NAME).path + model = model.deployment.get_file(MODEL_ONNX_NAME).path elif File is not object and isinstance(model, File): # get the downloaded_path -- will auto download if not on local system @@ -146,7 +146,7 @@ def model_to_path(model: Union[str, Model, File]) -> str: model_path = Path(model) if model_path.is_dir(): - return str(model_path / _MODEL_DIR_ONNX_NAME) + return str(model_path / MODEL_ONNX_NAME) return model diff --git a/src/deepsparse/v2/__init__.py b/src/deepsparse/v2/__init__.py new file mode 100644 index 0000000000..5fd33a9503 --- /dev/null +++ b/src/deepsparse/v2/__init__.py @@ -0,0 +1,22 @@ +# flake8: noqa + +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .operators import * +from .pipeline import * +from .routers import * +from .schedulers import * +from .task import * +from .utils import * diff --git a/src/deepsparse/v2/image_classification/__init__.py b/src/deepsparse/v2/image_classification/__init__.py new file mode 100644 index 0000000000..8668227df7 --- /dev/null +++ b/src/deepsparse/v2/image_classification/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa +from .postprocess_operator import * +from .preprocess_operator import * + + +from .pipeline import * # isort:skip diff --git a/src/deepsparse/v2/image_classification/pipeline.py b/src/deepsparse/v2/image_classification/pipeline.py new file mode 100644 index 0000000000..3d7887a701 --- /dev/null +++ b/src/deepsparse/v2/image_classification/pipeline.py @@ -0,0 +1,62 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import warnings +from typing import Dict, Optional, Tuple, Union + +from deepsparse.v2.image_classification.postprocess_operator import ( + ImageClassificationPostProcess, +) +from deepsparse.v2.image_classification.preprocess_operator import ( + ImageClassificationPreProcess, +) +from deepsparse.v2.operators.engine_operator import EngineOperator +from deepsparse.v2.pipeline import Pipeline +from deepsparse.v2.routers.router import LinearRouter +from deepsparse.v2.schedulers.scheduler import OperatorScheduler + + +_LOGGER = logging.getLogger(__name__) + +__all__ = ["ImageClassificationPipeline"] + + +class ImageClassificationPipeline(Pipeline): + def __init__( + self, + model_path: str, + engine_kwargs: Optional[Dict] = None, + class_names: Union[None, str, Dict[str, str]] = None, + image_size: Optional[Tuple[int]] = None, + top_k: int = 1, + ): + if not engine_kwargs: + engine_kwargs = {} + engine_kwargs["model_path"] = model_path + elif engine_kwargs.get("model_path") != model_path: + warnings.warn(f"Updating engine_kwargs to include {model_path}") + + engine = EngineOperator(**engine_kwargs) + preproces = ImageClassificationPreProcess( + model_path=engine.model_path, image_size=image_size + ) + postprocess = ImageClassificationPostProcess( + top_k=top_k, class_names=class_names + ) + + ops = [preproces, engine, postprocess] + router = LinearRouter(end_route=len(ops)) + scheduler = [OperatorScheduler()] + super().__init__(ops=ops, router=router, schedulers=scheduler) diff --git a/src/deepsparse/v2/image_classification/postprocess_operator.py b/src/deepsparse/v2/image_classification/postprocess_operator.py new file mode 100644 index 0000000000..9231113368 --- /dev/null +++ b/src/deepsparse/v2/image_classification/postprocess_operator.py @@ -0,0 +1,81 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Dict, List, Union + +import numpy +from pydantic import BaseModel, Field + +from deepsparse.v2.operators import Operator + + +class ImageClassificationOutput(BaseModel): + """ + Output model for image classification + """ + + labels: List[Union[int, str, List[int], List[str]]] = Field( + description="List of labels, one for each prediction" + ) + scores: List[Union[float, List[float]]] = Field( + description="List of scores, one for each prediction" + ) + + +__all__ = ["ImageClassificationPostProcess"] + + +class ImageClassificationPostProcess(Operator): + """ + Image Classification post-processing Operator. This Operator is responsible for + processing outputs from the engine and returning the classification results to + the user, using the ImageClassifcationOutput structure. + """ + + input_schema = None + output_schema = ImageClassificationOutput + + def __init__( + self, top_k: int = 1, class_names: Union[None, str, Dict[str, str]] = None + ): + self.top_k = top_k + if isinstance(class_names, str) and class_names.endswith(".json"): + self._class_names = json.load(open(class_names)) + elif isinstance(class_names, dict): + self._class_names = class_names + else: + self._class_names = None + + def run(self, inp: "EngineOperatorOutputs", **kwargs) -> Dict: # noqa: F821 + labels, scores = [], [] + inp = inp.engine_outputs + for prediction_batch in inp[0]: + label = (-prediction_batch).argsort()[: self.top_k] + score = prediction_batch[label] + labels.append(label) + scores.append(score.tolist()) + + if self._class_names is not None: + labels = numpy.vectorize(self._class_names.__getitem__)(labels) + labels = labels.tolist() + + if isinstance(labels[0], numpy.ndarray): + labels = [label.tolist() for label in labels] + + if len(labels) == 1: + labels = labels[0] + scores = scores[0] + + return {"scores": scores, "labels": labels} diff --git a/src/deepsparse/v2/image_classification/preprocess_operator.py b/src/deepsparse/v2/image_classification/preprocess_operator.py new file mode 100644 index 0000000000..9b4517a44c --- /dev/null +++ b/src/deepsparse/v2/image_classification/preprocess_operator.py @@ -0,0 +1,149 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional, Tuple + +import numpy +import onnx +from PIL import Image +from torchvision import transforms + +from deepsparse.image_classification.constants import ( + IMAGENET_RGB_MEANS, + IMAGENET_RGB_STDS, +) +from deepsparse.pipelines.computer_vision import ComputerVisionSchema +from deepsparse.v2.operators import Operator + + +class ImageClassificationInput(ComputerVisionSchema): + """ + Input model for image classification + """ + + +__all__ = ["ImageClassificationPreProcess"] + + +class ImageClassificationPreProcess(Operator): + """ + Image Classification pre-processing operator. This Operator is expected to process + the user inputs and prepare them for the engine. Inputs to this Operator are + expected to follow the ImageClassificationInput schema. + """ + + input_schema = ImageClassificationInput + output_schema = None + + def __init__(self, model_path: str, image_size: Optional[Tuple[int]] = None): + self.model_path = model_path + self._image_size = image_size or self._infer_image_size() + non_rand_resize_scale = 256.0 / 224.0 # standard used + self._pre_normalization_transforms = transforms.Compose( + [ + transforms.Resize( + tuple( + [ + round(non_rand_resize_scale * size) + for size in self._image_size + ] + ) + ), + transforms.CenterCrop(self._image_size), + ] + ) + + def run(self, inp: ImageClassificationInput, **kwargs) -> Dict: + """ + Pre-Process the Inputs for DeepSparse Engine + + :param inputs: input model + :return: list of preprocessed numpy arrays + """ + + if isinstance(inp.images, numpy.ndarray): + image_batch = inp.images + else: + if isinstance(inp.images, str): + inp.images = [inp.images] + + image_batch = list(map(self._preprocess_image, inp.images)) + + # build batch + image_batch = numpy.stack(image_batch, axis=0) + + original_dtype = image_batch.dtype + image_batch = numpy.ascontiguousarray(image_batch, dtype=numpy.float32) + + if original_dtype == numpy.uint8: + image_batch /= 255 + # normalize entire batch + image_batch -= numpy.asarray(IMAGENET_RGB_MEANS).reshape((-1, 3, 1, 1)) + image_batch /= numpy.asarray(IMAGENET_RGB_STDS).reshape((-1, 3, 1, 1)) + + return {"engine_inputs": [image_batch]} + + def _preprocess_image(self, image) -> numpy.ndarray: + if isinstance(image, List): + # image given as raw list + image = numpy.asarray(image) + if image.dtype == numpy.float32: + # image is already processed, append and continue + return image + # assume raw image input + # put image in PIL format for torchvision processing + image = image.astype(numpy.uint8) + if image.shape[0] < image.shape[-1]: + # put channel last + image = numpy.einsum("cwh->whc", image) + image = Image.fromarray(image) + elif isinstance(image, str): + # load image from string filepath + image = Image.open(image).convert("RGB") + elif isinstance(image, numpy.ndarray): + image = image.astype(numpy.uint8) + if image.shape[0] < image.shape[-1]: + # put channel last + image = numpy.einsum("cwh->whc", image) + image = Image.fromarray(image) + + if not isinstance(image, Image.Image): + raise ValueError( + f"inputs to {self.__class__.__name__} must be a string image " + "file path(s), a list representing a raw image, " + "PIL.Image.Image object(s), or a numpy array representing" + f"the entire pre-processed batch. Found {type(image)}" + ) + + # apply resize and center crop + image = self._pre_normalization_transforms(image) + image_numpy = numpy.array(image) + image.close() + + # make channel first dimension + image_numpy = image_numpy.transpose(2, 0, 1) + return image_numpy + + def _infer_image_size(self) -> Tuple[int, ...]: + """ + Infer and return the expected shape of the input tensor + + :return: The expected shape of the input tensor from onnx graph + """ + onnx_model = onnx.load(self.model_path) + input_tensor = onnx_model.graph.input[0] + return ( + input_tensor.type.tensor_type.shape.dim[2].dim_value, + input_tensor.type.tensor_type.shape.dim[3].dim_value, + ) diff --git a/src/deepsparse/v2/operators/__init__.py b/src/deepsparse/v2/operators/__init__.py new file mode 100644 index 0000000000..ae14f2a373 --- /dev/null +++ b/src/deepsparse/v2/operators/__init__.py @@ -0,0 +1,19 @@ +# flake8: noqa +# isort: skip_file + +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .operator import * +from .engine_operator import * +from .registry import * diff --git a/src/deepsparse/v2/operators/engine_operator.py b/src/deepsparse/v2/operators/engine_operator.py new file mode 100644 index 0000000000..630de2d5bd --- /dev/null +++ b/src/deepsparse/v2/operators/engine_operator.py @@ -0,0 +1,184 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from typing import Dict, List, Optional, Union + +from pydantic import BaseModel, Field + +from deepsparse import Context as EngineContext +from deepsparse import Engine, MultiModelEngine, Scheduler +from deepsparse.benchmark import ORTEngine +from deepsparse.utils import join_engine_outputs, model_to_path, split_engine_inputs +from deepsparse.v2.operators import Operator + + +DEEPSPARSE_ENGINE = "deepsparse" +ORT_ENGINE = "onnxruntime" + +SUPPORTED_PIPELINE_ENGINES = [DEEPSPARSE_ENGINE, ORT_ENGINE] + +__all__ = ["EngineOperator", "EngineOperatorInputs", "EngineOperatorOutputs"] + + +class EngineOperatorInputs(BaseModel): + engine_inputs: List = Field(description="engine_inputs") + engine: Optional[Union[ORTEngine, Engine]] = Field( + description="override the engine to run forward pass with", + default=None, + ) + + @classmethod + def join(cls, inputs: List["EngineOperatorInputs"]) -> "EngineOperatorInputs": + """ + :param inputs: list of separate EngineOperatorInputs, batch size must be 1 + :return: list of inputs joined into a single input with a multi batch size + """ + all_engine_inputs = [engine_input.engine_inputs for engine_input in inputs] + + for engine_inputs in all_engine_inputs: + if engine_inputs[0].shape[0] != 1: + raise RuntimeError( + "join requires all inputs to have batch size 1, found input with " + f"batch size {engine_inputs[0].shape[0]}" + ) + + # use join_engine_outputs since dtype is the same + joined_engine_inputs = join_engine_outputs( + all_engine_inputs, len(all_engine_inputs) + ) + + return cls(engine_inputs=joined_engine_inputs) + + class Config: + arbitrary_types_allowed = True + + +class EngineOperatorOutputs(BaseModel): + engine_outputs: List = Field(description="engine outputs") + + def split(self) -> List["EngineOperatorOutputs"]: + """ + :return: list of the current outputs split to a batch size of 1 each + """ + # using split_engine_inputs since input/output dtypes + # are the same (List[ndarray]) + split_outputs, _ = split_engine_inputs(self.engine_outputs, batch_size=1) + + return [self.__class__(engine_outputs=outputs) for outputs in split_outputs] + + +class EngineOperator(Operator): + input_schema = EngineOperatorInputs + output_schema = EngineOperatorOutputs + + def __init__( + self, + model_path: str, + engine_type: str = DEEPSPARSE_ENGINE, + num_cores: int = None, + num_streams: int = None, + scheduler: Scheduler = None, + input_shapes: List[List[int]] = None, + engine_context: Optional[EngineContext] = None, + engine_kwargs: Dict = None, + ): + self.model_path = model_to_path(model_path) + self.engine_context = engine_context + self._batch_size = 1 + + if self.engine_context is not None: + num_cores = num_cores or self.engine_context.num_cores + if self.engine_context.num_cores != num_cores: + raise ValueError( + f"num_cores mismatch. Expected {self.engine_context.num_cores} " + f"from passed context, but got {num_cores} while " + f"instantiating Pipeline" + ) + + engine_args = dict( + batch_size=self._batch_size, + num_cores=num_cores, + input_shapes=input_shapes, + ) + if engine_type.lower() == DEEPSPARSE_ENGINE: + engine_args["scheduler"] = scheduler + engine_args["num_streams"] = num_streams + + self._engine_args = engine_args + self._engine_type = engine_type + + if not engine_kwargs: + engine_kwargs = {} + + self.engine = self.create_engine(**engine_kwargs) + + @property + def batch_size(self) -> int: + """ + :return: the batch size this engine operator is compiled at + """ + return self._batch_size + + # TODO: maybe add a few args to make this less opaque? + def create_engine( + self, + **kwargs, + ) -> Union[Engine, MultiModelEngine, ORTEngine]: + """ + Create an inference engine for a given ONNX model + + :param kwargs: overrides to engine_args used as kwargs for engine + constructor/compilation + :return: inference engine + """ + + onnx_file_path = kwargs.pop("model_path", self.model_path) + engine_args = deepcopy(self._engine_args) + engine_args.update(kwargs) + engine_type = self._engine_type.lower() + + if engine_type == DEEPSPARSE_ENGINE: + if self.engine_context is not None and isinstance( + self.engine_context, EngineContext + ): + engine_args.pop("num_cores", None) + engine_args.pop("scheduler", None) + engine_args.pop("num_streams", None) + engine_args["context"] = self.engine_context + return MultiModelEngine( + model=onnx_file_path, + **engine_args, + ) + engine_args.pop("cache_output_bools", None) + return Engine(onnx_file_path, **engine_args) + + if engine_type == ORT_ENGINE: + return ORTEngine(onnx_file_path, **engine_args) + + raise ValueError( + f"Unknown engine_type {engine_type}. Supported values include: " + f"{SUPPORTED_PIPELINE_ENGINES}" + ) + + def run(self, inp: EngineOperatorInputs, **kwargs) -> Dict: + if inp.engine: + # run with custom engine, do not split/join since custom engine + # may run at any batch size, returning here as code below has a + # planned refactor + engine_outputs = inp.engine(inp.engine_inputs) + return {"engine_outputs": engine_outputs} + + engine_outputs = self.engine(inp.engine_inputs) + return {"engine_outputs": engine_outputs} diff --git a/src/deepsparse/v2/operators/operator.py b/src/deepsparse/v2/operators/operator.py new file mode 100644 index 0000000000..e775056f8f --- /dev/null +++ b/src/deepsparse/v2/operators/operator.py @@ -0,0 +1,135 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Any, Optional, Type + +from pydantic import BaseModel + +from deepsparse.v2.operators.registry import OperatorRegistry +from deepsparse.v2.utils import InferenceState + + +__all__ = ["Operator"] + + +class Operator(ABC): + """ + Base operator class - an operator should be defined for each atomic, functional + part of the pipeline. + """ + + # expected structured input and output types, to be defined by child classes + input_schema: Optional[Type[BaseModel]] = None + output_schema: Optional[Type[BaseModel]] = None + + @classmethod + def has_input_schema(cls) -> bool: + """ + :return: True if this class has a defined pydantic input schema + """ + if not cls.input_schema: + return False + + return issubclass(cls.input_schema, BaseModel) + + @classmethod + def has_output_schema(cls) -> bool: + """ + :return: True if this class has a defined pydantic input schema + """ + if not cls.output_schema: + return False + + return issubclass(cls.output_schema, BaseModel) + + def __call__( + self, + *args, + inference_state: InferenceState, + **kwargs, + ) -> Any: + """ + Parses inputs to this Operator and runs the run() method of this operator + + :param args: an unnamed arg may only be provided if it is of the type of the + input_schema + :param inference_state: inference_state for the pipeline. + :param pipeline_state: pipeline_state for the pipeline. The values in the state + are created during pipeline creation and are read-only during inference. + :param kwargs: kwargs when not initializing from an instantiated schema + :return: operator output + """ + if self.has_input_schema(): + if len(args) > 1: + raise ValueError( + f"The operator requires an {self.input_schema}. Too many arguments" + "provided." + ) + elif args and isinstance(args[0], self.input_schema): + inference_input = args[0] + elif kwargs: + inference_input = self.input_schema(**kwargs) + else: + raise ValueError( + "Can't resolve inputs. The values for the schema must be provided" + "in the form of a dictionary or an instance of the input_schema" + "object" + ) + run_output = self.run( + inference_input, + inference_state=inference_state, + ) + else: + run_output = self.run( + *args, + inference_state=inference_state, + **kwargs, + ) + if self.has_output_schema(): + return self.output_schema(**run_output) + return run_output + + @staticmethod + def create( + task: str, + **kwargs, + ) -> "Operator": + """ + :param task: Operator task + :param kwargs: extra task specific kwargs to be passed to task Operator + implementation + :return: operator object initialized for the given task + """ + operator_constructor = OperatorRegistry.get_task_constructor(task) + return operator_constructor(**kwargs) + + @abstractmethod + def run(self, *args, **kwargs) -> Any: + """ + :return: result of this operator as the defined output schema if applicable + """ + raise NotImplementedError + + def can_operate(self, inp: Any) -> bool: + """ + Whether or not the given operator can run, based on input + """ + return True + + def yaml(self): + pass + + def json(self): + pass diff --git a/src/deepsparse/v2/operators/registry.py b/src/deepsparse/v2/operators/registry.py new file mode 100644 index 0000000000..1b83b20728 --- /dev/null +++ b/src/deepsparse/v2/operators/registry.py @@ -0,0 +1,76 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Type + +from deepsparse.v2.task import SupportedTasks, dynamic_import_task +from sparsezoo.utils.registry import ( + RegistryMixin, + get_from_registry, + register, + registered_names, +) + + +__all__ = ["OperatorRegistry"] + + +class OperatorRegistry(RegistryMixin): + """ + Register operators with given task name(s). Leverages the RegistryMixin + functionality. + """ + + @classmethod + def register_value(cls, operator, name): + from deepsparse.v2.operators import Operator + + if not isinstance(name, list): + name = [name] + + for task_name in name: + register(Operator, operator, task_name, require_subclass=True) + + return operator + + @classmethod + def get_task_constructor(cls, task: str) -> Type["Operator"]: # noqa: F821 + """ + This function retrieves the class previously registered via + `OperatorRegistry.register` for `task`. + + If `task` starts with "import:", it is treated as a module to be imported, + and retrieves the task via the `TASK` attribute of the imported module. + + If `task` starts with "custom", then it is mapped to the "custom" task. + + :param task: The task name to get the constructor for + :return: The class registered to `task` + :raises ValueError: if `task` was not registered via `OperatorRegistry.register` + """ + from deepsparse.v2.operators import Operator + + if task.startswith("import:"): + # dynamically import the task from a file + task = dynamic_import_task(module_or_path=task.replace("import:", "")) + elif task.startswith("custom"): + # support any task that has "custom" at the beginning via the "custom" task + task = "custom" + else: + task = task.lower().replace("-", "_") + + tasks = registered_names(Operator) + # step needed to import relevant files required to load the operator + SupportedTasks.check_register_task(task, tasks) + return get_from_registry(Operator, task) diff --git a/src/deepsparse/v2/pipeline.py b/src/deepsparse/v2/pipeline.py new file mode 100644 index 0000000000..59970b2820 --- /dev/null +++ b/src/deepsparse/v2/pipeline.py @@ -0,0 +1,299 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import copy +from concurrent.futures import Future +from typing import Any, Dict, List, Optional, Union + +from deepsparse.v2.operators import EngineOperator, Operator +from deepsparse.v2.routers import Router +from deepsparse.v2.schedulers import ( + ContinuousBatchingScheduler, + OperatorScheduler, + SchedulerGroup, +) +from deepsparse.v2.utils import InferenceState, PipelineState +from deepsparse.v2.utils.data import SubGraph +from deepsparse.v2.utils.helpers import run_func + + +__all__ = ["Pipeline"] + + +class Pipeline(Operator): + """ + Pipeline accepts a series of operators, schedulers, and a router. Calling a pipeline + will use the router to run through all the defined operators. The operators should + be implemented using the Operator class and each implemented operator should be + responsible for a functional component of the pipelines. The flow of inputs/outputs + between the operators and the steps in the pipeline should be defined by the router, + (based off of the Router class), which dicates the next operator in the pipeline. + Execution of the operators will be handled by the provided schedulers. + + :param ops: Operators to run within the pipeline. Can either be a list of operators + or dictionary of operators. + :param router: A Router which dictates the next operator to call. + :param schedulers: A list of schedulers to run operators. + :param pipeline_state: pipeline_state created during pipeline initialization + + """ + + def __init__( + self, + ops: Union[Dict[str, Operator], List[Operator]], + router: Router, + schedulers: List[OperatorScheduler], + continuous_batching_scheduler: Optional[ContinuousBatchingScheduler] = None, + pipeline_state: PipelineState = None, + ): + + self.ops = ops + self.router = router + self.schedulers = schedulers + self.pipeline_state = pipeline_state + self._continuous_batching_scheduler = continuous_batching_scheduler + self.validate() + + self._scheduler_group = SchedulerGroup(self.schedulers) + + def _run_next( + self, + inp: Any, + inference_state: InferenceState, + next_step: str, + ): + if ( + isinstance(self.ops[next_step], EngineOperator) + and self._continuous_batching_scheduler + ): + func = self._continuous_batching_scheduler.submit + inp = self.ops[next_step].input_schema(**inp) + else: + func = self._scheduler_group.submit + + return run_func( + func=func, + operator=self.ops[next_step], + inp=inp, + pipeline_state=self.pipeline_state, + inference_state=inference_state, + ) + + def _run_sub_graphs( + self, sub_graph_inputs: List[Any], sub_graphs: List[SubGraph] + ) -> List[Any]: + """ + Run a list of sub_graphs asynchronously. Polls to identify the sub graph that is + still running but has completed its current step. Schedules the next step + subgraph step. This is repeated until all subgraphs have finished running and + have reached their end step (stored in the Subgraph.end attribute). + + :param sub_graph_inputs: A list of inputs that should be passed to each + subgraph. Each subgraph is given an element of the list as input to its + first node. + :param sub_graphs: A list of Subgraph objects. Each stores the relevant + execution information for the particular subgraph, such as its current step + in the sub graph, inference state, output, and end step. + + :returns: a list of outputs for all the completed Subgraph objects. Returned + in the same order that the subgraphs were passed to the function. + """ + for i in range(len(sub_graphs)): + sub_graphs[i].output = self._run_next( + sub_graph_inputs[i], sub_graphs[i].inf, sub_graphs[i].step + ) + + # Execute all sub graphs until all graphs have been completed. + while True: + for sub_graph in sub_graphs: + if isinstance(sub_graph.output, Future) and sub_graph.output.done(): + # get the result for the completed operator; resolve its output + operator_output = sub_graph.output.result() + operator_output = sub_graph.parse_output(operator_output) + + # determine the next step for the particular operator, using + # its previous output and previously stored step + next_step = self.router.next( + sub_graph.step, self.ops, operator_output + ) + # update the step + sub_graph.step = next_step + + # store the output for the next step. If the next step is + # end step, this particular route has completed. Simply + # update the output value + if next_step in sub_graph.end: + sub_graph.output = operator_output + else: + sub_graph.output = self._run_next( + inp=operator_output, + inference_state=sub_graph.inf, + next_step=next_step, + ) + break + + # keep running until all sub graphs have completed. + if not any(isinstance(x.output, Future) for x in sub_graphs): + break + + return [x.output for x in sub_graphs] + + def _apply_split(self, inp: Any, inference_state: InferenceState): + """ + Split inputs using the pipeline's expand_inputs function. Inputs are split + into a batch size of one when a SPLIT_ROUTE node is found in a given pipeline's + provided router. The split batches are run asynchronously and then joined when + a JOIN_ROUTE node is found, using the pipeline's condense_inputs function. + """ + + batches, orig_batch_size = self.expand_inputs(inp, 1) + + # Create a list of SplitRoutes, per batch size 1 + # Each SplitRoute object holds information about the particular path it + # follows. All start at the same step defined by SPLIT_ROUTE and start + # with the same inference_state. + split_graphs = [ + SubGraph( + inf=copy.deepcopy(inference_state), + step=self.router.route[self.router.SPLIT_ROUTE], + end=[self.router.JOIN_ROUTE], + ) + for i in range(len(batches)) + ] + + outputs = self._run_sub_graphs( + sub_graph_inputs=batches, sub_graphs=split_graphs + ) + return self.condense_inputs(outputs) + + def run( + self, + *args, + inference_state: InferenceState, + **kwargs, + ): + """ + Run through the operators using the provided router and scheduler. + The input to a given operator is the output of the previous operator. + + :param inference_state: inference_state for the pipeline. + :param pipeline_state: pipeline_state for the pipeline. The values in the state + are created during pipeline creation and are read-only during inference. + """ + next_step = self.router.START_ROUTE + operator_output = None + while next_step != self.router.END_ROUTE: + + # Split Grap Execution (i.e multiple subgraphs) + # NOTE: split_route should only appear after the start route node + if next_step == self.router.SPLIT_ROUTE: + if operator_output is None: + raise ValueError( + f"{self.router.SPLIT_ROUTE} should appear after " + f"{self.ROUTER.START_ROUTE}" + ) + + operator_output = self._apply_split(operator_output, inference_state) + next_step = self.router.route[self.router.JOIN_ROUTE] + if next_step == self.router.END_ROUTE: + return operator_output + + if next_step == self.router.START_ROUTE: + operator_output = run_func( + *args, + func=self._scheduler_group.submit, + operator=self.ops[next_step], + inference_state=inference_state, + pipeline_state=self.pipeline_state, + **kwargs, + ).result() + + if isinstance(operator_output, tuple): + operator_output, state_update = ( + operator_output[0], + operator_output[-1], + ) + inference_state.update_state(state_update) + + next_step = self.router.next(next_step, self.ops, operator_output) + + else: + # Single graph execution + graph = SubGraph( + inf=copy.deepcopy(inference_state), + step=next_step, + end=[self.router.SPLIT_ROUTE, self.router.END_ROUTE], + ) + + operator_output = self._run_sub_graphs( + sub_graph_inputs=[operator_output], sub_graphs=[graph] + )[0] + + inference_state = graph.inf + next_step = graph.step + + return operator_output + + def __call__(self, *args, **kwargs): + """ + Consolidate any provided inference_state or pipeline_state objects and pass + any other operator inputs to run(). + + :return: output of the pipeline operators ran with the router for the given + input + """ + if kwargs.get("inference_state"): + inference_state = kwargs.pop("inference_state") + else: + inference_state = InferenceState() + inference_state.create_state({}) + + kwargs["inference_state"] = inference_state + + return self.run(*args, **kwargs) + + def expand_inputs(self, *args, **kwargs): + """ + Generic function to handle expanding values. + """ + raise NotImplementedError( + "This function should be implemented for any router with split or join" + "nodes. expand_inputs will be called prior to the split node (stored in " + "the router's SPLIT_ROUTE attribute), expanding outputs for each output " + "such that there is a batch size of one per thread." + ) + + def condense_inputs(self, *args, **kwargs): + """ + Generic function to handle condensing values. + """ + raise NotImplementedError( + "This function should be implemented for any router with split or join " + "nodes. condense_inputs will be called after the join node (stored in the " + "router's JOIN_ROUTE attribute), condensing outputs from multiple threads." + ) + + def validate(self): + """ + Validate that compatability of the router and operators provided. + """ + router_validation = self.router.validate(self.ops) + + if router_validation is False: + # default error message + op_types = [type(op) for op in self.ops] + raise ValueError(f"Invalid Router: {type(self.router)} for ops: {op_types}") + elif isinstance(router_validation, str): + raise ValueError(f"Invalid Router for operators: {router_validation}") diff --git a/src/deepsparse/v2/routers/__init__.py b/src/deepsparse/v2/routers/__init__.py new file mode 100644 index 0000000000..8718bedeb4 --- /dev/null +++ b/src/deepsparse/v2/routers/__init__.py @@ -0,0 +1,17 @@ +# flake8: noqa + +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .router import * diff --git a/src/deepsparse/v2/routers/router.py b/src/deepsparse/v2/routers/router.py new file mode 100644 index 0000000000..5d0365fda9 --- /dev/null +++ b/src/deepsparse/v2/routers/router.py @@ -0,0 +1,168 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from abc import abstractmethod +from typing import Any, Dict, List, Optional, Union + +from deepsparse.v2.operators import Operator + + +_LOGGER = logging.getLogger(__name__) + +__all__ = ["Router", "LinearRouter", "GraphRouter"] + + +class Router: + """ + Routers dictate the next operator to run. Each Router must implement a next function, + which dictates the index or key of the next operator to run. + + :param start_route: the start index or key of the router + :param end_route: the end index or key of the router + :param route: the route that the router has to traverse through + + """ + + def __init__( + self, + end_route: Union[str, int], + start_route: Union[str, int], + route: Optional[Dict] = None, + split_route: str = "SPLIT", + join_route: str = "JOIN", + ): + self.START_ROUTE = start_route + self.END_ROUTE = end_route + self.SPLIT_ROUTE = split_route + self.JOIN_ROUTE = join_route + self.route = route + + @abstractmethod + def next( + self, + past: Union[str, int], + ops: Optional[Union[List[Operator], Dict[str, Operator]]], + inp: Optional[Any], + ) -> Union[str, int]: + """ + Determines the index or dictionary key for the next operator which should run. + + :param past: the previous index or key. This should uniquely determine the next + operator to run + :param ops: list or dictionary of operators + :param inp: operator input + :returns: the next index or dictionary key for the next operator to run + """ + raise NotImplementedError + + def yaml(self): + pass + + def json(self): + pass + + +class LinearRouter(Router): + """ + LinearRouter runs a list of Operators in sequential order. end_route should + be the length of the list and the start_route should be the start index. + """ + + def __init__(self, route: Optional[List[str]] = None, end_route: Optional[int] = None, start_route: int = 0): + if end_route is None: + if route is None: + raise ValueError("To define the number of steps in the LinearRouter " + "either `route` or `end_route` must be provided" + ) + + end_route = len(route) + super().__init__(end_route=end_route, start_route=start_route) + _LOGGER.warn("SPLIT and JOIN are not yet supported for the LinearRouter.") + + def next( + self, past: int, ops: Optional[List[Operator]] = None, inp: Optional[Any] = None + ) -> int: + new_index = past + 1 + if new_index < self.END_ROUTE: + return new_index + return self.END_ROUTE + + @staticmethod + def validate(operators: List[Operator]) -> bool: + """ + :param operators: operators that this Router could potentially run over + :return: True if this Router can run this series of operators. Base Router + runs any series of operators that is non-empty and whose input and output + schemas align. If not valid, either False or an error string will be + returned + """ + # Commented out - operators are dicts not lists + # if len(operators) < 1: + # _LOGGER.info("No operators provided") + # return False + # + # for idx in range(len(operators) - 1): + # current_output_schema = operators[idx].output_schema + # next_input_schema = operators[idx + 1].input_schema + # + # if current_output_schema is None or next_input_schema is None: + # # if no input/output schema defined, assume operator can run + # # without schema + # continue + # + # if current_output_schema != next_input_schema: + # _LOGGER.info( + # f"Operator at idx {idx}: {type(operators[idx])} has invalid " + # f"output schema {current_output_schema} for next operator " + # f"{type(operators[idx + 1])} which requires {next_input_schema}" + # ) + # return False + return True + + +class GraphRouter(Router): + """ + Router for a DAG. Expects graphs be presented in the form of a dictionary, where + keys are the nodes of the graph and the values are the connected nodes. For + nodes with multiple ouput edges, all the nodes will be visited and the first node + where `can_operate` returns True will run. Paths should be deterministic. + """ + + def __init__(self, end_route: str, start_route: str, route: Dict, **kwargs): + super().__init__( + end_route=end_route, start_route=start_route, route=route, **kwargs + ) + + def next( + self, + past: str, + ops: Dict[str, Operator], + inp: Any, + ) -> int: + node = past + if isinstance(self.route[node], str): + return self.route[node] + else: + for neighbour_node in self.route[node]: + neighbour_node_op = ops[neighbour_node] + if neighbour_node_op.can_operate(inp): + return neighbour_node + raise ValueError("Cannot operate on any of the nodes") + + @staticmethod + def validate(ops) -> bool: + # TODO: still needs to be implemented for the GraphRouter + pass diff --git a/src/deepsparse/v2/schedulers/__init__.py b/src/deepsparse/v2/schedulers/__init__.py new file mode 100644 index 0000000000..b4d78521ab --- /dev/null +++ b/src/deepsparse/v2/schedulers/__init__.py @@ -0,0 +1,20 @@ +# flake8: noqa +# isort: skip_file + +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .scheduler import * +from .scheduler_group import * +from .continuous_batching_scheduler import * diff --git a/src/deepsparse/v2/schedulers/continuous_batching_scheduler.py b/src/deepsparse/v2/schedulers/continuous_batching_scheduler.py new file mode 100644 index 0000000000..cc74ac0996 --- /dev/null +++ b/src/deepsparse/v2/schedulers/continuous_batching_scheduler.py @@ -0,0 +1,189 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from concurrent.futures import Future +from threading import Lock +from typing import List + +from deepsparse.v2.operators import EngineOperator, Operator +from deepsparse.v2.schedulers.scheduler import OperatorScheduler +from deepsparse.v2.schedulers.utils import ( + ContinuousBatchingExecutorThread, + ContinuousBatchingQueues, +) + + +__all__ = ["ContinuousBatchingScheduler"] + + +_GLOBAL_SCHEDULER = None + + +class ContinuousBatchingScheduler(OperatorScheduler): + """ + Manages EngineOperator jobs that should be run with continuous batching. + Groups requests for the same engine into larger batches and returns + the result to the respective request threads after scheduled completion + + Example code for getting or creating a shared instance for scheduling + between pipelines and adding an engine operator to the scheduler + within a pipeline + + ```python + + class MyPipeline(Pipeline): + + def __init__(self): + ... + engine_operator = EngineOperator(...) + ... + continuous_batching_scheduler = ContinuousBatchingScheduler.get_instance() + continuous_batching_scheduler.add_engine_operator(engine_operator, [1]) + + super.__init__(...) + ``` + + :param max_workers: maximum number of threads to execute at once, default 1 + """ + + # TODO: If the singleton always returns max_workers 1, should we remove this arg/not + # give the user a choice? + def __init__(self, max_workers: int = 1): + self._max_workers = max_workers + + self._mutex = Lock() + + # Dict[EngineOperator, Dict[batch_size, Engine]] + self._operators_to_engines = {} # EngineOperator -> Dict[batch_size, Engine] + self._queues = ContinuousBatchingQueues() + + # create and start max number of worker threads + self._threads = [ + ContinuousBatchingExecutorThread(self._queues, self._operators_to_engines) + for _ in range(self.max_workers) + ] + for worker_thread in self._threads: + worker_thread.start() + + @classmethod + def get_instance(cls) -> "ContinuousBatchingScheduler": + """ + :return: global instance of the continuous batching scheduler. If one + does not exist yet, a scheduler with a single worker thread to + schedule all jobs is created and started + """ + global _GLOBAL_SCHEDULER + + if _GLOBAL_SCHEDULER is not None: + return _GLOBAL_SCHEDULER # noqa: F823 + + _GLOBAL_SCHEDULER = cls(max_workers=1) + return _GLOBAL_SCHEDULER + + @property + def max_workers(self) -> int: + """ + :return: maximum number of threads to execute at once + """ + return self._max_workers + + def submit(self, *args, operator: Operator, **kwargs) -> Future: + """ + :param operator: operator to run + :param operator_input: input schema to the operator + :return: future referencing the asynchronously run output of the operator + """ + inputs = args[0] + if not isinstance(inputs, operator.input_schema): + raise ValueError( + "Inputs to ContinuousBatchingScheduler must be the specific " + f"input schema to the given operator. Expected {operator.input_schema}" + f"found {type(inputs)}" + ) + + future = Future() + self._queues.add_queue_item(key=operator, item=inputs, future=future) + + return future + + def can_process(self, *args, operator: Operator, **kwargs) -> bool: + """ + :param operator: operator to check + :param operator_input: operator_input to check + :return: True if this Operator can process the given operator and input. + SchedulerGroup always returns True + """ + return operator in self._operators_to_engines and operator in self._queues + + def add_engine_operator( + self, engine_operator: EngineOperator, batch_sizes: List[int] + ): + """ + Adds tracking for an engine operator to this scheduler + with continuous batching for the given sizes + + :param engine_operator: an EngineOperator, must be compiled with + batch_size=1 + :param batch_sizes: batch sizes to use for continuous batching + """ + # lock updates to _operators_to_engines while updating + self._mutex.acquire() + + # validation + if engine_operator in self._operators_to_engines: + # operator already added + return + + if not isinstance(engine_operator, EngineOperator): + raise ValueError( + f"Expected an EngineOperator instance, found {type(engine_operator)}" + ) + if engine_operator.batch_size != 1: + raise ValueError( + "For continuous batching, EngineOperator must have batch_size=1. " + f"found batch_size={engine_operator.batch_size}" + ) + + # build EngineOperator -> List[batch_size] dict + operator_engines = {} + # base engine, expected batch size is 1 + operator_engines[engine_operator.batch_size] = engine_operator.engine + + # compile auxillary engines for continuous batching + for batch_size in batch_sizes: + if batch_size == 1: + continue # already added + + override_model_path = None + # text generation/NLEngineOperator specific; could add generic method + # for all engine_operators, if desired + if hasattr(engine_operator, "override_model_inputs"): + override_model_path = engine_operator.override_model_inputs( + model_path=engine_operator.model_path, batch_size=batch_size + ) + + # will break for internal kv_cache; needs additional argument + operator_engines[batch_size] = engine_operator.create_engine( + batch_size=batch_size, model_path=override_model_path + ) + + self._operators_to_engines[engine_operator] = operator_engines + self._queues.add_queue( + key=engine_operator, + batch_sizes=list(operator_engines.keys()), + ) + + # release lock + self._mutex.release() diff --git a/src/deepsparse/v2/schedulers/scheduler.py b/src/deepsparse/v2/schedulers/scheduler.py new file mode 100644 index 0000000000..5313683107 --- /dev/null +++ b/src/deepsparse/v2/schedulers/scheduler.py @@ -0,0 +1,77 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from concurrent.futures import Future, ThreadPoolExecutor +from typing import Callable + +from deepsparse.v2.operators import Operator + + +__all__ = ["OperatorScheduler"] + + +class OperatorScheduler: + """ + OperatorSchedulers should implement a `submit` function that asynchronously + runs an operator and its input and returns a Future. Priority of operators + to run and resources they are run on are deferred to specific OperatorScheduler + implementations + + Base OperatorScheduler behaves as a simple queue deferring to ThreadPoolExecutor + + :param max_workers: maximum number of threads to execute at once + """ + + def __init__(self, max_workers: int = 1): + self._threadpool = ThreadPoolExecutor(max_workers=max_workers) + + def submit( + self, + *args, + operator: Operator, + **kwargs, + ) -> Future: + """ + :param operator: operator to run + :return: future referencing the asynchronously run output of the operator + """ + return self._threadpool.submit( + operator, + *args, + **kwargs, + ) + + def can_process( + self, + *args, + operator: Operator, + **kwargs, + ) -> bool: + """ + :param operator: operator to check + :return: True if this Operator can process the given operator and input. + Base OperatorScheduler always returns True + """ + return True + + def map(self, *args, func: Callable): + """ + :param func: generic callable run for each arg + :return: list of futures for each submit + """ + futures = [] + for _, values in enumerate(zip(*args)): + futures.append(self.submit(*values, operator=func)) + return futures diff --git a/src/deepsparse/v2/schedulers/scheduler_group.py b/src/deepsparse/v2/schedulers/scheduler_group.py new file mode 100644 index 0000000000..14d869a0f2 --- /dev/null +++ b/src/deepsparse/v2/schedulers/scheduler_group.py @@ -0,0 +1,57 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from concurrent.futures import Future +from typing import List + +from deepsparse.v2.operators import Operator +from deepsparse.v2.schedulers.scheduler import OperatorScheduler + + +__all__ = ["SchedulerGroup"] + + +class SchedulerGroup(OperatorScheduler): + """ + Wrapper for a series of schedulers. Runs submitted operators on the first + scheduler that can process a given input + + :param schedulers: list of schedulers to pass operators to + """ + + def __init__(self, schedulers: List[OperatorScheduler]): + self.schedulers = schedulers + + def submit( + self, + *args, + operator: Operator, + **kwargs, + ) -> Future: + """ + :param operator: operator to run + :return: future referencing the asynchronously run output of the operator + """ + for scheduler in self.schedulers: + if scheduler.can_process( + *args, + operator=operator, + **kwargs, + ): + return scheduler.submit( + *args, + operator=operator, + **kwargs, + ) diff --git a/src/deepsparse/v2/schedulers/utils/__init__.py b/src/deepsparse/v2/schedulers/utils/__init__.py new file mode 100644 index 0000000000..521341a7fc --- /dev/null +++ b/src/deepsparse/v2/schedulers/utils/__init__.py @@ -0,0 +1,19 @@ +# flake8: noqa +# isort: skip_file + +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .continuous_batching_queues import * +from .continuous_batching_executor import * diff --git a/src/deepsparse/v2/schedulers/utils/continuous_batching_executor.py b/src/deepsparse/v2/schedulers/utils/continuous_batching_executor.py new file mode 100644 index 0000000000..40ff00ca4f --- /dev/null +++ b/src/deepsparse/v2/schedulers/utils/continuous_batching_executor.py @@ -0,0 +1,79 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from threading import Thread +from typing import Dict + +from deepsparse import Engine +from deepsparse.v2.operators import EngineOperator +from deepsparse.v2.schedulers.utils.continuous_batching_queues import ( + ContinuousBatchingQueues, +) + + +__all__ = [ + "ContinuousBatchingExecutorThread", +] + + +class ContinuousBatchingExecutorThread(Thread): + """ + Thread that when started runs indefinitely, grabbing a valid batch from + the queues when possible and running them in the correct engine + + :param queues: ContinuousBatchingQueues object containing a queue for + each valid engine + :param operators_to_engines: dictionary mapping valid engine operators + to a dictionary of its valid batch sizes mapped to an engine compiled + for that batch size + """ + + def __init__( + self, + queues: ContinuousBatchingQueues, + operators_to_engines: Dict[EngineOperator, Dict[int, Engine]], + ): + self._queues = queues + self._operators_to_engines = operators_to_engines + self._should_stop = False + + super().__init__(target=self._working_loop) + self.daemon = True # worker thread should exit when main thread exits + + def _working_loop(self): + # indefinitely wait for batch, run batch, split and resolve futures + while True: + # wait for next batch to be available + engine_operator, batch = self._queues.pop_batch(block=True) + + # unpack batch of QueueEntry objects + engine_inputs, futures, _ = list(zip(*batch)) + batch_size = len(engine_inputs) + + # type is EngineOperatorInputs + joined_inputs = engine_operator.input_schema.join(engine_inputs) + + # get engine for this operator compiled to the popped batch size + # and set the inputs to execute with it + joined_inputs.engine = self._operators_to_engines[engine_operator][ + batch_size + ] + + # run the engine operator with the given engine at the joined batch size + joined_outputs = engine_operator(joined_inputs, inference_state=None) + + # split outputs and return the results to their respective futures + split_outputs = joined_outputs.split() + for output, future in zip(split_outputs, futures): + future.set_result(output) diff --git a/src/deepsparse/v2/schedulers/utils/continuous_batching_queues.py b/src/deepsparse/v2/schedulers/utils/continuous_batching_queues.py new file mode 100644 index 0000000000..84d4f38e3d --- /dev/null +++ b/src/deepsparse/v2/schedulers/utils/continuous_batching_queues.py @@ -0,0 +1,220 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from concurrent.futures import Future +from queue import Queue +from threading import Condition, Lock +from time import time +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple + + +__all__ = [ + "ContinuousBatchingQueue", + "ContinuousBatchingQueues", + "QueueEntry", +] + + +# maximum wait time of longest item in queue before it is prioritized +_MAX_WAIT_MS = 100 + + +class QueueEntry(NamedTuple): + value: Any + future: Optional[Future] + entry_time_ms: float + + def time_elapsed(self) -> float: + return _current_time_ms() - self.entry_time_ms + + +class ContinuousBatchingQueue(Queue): + """ + Extension of queue.Queue with helper functions for dequeueing valid + batch sizes for continuous batching + + :param batch_sizes: valid batch sizes that can be grouped for continuous + batching + """ + + def __init__(self, batch_sizes: List[int], *args, **kwargs): + super().__init__(*args, **kwargs) + + self._batch_sizes = batch_sizes + self._min_batch_size = min(self.batch_sizes) + + @property + def batch_sizes(self) -> List[int]: + """ + :return: valid batch sizes that this queue can return + """ + return self._batch_sizes + + def pop_batch(self) -> List[Any]: + """ + :return: + """ + batch_size = self.max_queued_batch_size() + if batch_size == 0: + raise RuntimeError( + f"Cannot create a batch with {self.qsize()} entries and valid " + f"batch sizes: {self.batch_sizes}" + ) + + return [self.get() for _ in range(batch_size)] + + def has_batch(self) -> bool: + """ + :return: True if a batch of valid size can be filled with the current qsize + """ + return self.qsize() >= self._min_batch_size + + def max_queued_batch_size(self) -> int: + """ + :return: the maximum batch size that can be filled by members of this queue + """ + num_entries = self.qsize() + max_size = 0 + + for batch_size in self.batch_sizes: + if num_entries >= batch_size > max_size: + # current batch size can be satisfied and is the largest so far + max_size = batch_size + + return max_size + + def peek(self): + """ + :return: threadsafe peek of the first item in the queue + """ + with self.mutex: + return self.queue[0] + + +class ContinuousBatchingQueues: + """ + Threadsafe collection of Queues designed to support continuous batching. + Each Queue should be keyed by an operator where possible, however keys + are kept generic. + + On request for next - a job will be returned with an operator key and + a batch of inputs. The default heuristic for the next job will be + a combination of wait time and largest batch that can be run + """ + + def __init__(self): + self._queues = {} # Dict[Any, ContinuousBatchingQueue] + self._mutex = Lock() + + # add condition for wait/notify when an item is added to any queue + self._item_added = Condition(self._mutex) + + def __contains__(self, key: Any) -> bool: + """ + :param key: key to look up + :return: True if the given key has a queue in this group + """ + with self._mutex: + return key in self._queues + + def add_queue(self, key: Any, batch_sizes: List[int]): + """ + Adds a queue for a single operator that can be run at multiple batch sizes + + :param key: key to identify queue with, preferably the engine operator + :param batch_sizes: batch sizes that the operator can be run at + """ + with self._mutex: + self._queues[key] = ContinuousBatchingQueue(batch_sizes=batch_sizes) + + def add_queue_item(self, key: Any, item: Any, future: Optional[Future] = None): + """ + Adds an item to the given queue + + :param key: key for queue to add to + :param item: item to add in queue + :param future: optional future that should be used for resolution of value + """ + if key not in self: + raise KeyError(f"Cannot add item to queue for unregistered key {key}") + + entry = QueueEntry(value=item, future=future, entry_time_ms=_current_time_ms()) + + with self._mutex: + self._queues[key].put(entry) + self._item_added.notify() + + def has_next_batch(self) -> bool: + """ + :return: true if any Queue has enough entries to fill a valid batch size + """ + with self._mutex: + return any(queue.has_batch() for queue in self._queues.values()) + + def pop_batch( + self, + select_fn: Callable[[Dict[Any, ContinuousBatchingQueue]], Any] = None, + block: bool = True, + ) -> Tuple[Any, List[QueueEntry]]: + """ + :param select_fn: function that takes in a dictionary of queue key + (i.e. EngineOperator) to its ContinuousBatchingQueue of QueueItem + objects and returns the key of the queue that should be returned. + Only keys with queues large enough to fill a batch will be given. + If not provided, the default select_fn will return the queue that + can fill the largest batch size, or the queue that has the first item + with the longest wait time if that time is over 100ms. + :param block: if True, will wait for a valid batch to be in a queue before + popping and returning, if False, will raise an error if a full batch + cannot be popped. Default True + :return: Tuple of the queue key (EngineOperator) and + batch of QueueEntry objects as a list that have been popped and should + be run as a batch + """ + with self._mutex: + while not (valid_queues := self._filter_empty_queues()): + if block: + # wait to search for a valid queue again until a new item is added + self._item_added.wait() + else: + raise RuntimeError( + "Cannot pop_batch when no queues have enough items to fill " + "a valid batch size, check with has_next_batch before calling " + "pop_batch" + ) + + select_fn = select_fn or _default_select_fn + selected_key = select_fn(valid_queues) + + return selected_key, self._queues[selected_key].pop_batch() + + def _filter_empty_queues(self) -> Dict[Any, ContinuousBatchingQueue]: + return {key: queue for key, queue in self._queues.items() if queue.has_batch()} + + +def _default_select_fn(queues: Dict[Any, ContinuousBatchingQueue]) -> Any: + # find the maximum wait time of a queue + wait_times = [(key, queue.peek().time_elapsed()) for key, queue in queues.items()] + max_wait_key, max_wait = max(wait_times, key=lambda x: x[1]) # key on time + + if max_wait >= _MAX_WAIT_MS: + # if max time is greater than the threshold return that queue + return max_wait_key + + # default to the largest batch size that can be satisfied + return max(queues.keys(), key=lambda key: queues[key].max_queued_batch_size()) + + +def _current_time_ms(): + return time() * 1000 diff --git a/src/deepsparse/v2/task.py b/src/deepsparse/v2/task.py new file mode 100644 index 0000000000..f1f4fc6d66 --- /dev/null +++ b/src/deepsparse/v2/task.py @@ -0,0 +1,204 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Classes and implementations for supported tasks in the DeepSparse pipeline and system +""" + +import importlib +import logging +import os +import sys +from collections import namedtuple +from typing import Iterable, List, Optional, Tuple + + +_LOGGER = logging.getLogger(__name__) + +__all__ = ["SupportedTasks", "AliasedTask"] + + +class AliasedTask: + """ + A task that can have multiple aliases to match to. + For example, question_answering which can alias to qa as well + + :param name: the name of the task such as question_answering or text_classification + :param aliases: the aliases the task can go by in addition to the name such as + qa, glue, sentiment_analysis, etc + """ + + def __init__(self, name: str, aliases: List[str]): + self._name = name + self._aliases = aliases + + @property + def name(self) -> str: + """ + :return: the name of the task such as question_answering + """ + return self._name + + @property + def aliases(self) -> List[str]: + """ + :return: the aliases the task can go by such as qa, glue, sentiment_analysis + """ + return self._aliases + + def matches(self, task: str) -> bool: + """ + :param task: the name of the task to check whether the given instance matches. + Checks the current name as well as any aliases. + Everything is compared at lower case and "-" and whitespace + are replaced with "_". + :return: True if task does match the current instance, False otherwise + """ + task = task.lower().replace("-", "_") + + # replace whitespace with "_" + task = "_".join(task.split()) + + return task == self.name or task in self.aliases + + +class SupportedTasks: + """ + The supported tasks in the DeepSparse pipeline and system + """ + + text_generation = namedtuple( + "text_generation", ["text_generation", "opt", "bloom"] + )( + text_generation=AliasedTask("text_generation", []), + opt=AliasedTask("opt", []), + bloom=AliasedTask("bloom", []), + ) + + all_task_categories = [text_generation] + + @classmethod + def check_register_task( + cls, task: str, extra_tasks: Optional[Iterable[str]] = None + ): + """ + :param task: task name to validate and import dependencies for + :param extra_tasks: valid task names that are not included in supported tasks. + i.e. tasks registered to Pipeline at runtime + """ + if cls.is_text_generation(task): + import deepsparse.v2.text_generation.pipeline # noqa: F401 + + all_tasks = set(cls.task_names() + (list(extra_tasks or []))) + if task not in all_tasks: + raise ValueError( + f"Unknown Pipeline task {task}. Currently supported tasks are " + f"{list(all_tasks)}" + ) + + @classmethod + def is_text_generation(cls, task: str) -> bool: + """ + :param task: the name of the task to check whether it is a text generation task + such as codegen + :return: True if it is a text generation task, False otherwise + """ + return any( + text_generation_task.matches(task) + for text_generation_task in cls.text_generation + ) + + @classmethod + def task_names(cls): + task_names = ["custom"] + for task_category in cls.all_task_categories: + for task in task_category: + unique_aliases = ( + alias for alias in task._aliases if alias != task._name + ) + task_names += (task._name, *unique_aliases) + return task_names + + +def dynamic_import_task(module_or_path: str) -> str: + """ + Dynamically imports `module` with importlib, and returns the `TASK` + attribute on the module (something like `importlib.import_module(module).TASK`). + + Example contents of `module`: + ```python + from deepsparse.pipeline import Pipeline + from deepsparse.transformers.pipelines.question_answering import ( + QuestionAnsweringPipeline, + ) + + TASK = "my_qa_task" + Pipeline.register(TASK)(QuestionAnsweringPipeline) + ``` + + NOTE: this modifies `sys.path`. + + :raises FileNotFoundError: if path does not exist + :raises RuntimeError: if the imported module does not contain `TASK` + :raises RuntimeError: if the module doesn't register the task + :return: The task from the imported module. + """ + parent_dir, module_name = _split_dir_and_name(module_or_path) + if not os.path.exists(os.path.join(parent_dir, module_name + ".py")): + raise FileNotFoundError( + f"Unable to find file for {module_or_path}. " + f"Looked for {module_name}.py under {parent_dir if parent_dir else '.'}" + ) + + # add parent_dir to sys.path so we can import the file as a module + sys.path.append(os.curdir) + if parent_dir: + _LOGGER.info(f"Adding {parent_dir} to sys.path") + sys.path.append(parent_dir) + + # do the import + _LOGGER.info(f"Importing '{module_name}'") + module_or_path = importlib.import_module(module_name) + + if not hasattr(module_or_path, "TASK"): + raise RuntimeError( + "When using --task import:, " + "module must set the `TASK` attribute." + ) + + task = getattr(module_or_path, "TASK") + _LOGGER.info(f"Using task={repr(task)}") + + return task + + +def _split_dir_and_name(module_or_path: str) -> Tuple[str, str]: + """ + Examples: + - `a` -> `("", "a")` + - `a.b` -> `("a", "b")` + - `a.b.c` -> `("a/b", "c")` + + :return: module split into directory & name + """ + if module_or_path.endswith(".py"): + # assume path + split_char = os.sep + module_or_path = module_or_path.replace(".py", "") + else: + # assume module + split_char = "." + *dirs, module_name = module_or_path.split(split_char) + parent_dir = os.sep if dirs == [""] else os.sep.join(dirs) + return parent_dir, module_name diff --git a/src/deepsparse/v2/text_generation/__init__.py b/src/deepsparse/v2/text_generation/__init__.py new file mode 100644 index 0000000000..6f1323de50 --- /dev/null +++ b/src/deepsparse/v2/text_generation/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# flake8: noqa +from .autoregressive_preprocess_operator import * +from .compile_generated_tokens import * +from .compile_generations import * +from .compile_logits import * +from .generate_new_token import * +from .join_output import * +from .kv_cache_operator import * +from .multi_engine_prefill_operator import * +from .nl_engine_operator import * +from .nl_engine_operator_no_kv_cache import * +from .prep_for_prefill import * +from .process_inputs import * +from .process_outputs import * + + +from .token_generator import * # isort:skip +from .prep_for_generation import * # isort:skip + +from .pipeline import * # isort:skip +from .pipeline_no_kv_cache import * # isort:skip diff --git a/src/deepsparse/v2/text_generation/autoregressive_preprocess_operator.py b/src/deepsparse/v2/text_generation/autoregressive_preprocess_operator.py new file mode 100644 index 0000000000..17d8dd662c --- /dev/null +++ b/src/deepsparse/v2/text_generation/autoregressive_preprocess_operator.py @@ -0,0 +1,82 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any + +from deepsparse.transformers.utils.helpers import compute_engine_inputs +from deepsparse.v2.operators import Operator +from deepsparse.v2.utils import PipelineState + + +_LOGGER = logging.getLogger(__name__) + +__all__ = ["AutoRegressiveOperatorPreprocess"] + + +class AutoRegressiveOperatorPreprocess(Operator): + def __init__(self, sequence_length: int, prompt_sequence_length: int): + """ + Prepare the tokens for the single-token engine. This requires creating the + attention mask, positions, and causal mask. The output contains these three + arrays to be passed into the single-token engine. + """ + self.sequence_length = sequence_length + self.prompt_sequence_length = prompt_sequence_length + + _LOGGER.warn( + "This operator requires the PipelineState to be set-up with the " + "onnx_input_names_no_cache attribute set from the NLEngineOperator." + ) + + def can_operate(self, inp: Any) -> bool: + """ + Can run this Operator if the number of tokens left to process is greater than + 0 but less than the self.prompt_sequence_length. + """ + tokens = inp.get("tokens") + kv_cache = inp.get("kv_cache") + + if inp.get("in_generation"): + return True + + remaining_tokens = len(tokens) - kv_cache.total_num_processed_tokens + can_process = ( + remaining_tokens > 0 and remaining_tokens < self.prompt_sequence_length + ) + if can_process and inp.get("in_generation") is None: + return True + return False + + def run(self, tokens: Any, kv_cache: Any, pipeline_state: PipelineState, **kwargs): + kv_cache.set_capacity(self.sequence_length - 1) + + num_total_processed_tokens = kv_cache.total_num_processed_tokens + new_token = tokens[num_total_processed_tokens] + + engine_inputs = compute_engine_inputs( + onnx_input_names=pipeline_state.current_state.get( + "onnx_input_names_no_cache" + ), + token_batch=[new_token], + prompt_sequence_length=1, + sequence_length=self.sequence_length, + num_total_processed_tokens=num_total_processed_tokens, + ) + return { + "engine_inputs": engine_inputs, + "kv_cache": kv_cache, + "tokens": tokens, + "in_generation": kwargs.get("in_generation"), + } diff --git a/src/deepsparse/v2/text_generation/compile_generated_tokens.py b/src/deepsparse/v2/text_generation/compile_generated_tokens.py new file mode 100644 index 0000000000..630067f8c3 --- /dev/null +++ b/src/deepsparse/v2/text_generation/compile_generated_tokens.py @@ -0,0 +1,56 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from deepsparse.v2.operators import Operator +from deepsparse.v2.utils import InferenceState + + +__all__ = ["CompileGeneratedTokens"] + + +class CompileGeneratedTokens(Operator): + def run( + self, + new_token, + logits, + finish_reason, + kv_cache, + tokens, + inference_state: InferenceState, + **kwargs, + ): + in_generation = True + + generated_tokens = inference_state.current_state.get("generated_tokens") + generated_logits = inference_state.current_state.get("generated_logits") + finished_reason = inference_state.current_state.get("finished_reason") + + generated_tokens.append(new_token) + generated_logits.append(logits) + finished_reason.append(finish_reason) + + if finish_reason is not None: + in_generation = False + + state_update = { + "finished_reason": finished_reason, + "generated_tokens": generated_tokens, + "generated_logits": generated_logits, + } + + output = { + "tokens": tokens, + "kv_cache": kv_cache, + "in_generation": in_generation, + } + return output, state_update diff --git a/src/deepsparse/v2/text_generation/compile_generations.py b/src/deepsparse/v2/text_generation/compile_generations.py new file mode 100644 index 0000000000..ed8297ac01 --- /dev/null +++ b/src/deepsparse/v2/text_generation/compile_generations.py @@ -0,0 +1,55 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +import numpy +from pydantic import BaseModel, Field + +from deepsparse.transformers.pipelines.text_generation import FinishReason +from deepsparse.v2.operators import Operator +from deepsparse.v2.utils import InferenceState + + +__all__ = ["CompileGenerations", "CompileGenerationsOutput"] + + +class CompileGenerationsOutput(BaseModel): + generated_tokens: Any = Field(description="generated_tokens") + generated_logits: Any = Field(description="generated_logits") + finished_reason: Any = Field(description="finished_reason") + + +class CompileGenerations(Operator): + output_schema = CompileGenerationsOutput + + def can_operate(self, inp: Any): + if inp.get("in_generation") is False: + return True + return False + + def run(self, inference_state: InferenceState, **kwargs): + generated_tokens = inference_state.current_state.get("generated_tokens") + generated_logits = inference_state.current_state.get("generated_logits") + finished_reason = inference_state.current_state.get("finished_reason") + + if len(finished_reason) == 0: + finished_reason.append(FinishReason.LENGTH) + + generated_tokens = numpy.array([generated_tokens]) + generated_logits = numpy.concatenate(generated_logits, axis=1) + return { + "generated_tokens": generated_tokens, + "generated_logits": generated_logits, + "finished_reason": finished_reason, + } diff --git a/src/deepsparse/v2/text_generation/compile_logits.py b/src/deepsparse/v2/text_generation/compile_logits.py new file mode 100644 index 0000000000..48a7158f66 --- /dev/null +++ b/src/deepsparse/v2/text_generation/compile_logits.py @@ -0,0 +1,49 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from deepsparse.v2.operators import Operator +from deepsparse.v2.text_generation.nl_engine_operator import NLEngineOutputs +from deepsparse.v2.utils import InferenceState + + +__all__ = ["CompilePromptLogits"] + + +class CompilePromptLogits(Operator): + """ + Combine the prompt logits. Currently relying on the inference state to store the + prompt logits for each token or multi-token batch processed. This operator will + take prompt logits from each iteration run and update the inference state. + """ + + def can_operate(self, inp: NLEngineOutputs): + if inp.in_generation is None: + return True + return False + + def run(self, inp: NLEngineOutputs, inference_state: InferenceState, **kwargs): + logits = inp.engine_outputs + logit_type = "prompt_logits" + + if inference_state.current_state.get(logit_type) is not None: + current_logits = inference_state.current_state.get(logit_type).copy() + current_logits.append(logits) + else: + current_logits = [logits] + + state_update = {logit_type: current_logits} + return { + "kv_cache": inp.kv_cache, + "tokens": inp.tokens, + }, state_update diff --git a/src/deepsparse/v2/text_generation/generate_new_token.py b/src/deepsparse/v2/text_generation/generate_new_token.py new file mode 100644 index 0000000000..ba3fb445aa --- /dev/null +++ b/src/deepsparse/v2/text_generation/generate_new_token.py @@ -0,0 +1,94 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Sequence, Union + +import transformers + +from deepsparse.transformers.pipelines.text_generation import FinishReason +from deepsparse.v2.operators import Operator +from deepsparse.v2.text_generation.nl_engine_operator import NLEngineOutputs +from deepsparse.v2.utils import InferenceState + + +__all__ = ["GenerateNewTokenOperator"] + + +class GenerateNewTokenOperator(Operator): + def __init__( + self, tokenizer: transformers.PreTrainedTokenizerBase, force_max_tokens: bool + ): + self.force_max_tokens = force_max_tokens + self.tokenizer = tokenizer + + def can_operate(self, inp: NLEngineOutputs): + if inp.in_generation: + return True + return False + + def run(self, *args, inference_state: InferenceState, **kwargs): + logits = args[0].engine_outputs if args else kwargs.get("logits") + kv_cache = args[0].kv_cache if args else kwargs.get("kv_cache") + + token_generator = inference_state.current_state.get("token_generator") + token = token_generator.generate(logits=logits[0, -1, :]) + finish_reason = None + + callback = inference_state.current_state.get("callback") + stop = inference_state.current_state.get("stop") + + if token == self.tokenizer.eos_token_id and not self.force_max_tokens: + finish_reason = FinishReason.STOP + + if self._stop_token_generated(token, stop_tokens=stop): + print( + "Stop token %s generated. Stopping generation." + % self.tokenizer.decode(token) + ) + finish_reason = FinishReason.STOP + + if callback is not None and callback(token) is False: + print( + "callback %s returned False, stopping generation." + % callback.__qualname__ + ) + finish_reason = FinishReason.CALLBACK + + max_tokens = inference_state.current_state.get("max_tokens") + if len(inference_state.current_state.get("generated_tokens")) + 1 == max_tokens: + finish_reason = inference_state.current_state.get("length_finish_reason") + + state_update = { + "token_generator": token_generator, + } + + new_generation = { + "logits": logits, + "new_token": token, + "finish_reason": finish_reason, + } + output = {"tokens": token_generator.tokens, "kv_cache": kv_cache} + output.update(new_generation) + return output, state_update + + def _stop_token_generated( + self, token, stop_tokens: Union[None, str, Sequence[str]] + ) -> bool: + if stop_tokens is None: + return False + + decoded_token = self.tokenizer.decode(token) + decoded_token = ( + decoded_token if decoded_token.isspace() else decoded_token.strip() + ) + return decoded_token in stop_tokens diff --git a/src/deepsparse/v2/text_generation/join_output.py b/src/deepsparse/v2/text_generation/join_output.py new file mode 100644 index 0000000000..7479ee7493 --- /dev/null +++ b/src/deepsparse/v2/text_generation/join_output.py @@ -0,0 +1,71 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Tuple + +import numpy + +from deepsparse.transformers.utils.helpers import pad_to_fixed_length +from deepsparse.v2.operators import Operator +from deepsparse.v2.text_generation.compile_generations import CompileGenerationsOutput + + +__all__ = ["JoinOutput"] + + +class JoinOutput(Operator): + """ + Run this operator to combine the results from multiple prompts. + """ + + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + def run(self, inp: Tuple[List[CompileGenerationsOutput], Dict], **kwargs): + + batch_outputs = [x for x in inp[0]] + generated_tokens = [x.generated_tokens for x in batch_outputs] + generated_logits = [x.generated_logits for x in batch_outputs] + finished_reason = [x.finished_reason for x in batch_outputs] + + max_len = max(token.shape[1] for token in generated_tokens) + + # pad all tokens to the same length + tokens = [ + pad_to_fixed_length( + array=prediction, + max_len=max_len, + value=self.tokenizer.pad_token_id, + axis=1, + ) + for prediction in generated_tokens + ] + + # find the longest sequence in the batch of logits + max_len = max(logits.shape[1] for logits in generated_logits) + + # pad all logits to the same length + logits = [ + pad_to_fixed_length(array=single_logits, max_len=max_len, axis=1) + for single_logits in generated_logits + ] + + tokens = numpy.concatenate(tokens) + logits = numpy.concatenate(logits) + + return { + "generated_tokens": tokens, + "generated_logits": logits, + "finished_reason": finished_reason, + } diff --git a/src/deepsparse/v2/text_generation/kv_cache_operator.py b/src/deepsparse/v2/text_generation/kv_cache_operator.py new file mode 100644 index 0000000000..3c15d0ff5a --- /dev/null +++ b/src/deepsparse/v2/text_generation/kv_cache_operator.py @@ -0,0 +1,70 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from pydantic import BaseModel, Field + +from deepsparse.transformers.utils import DecoderKVCache +from deepsparse.transformers.utils.helpers import ( + initialize_kv_cache_state, + prepends_bos_token, +) +from deepsparse.v2.operators import Operator + + +__all__ = ["KVCacheCreator", "KVCacheCreatorInput"] + + +class KVCacheCreatorOutput(BaseModel): + kv_cache: Any = Field(description="KV Cache Created") # DecoderKVCache + + +class KVCacheCreatorInput(BaseModel): + cache_shape: Any = Field(description="shape") + kv_cache_data_type: Any = Field(description="data type") + output_names: Any = Field(description="output names") + + +class KVCacheCreator(Operator): + input_schema = KVCacheCreatorInput + output_schema = KVCacheCreatorOutput + + def __init__( + self, + tokenizer, + sequence_length: int, + prompt_sequence_length: int, + internal_kv_cache: bool, + ): + self.tokenizer = tokenizer + self.prompt_sequence_length = prompt_sequence_length + self.internal_kv_cache = internal_kv_cache + self.sequence_length = sequence_length + + def run(self, cache_shape, kv_cache_data_type: str, output_names: list, **kwargs): + kv_cache_state = initialize_kv_cache_state( + cache_shape=cache_shape, + kv_cache_data_type=kv_cache_data_type, + output_names=output_names, + length=self.sequence_length - self.prompt_sequence_length, + empty=bool(self.internal_kv_cache), + ) + + kv_cache = DecoderKVCache(self.internal_kv_cache) + kv_cache.setup( + state=kv_cache_state, + freeze_first_position=prepends_bos_token(self.tokenizer), + ) + return {"kv_cache": kv_cache} diff --git a/src/deepsparse/v2/text_generation/multi_engine_prefill_operator.py b/src/deepsparse/v2/text_generation/multi_engine_prefill_operator.py new file mode 100644 index 0000000000..513c34dfc2 --- /dev/null +++ b/src/deepsparse/v2/text_generation/multi_engine_prefill_operator.py @@ -0,0 +1,77 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any + +from deepsparse.transformers.utils.helpers import compute_engine_inputs +from deepsparse.v2.operators import Operator +from deepsparse.v2.utils import PipelineState + + +_LOGGER = logging.getLogger(__name__) + +__all__ = ["MultiEnginePrefill"] + + +class MultiEnginePrefill(Operator): + def __init__(self, prompt_sequence_length, sequence_length): + """ + Prepare the tokens for the multi-token engine. This requires creating the + appropriate engine_inputsto be passed into the multi-token engine. + """ + self.prompt_sequence_length = prompt_sequence_length + self.sequence_length = sequence_length + + def can_operate(self, inp: Any): + """ + Can only run if the number of prompt tokens left to process is greater than + or equal to the self.prompt_sequence_length. + """ + kv_cache = inp.get("kv_cache") + tokens = inp.get("tokens") + + if len(tokens) < self.prompt_sequence_length: + return False + + if ( + len(tokens) - kv_cache.total_num_processed_tokens + >= self.prompt_sequence_length + ): + return True + return False + + def run(self, tokens: Any, kv_cache: Any, pipeline_state: PipelineState, **kwargs): + kv_cache.set_capacity(self.sequence_length - self.prompt_sequence_length) + + num_total_processed_tokens = kv_cache.total_num_processed_tokens + start = num_total_processed_tokens + end = start + self.prompt_sequence_length + token_batch = tokens[start:end] + + engine_inputs = compute_engine_inputs( + onnx_input_names=pipeline_state.current_state.get( + "onnx_input_names_no_cache" + ), + token_batch=token_batch, + prompt_sequence_length=self.prompt_sequence_length, + sequence_length=self.sequence_length, + num_total_processed_tokens=num_total_processed_tokens, + ) + + return { + "engine_inputs": engine_inputs, + "kv_cache": kv_cache, + "tokens": tokens, + } diff --git a/src/deepsparse/v2/text_generation/nl_engine_operator.py b/src/deepsparse/v2/text_generation/nl_engine_operator.py new file mode 100644 index 0000000000..c6583e37cf --- /dev/null +++ b/src/deepsparse/v2/text_generation/nl_engine_operator.py @@ -0,0 +1,311 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import os +from pathlib import Path +from typing import Any, List, Optional, Tuple, Union + +import numpy +from pydantic import BaseModel, Field + +from deepsparse.utils import join_engine_outputs, split_engine_inputs +from deepsparse.utils.onnx import ( + CACHE_INPUT_PREFIX, + overwrite_onnx_model_inputs_for_kv_cache_models, +) +from deepsparse.v2.operators.engine_operator import ( + DEEPSPARSE_ENGINE, + EngineOperator, + EngineOperatorInputs, +) + + +__all__ = ["NLEngineOperator", "NLEngineInputs", "NLEngineOutputs"] + + +class NLEngineInputs(BaseModel): + engine_inputs: List = Field(description="engine_inputs") + kv_cache: Any = Field(description="kv_cache object") + tokens: List = Field(description="tokens") + in_generation: Any = Field(description="in_generation", default=None) + engine: Optional[Any] = Field( + description="override the engine to run forward pass with", + default=None, + ) + + @classmethod + def join(cls, inputs: List["NLEngineInputs"]) -> "NLEngineInputs": + """ + :param inputs: list of separate EngineOperatorInputs, batch size must be 1 + :return: list of inputs joined into a single input with a multi batch size + """ + all_engine_inputs = [] + all_kv_cache = [] + all_tokens = [] + all_generation = [] + + for engine_input in inputs: + all_engine_inputs.append(engine_input.engine_inputs) + all_kv_cache.append(engine_input.kv_cache) + all_tokens.append(engine_input.tokens) + all_generation.append(engine_input.in_generation) + + for engine_inputs in all_engine_inputs: + if engine_inputs[0].shape[0] != 1: + raise RuntimeError( + "join requires all inputs to have batch size 1, found input with " + f"batch size {engine_inputs[0].shape[0]}" + ) + return cls( + engine_inputs=all_engine_inputs, + tokens=all_tokens, + in_generation=all_generation, + kv_cache=all_kv_cache, + ) + + class Config: + arbitrary_types_allowed = True + + +class NLEngineOutputs(BaseModel): + engine_outputs: Any = Field(description="engine_outputs") + kv_cache: Any = Field(description="kv_cache object") + tokens: List = Field(description="tokens") + in_generation: Any = Field(description="in_generation", default=None) + + def split(self) -> List["NLEngineOutputs"]: + """ + :return: list of the current outputs split to a batch size of 1 each + """ + split_outputs = [ + numpy.expand_dims(self.engine_outputs[i], 0) + for i in range(len(self.engine_outputs)) + ] + return [ + self.__class__( + engine_outputs=split_outputs[i], + kv_cache=self.kv_cache[i], + tokens=self.tokens[i], + in_generation=self.in_generation[i], + ) + for i in range(len(split_outputs)) + ] + + +class NLEngineOperator(EngineOperator): + + """ + Operator for the NL Decoder Engine. This Operator inherits from the EngineOperator. + Specific updates to engine attributes are made through this operator, as well + as updating the kv_cache. This Operator is used for both the single-token and + multi-token case. + """ + + input_schema = NLEngineInputs + output_schema = NLEngineOutputs + + def __init__( + self, + sequence_length: int, + input_ids_length: int, + internal_kv_cache: bool = False, + **kwargs, + ): + + self.sequence_length = sequence_length + self.input_ids_length = input_ids_length + self.kv_cache_data_type = None + self.internal_kv_cache = internal_kv_cache + self.model_path = kwargs.get("model_path") + (onnx_file_path, additional_outputs) = self.override_model_inputs( + self.model_path, batch_size=1, return_additional_outputs=True + ) + output_indices_to_be_cached, kv_cache_data_type, = additional_outputs.get( + "output_indices_to_be_cached" + ), additional_outputs.get("kv_cache_data_type") + + engine_kwargs = kwargs.get("engine_kwargs", {}) + if kwargs.get("engine_type", DEEPSPARSE_ENGINE) == DEEPSPARSE_ENGINE: + if "WAND_OPT_FLAGS" not in os.environ: + os.environ["WAND_OPT_FLAGS"] = "default,~pyramids" + + if any(output_indices_to_be_cached): + self.kv_cache_data_type = kv_cache_data_type + if ( + internal_kv_cache + and kwargs.get("engine_type", DEEPSPARSE_ENGINE) == DEEPSPARSE_ENGINE + ): + engine_kwargs["cached_outputs"] = output_indices_to_be_cached + + kwargs["engine_kwargs"] = engine_kwargs + kwargs["model_path"] = onnx_file_path + + super().__init__(**kwargs) + + def override_model_inputs( + self, + model_path: Union[str, Path], + batch_size: int, + return_additional_outputs=False, + ): + """ + Override the model based on the provided batch_size, sequence_length, + and input_ids_length. + + :param model_path: Path to the model + :param batch_size: The batch size to be used for the model + :return: new overwritten model file path. Optionally returns additional outputs + specific to the NLDecoder engine + """ + ( + onnx_file_path, + output_indices_to_be_cached, + kv_cache_data_type, + ) = overwrite_onnx_model_inputs_for_kv_cache_models( + onnx_file_path=model_path, + batch_size=batch_size, + sequence_length=self.sequence_length, + input_ids_length=self.input_ids_length, + ) + if return_additional_outputs: + return onnx_file_path, { + "output_indices_to_be_cached": output_indices_to_be_cached, + "kv_cache_data_type": kv_cache_data_type, + } + return onnx_file_path + + def run(self, inp: NLEngineInputs, **kwargs) -> NLEngineOutputs: + engine_input = inp.engine_inputs + kv_cache = inp.kv_cache + + split = True + if not isinstance(kv_cache, list): + split = False + kv_cache = [kv_cache] + engine_input = [engine_input] + + inputs = list(map(self._add_kv_cache_to_input, engine_input, kv_cache)) + + if bool(kv_cache[0].engine_internal_cache): + # conventionally, before dispatching + # inputs to the engine, we validate them + # if val_inp=True. However, in this case + # we want to pass the empty kv cache inputs + # (batch_size=0) to the engine. Therefore, + # we skip the validation + + # Internal kv_cache works for batch_size of 1 atm + out = self.engine._eng_net.execute_list_out( + inputs[0], kv_cache[0].engine_internal_cache + ) + else: + # run the engine without the LIB.kv_cache object + # stack multiple batch inputs along the batch dimension + inputs = join_engine_outputs(inputs, len(inputs)) + out = ( + super() + .run( + EngineOperatorInputs(engine_inputs=inputs, engine=inp.engine), + **kwargs, + ) + .get("engine_outputs") + ) + + # logits should be stacked along batch dim + # kv_cache_state should be a list where each dim 0 is batch_size + logits, *kv_cache_state = out + kv_cache_state, _ = split_engine_inputs(kv_cache_state, 1) + + if len(kv_cache_state) > 0: + for i in range(len(kv_cache)): + self._update_kv_cache( + kv_cache_state=kv_cache_state[i], kv_cache=kv_cache[i] + ) + else: + # internal kv cache case + self._update_kv_cache(kv_cache=kv_cache[0]) + + output = { + "engine_outputs": logits, + "kv_cache": kv_cache if split else kv_cache[0], + "tokens": inp.tokens, + "in_generation": inp.in_generation, + } + return output + + def _add_kv_cache_to_input(self, engine_input, kv_cache): + kv_cache_state = copy.copy(kv_cache.cached_inputs) + + for idx, input_name in enumerate(self.onnx_input_names_no_cache): + kv_cache_state[input_name] = engine_input[idx] + + new_inp = [kv_cache_state[name] for name in self.engine.input_names] + return new_inp + + def _update_kv_cache(self, kv_cache, kv_cache_state=None): + if bool(kv_cache.engine_internal_cache): + kv_cache.total_num_processed_tokens += self.input_ids_length + return + + kv_cache_state = { + name: array + for name, array in zip(self.onnx_input_names_cached, kv_cache_state) + } + + kv_cache.update(state=kv_cache_state, input_ids_len=self.input_ids_length) + + @property + def onnx_input_names_no_cache(self) -> List[str]: + """ + :return: The input names for the onnx model, excluding + the potential kv cache inputs + """ + return [ + name + for name in self.engine.input_names + if not name.startswith(CACHE_INPUT_PREFIX) + ] + + @property + def onnx_input_names_cached(self) -> List[str]: + """ + :return: The cached input names for the onnx model + """ + return [ + name + for name in self.engine.input_names + if name.startswith(CACHE_INPUT_PREFIX) + ] + + @property + def cache_shape(self) -> Tuple[int, int, int, int]: + """ + :return: The shape of the kv cache inputs + for the onnx model. The shape is + (batch_size, num_heads, sequence_length, hidden_size) + """ + cache_engine_input_index = next( + i + for i, name in enumerate(self.engine.input_names) + if CACHE_INPUT_PREFIX in name + ) + return self.engine.input_shapes[cache_engine_input_index] + + @property + def output_names(self) -> List[str]: + """ + :return: The output names for the onnx model + """ + return self.engine.output_names diff --git a/src/deepsparse/v2/text_generation/nl_engine_operator_no_kv_cache.py b/src/deepsparse/v2/text_generation/nl_engine_operator_no_kv_cache.py new file mode 100644 index 0000000000..746010560f --- /dev/null +++ b/src/deepsparse/v2/text_generation/nl_engine_operator_no_kv_cache.py @@ -0,0 +1,67 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import numpy +from pydantic import BaseModel + +from deepsparse.transformers.helpers import overwrite_transformer_onnx_model_inputs +from deepsparse.v2.operators.engine_operator import EngineOperator, EngineOperatorInputs + + +__all__ = [ + "NLEngineOperatorNoCache", + "NLEngineInputsNoCache", +] + + +class NLEngineInputsNoCache(BaseModel): + input_ids: Any + attention_mask: Any + + +class NLEngineOperatorNoCache(EngineOperator): + """ + Operator the Natural Language Engine, that operates without + KV Cache. This means that this operator merely maps input_ids + and attention_mask to logits + """ + + input_schema = NLEngineInputsNoCache + output_schema = None + + def __init__(self, sequence_length: int, **kwargs): + overwrite_transformer_onnx_model_inputs( + path=kwargs.get("model_path"), + batch_size=kwargs.get("batch_size", 1), + max_length=sequence_length, + ) + super().__init__(**kwargs) + + def run(self, inp: NLEngineInputsNoCache, **kwargs) -> Any: + engine_inputs = [inp.input_ids, inp.attention_mask] + logits = ( + super() + .run(EngineOperatorInputs(engine_inputs=engine_inputs), **kwargs) + .get("engine_outputs") + ) + + # By default, the engine outputs logits for all tokens in the sequence. + # Let's filter out the logits for the padding tokens. + logits = numpy.compress(inp.attention_mask.flatten(), logits[0], axis=1) + + return {"logits": [logits], "kv_cache": None, "tokens": None}, { + "prompt_logits": [logits] + } diff --git a/src/deepsparse/v2/text_generation/pipeline.py b/src/deepsparse/v2/text_generation/pipeline.py new file mode 100644 index 0000000000..344980dc3f --- /dev/null +++ b/src/deepsparse/v2/text_generation/pipeline.py @@ -0,0 +1,240 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Dict, List, Optional + +from deepsparse.transformers.helpers import setup_transformers_pipeline +from deepsparse.transformers.utils.helpers import process_generation_config +from deepsparse.utils import split_engine_inputs +from deepsparse.v2.operators import EngineOperator +from deepsparse.v2.operators.registry import OperatorRegistry +from deepsparse.v2.pipeline import Pipeline +from deepsparse.v2.routers import GraphRouter +from deepsparse.v2.schedulers import ContinuousBatchingScheduler, OperatorScheduler +from deepsparse.v2.text_generation import ( + AutoRegressiveOperatorPreprocess, + CompileGeneratedTokens, + CompileGenerations, + CompilePromptLogits, + GenerateNewTokenOperator, + JoinOutput, + KVCacheCreator, + MultiEnginePrefill, + NLEngineOperator, + PrepareforPrefill, + PrepareGeneration, + ProcessInputsTextGeneration, + ProcessOutputs, + TokenGeneratorOperator, +) +from deepsparse.v2.utils import PipelineState + + +_LOGGER = logging.getLogger(__name__) + + +@OperatorRegistry.register(name="text_generation") +class TextGenerationPipeline(Pipeline): + def __init__( + self, + model_path: str, + prompt_sequence_length: int = 16, + sequence_length: int = 1024, + internal_kv_cache: bool = True, + force_max_tokens: bool = False, + generation_config=None, + continuous_batch_sizes: Optional[List[int]] = None, + engine_kwargs: Optional[Dict] = None, + ): + ( + self.model_path, + self.config, + self.tokenizer, + engine_kwargs, + ) = setup_transformers_pipeline( + model_path, sequence_length, engine_kwargs=engine_kwargs + ) + + pipeline_state = PipelineState() + pipeline_state_vals = {} + + if internal_kv_cache and engine_kwargs.get("engine_type") == "onnxruntime": + internal_kv_cache = False + + single_engine_operator = NLEngineOperator( + sequence_length=sequence_length, + internal_kv_cache=internal_kv_cache, + input_ids_length=1, + **engine_kwargs, + ) + + multi_engine_operator = NLEngineOperator( + sequence_length=sequence_length, + internal_kv_cache=internal_kv_cache, + input_ids_length=prompt_sequence_length, + **engine_kwargs, + ) + + # NOTE: Currently using pipeline state. Can swap to simply pass in the + # attributes to the specific Operator that need them, as class attributes. + pipeline_state_vals[ + "onnx_input_names_no_cache" + ] = single_engine_operator.onnx_input_names_no_cache + pipeline_state_vals["cache_shape"] = single_engine_operator.cache_shape + pipeline_state_vals["output_names"] = single_engine_operator.output_names + pipeline_state_vals[ + "kv_cache_data_type" + ] = single_engine_operator.kv_cache_data_type + pipeline_state.create_state(pipeline_state_vals) + + process_inputs = ProcessInputsTextGeneration( + generation_config=process_generation_config(generation_config), + sequence_length=sequence_length, + tokenizer=self.tokenizer, + ) + + kv_cache_creator = KVCacheCreator( + sequence_length=sequence_length, + tokenizer=self.tokenizer, + prompt_sequence_length=prompt_sequence_length, + internal_kv_cache=internal_kv_cache, + ) + + # NOTE: Can also have the KVCacheCreator be initialized inside this Operator. + # Relies on pipeline state variables set-up above (can be swapped to be class + # attributes instead of using the state. + engine_inputs_for_prefill = PrepareforPrefill(kv_cache_creator=kv_cache_creator) + + multi_engine_prefill = MultiEnginePrefill( + prompt_sequence_length=prompt_sequence_length, + sequence_length=sequence_length, + ) + compile_prompt_logits = CompilePromptLogits() + + autoregressive_preprocess = AutoRegressiveOperatorPreprocess( + sequence_length=sequence_length, + prompt_sequence_length=prompt_sequence_length, + ) + token_generator = TokenGeneratorOperator() + prep_for_generation = PrepareGeneration( + sequence_length=sequence_length, + prompt_sequence_length=prompt_sequence_length, + token_generator=token_generator, + ) + generate_new_token = GenerateNewTokenOperator( + tokenizer=self.tokenizer, force_max_tokens=force_max_tokens + ) + process_output = ProcessOutputs(tokenizer=self.tokenizer) + compile_generations = CompileGenerations() + compile_generated_tokens = CompileGeneratedTokens() + join_output = JoinOutput(tokenizer=self.tokenizer) + + # TODO: do we want to support lists for different engines? + continuous_batching_scheduler = None + if continuous_batch_sizes: + if internal_kv_cache: + _LOGGER.warn( + "internal kv_cache is currently not supported with continuous ", + "batching", + ) + else: + continuous_batching_scheduler = self._get_continuous_batching_scheduler( + batch_sizes=continuous_batch_sizes, + engines=[single_engine_operator, multi_engine_operator], + ) + + ops = { + "process_input": process_inputs, + "single_engine": single_engine_operator, + "multi_engine": multi_engine_operator, + "kv_cache_creator": kv_cache_creator, + "prepare_prefill": engine_inputs_for_prefill, + "multi_engine_prefill": multi_engine_prefill, + "compile_logits": compile_prompt_logits, + "autoregressive_preprocess": autoregressive_preprocess, + "prep_for_generation": prep_for_generation, + "generate_new_token": generate_new_token, + "process_outputs": process_output, + "compile_generations": compile_generations, + "compile_generated_tokens": compile_generated_tokens, + "join_output": join_output, + } + + routes = { + "process_input": "SPLIT", + "SPLIT": "prepare_prefill", + "prepare_prefill": ["multi_engine_prefill", "autoregressive_preprocess"], + "multi_engine_prefill": "multi_engine", + "multi_engine": "compile_logits", + "compile_logits": [ + "multi_engine_prefill", + "prep_for_generation", + "autoregressive_preprocess", + ], + "autoregressive_preprocess": "single_engine", + "single_engine": [ + "compile_logits", + "generate_new_token", + ], + "prep_for_generation": "autoregressive_preprocess", + "generate_new_token": "compile_generated_tokens", + "compile_generated_tokens": [ + "autoregressive_preprocess", + "compile_generations", + ], + "compile_generations": "JOIN", + "JOIN": "join_output", + "join_output": "process_outputs", + "process_outputs": "STOP", + } + + router = GraphRouter( + end_route="STOP", start_route="process_input", route=routes + ) + scheduler = [OperatorScheduler()] + super().__init__( + ops=ops, + router=router, + schedulers=scheduler, + pipeline_state=pipeline_state, + continuous_batching_scheduler=continuous_batching_scheduler, + ) + + def expand_inputs(self, items, batch_size): + items = [items.get(key) for key in items.keys()] + out, orig_batch_size = split_engine_inputs(items, batch_size) + combined_batches = [{"input_ids": b[0], "attention_mask": b[1]} for b in out] + return combined_batches, orig_batch_size + + def condense_inputs(self, *args, **kwargs): + return args[0], kwargs + + def _get_continuous_batching_scheduler( + self, batch_sizes: List[int], engines: List[EngineOperator] + ) -> ContinuousBatchingScheduler: + """ + Fetch the continuous batching scheduler. Requires adding the EngineOperator + that will run through the scheduler. + + :param batch_sizes: List of batch sizes to be used by the models + :param engine: List of EngineOperators which should be scheduled using the + continuous batching scheduler + + :returns: ContinuousBatchingScheduler + """ + continuous_batching_scheduler = ContinuousBatchingScheduler.get_instance() + for op in engines: + continuous_batching_scheduler.add_engine_operator(op, batch_sizes) + return continuous_batching_scheduler diff --git a/src/deepsparse/v2/text_generation/pipeline_no_kv_cache.py b/src/deepsparse/v2/text_generation/pipeline_no_kv_cache.py new file mode 100644 index 0000000000..ffb149ff27 --- /dev/null +++ b/src/deepsparse/v2/text_generation/pipeline_no_kv_cache.py @@ -0,0 +1,146 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Dict, Optional + +from deepsparse.transformers.helpers import setup_transformers_pipeline +from deepsparse.transformers.utils.helpers import process_generation_config +from deepsparse.utils import split_engine_inputs +from deepsparse.utils.onnx import default_cached_outputs +from deepsparse.v2.pipeline import Pipeline +from deepsparse.v2.routers import LinearRouter +from deepsparse.v2.schedulers import OperatorScheduler +from deepsparse.v2.text_generation import ( + CompileGenerations, + GenerateNewTokenOperator, + JoinOutput, + NLEngineOperatorNoCache, + PrepareGeneration, + ProcessInputsTextGeneration, + ProcessOutputs, + TokenGeneratorOperator, +) + + +_LOGGER = logging.getLogger(__name__) + + +class TextGenerationPipelineNoCache(Pipeline): + def __init__( + self, + model_path: str, + sequence_length: int = 1024, + onnx_model_name: Optional[str] = None, + generation_config=None, + engine_kwargs: Optional[Dict] = None, + **kwargs, + ): + + ( + self.model_path, + self.config, + self.tokenizer, + engine_kwargs, + ) = setup_transformers_pipeline( + model_path, + sequence_length, + tokenizer_padding_side="right", + onnx_model_name=onnx_model_name, + engine_kwargs=engine_kwargs, + ) + self.verify_no_kv_cache_present() + + token_generator = TokenGeneratorOperator() + + process_inputs = ProcessInputsTextGeneration( + generation_config=process_generation_config(generation_config), + sequence_length=sequence_length, + tokenizer=self.tokenizer, + ) + engine_operator = NLEngineOperatorNoCache( + sequence_length=sequence_length, + **engine_kwargs, + ) + prepare_generation = PrepareGeneration( + sequence_length=sequence_length, + prompt_sequence_length=1, + token_generator=token_generator, + ) + generate_new_token = GenerateNewTokenOperator( + tokenizer=self.tokenizer, force_max_tokens=True + ) + compile_generations = CompileGenerations() + join_output = JoinOutput(tokenizer=self.tokenizer) + process_outputs = ProcessOutputs(tokenizer=self.tokenizer) + + ops = { + "process_input": process_inputs, + "engine_operator": engine_operator, + "prepare_generation": prepare_generation, + "generate_new_token": generate_new_token, + "compile_generations": compile_generations, + "join_output": join_output, + "process_outputs": process_outputs, + } + route = [ + "process_input", + "SPLIT", + "engine_operator", + "prepare_generation", + "generate_new_token", + "compile_generations", + "JOIN", + "join_output", + "process_outputs" + ] + + # TODO: Using the GraphRouter, but should use + # LinearRouter with appropriate split/join support + router = LinearRouter(route=route) + scheduler = [OperatorScheduler()] + super().__init__( + ops=ops, + router=router, + schedulers=scheduler, + ) + + def run(self, *args, **kwargs): + # we need to set the fixed_sequences_length flag to True + # for the non-kv cache pipeline + kwargs.update(dict(fixed_sequences_length=True, max_new_tokens=1)) + return super().run(*args, **kwargs) + + def condense_inputs(self, *args, **kwargs): + return args[0], kwargs + + def expand_inputs(self, items, batch_size): + items = [items.get(key) for key in items.keys()] + out, orig_batch_size = split_engine_inputs(items, batch_size) + combined_batches = [{"input_ids": b[0], "attention_mask": b[1]} for b in out] + return combined_batches, orig_batch_size + + def verify_no_kv_cache_present(self) -> bool: + """ + Verifies that the ONNX model does not have + KV cache inputs/outputs present. + :return: True if compatible, False otherwise + """ + is_kv_cache_present = any(default_cached_outputs(self.model_path)) + if is_kv_cache_present: + raise ValueError( + f"The model: {self.model_path} has KV cache inputs/outputs present. " + "Please use the TextGenerationPipeline instead." + ) + return not is_kv_cache_present diff --git a/src/deepsparse/v2/text_generation/prep_for_generation.py b/src/deepsparse/v2/text_generation/prep_for_generation.py new file mode 100644 index 0000000000..9b63946c16 --- /dev/null +++ b/src/deepsparse/v2/text_generation/prep_for_generation.py @@ -0,0 +1,99 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +from typing import Any + +import numpy + +from deepsparse.transformers.pipelines.text_generation import FinishReason +from deepsparse.transformers.utils.helpers import set_generated_length +from deepsparse.v2.operators import Operator +from deepsparse.v2.text_generation import TokenGeneratorOperator +from deepsparse.v2.utils import InferenceState + + +__all__ = ["PrepareGeneration"] + + +class PrepareGeneration(Operator): + def __init__( + self, + token_generator: TokenGeneratorOperator, + prompt_sequence_length: int, + sequence_length: int, + ): + self.sequence_length = sequence_length + self.token_generator_creator = token_generator + self.prompt_sequence_length = prompt_sequence_length + + def can_operate(self, inp: Any): + kv_cache = inp.get("kv_cache") + tokens = inp.get("tokens") + + # If the number of prompt tokens is greater than what we've processed, + # don't start generation. Should be equal when started as all prompt logits + # should be accounted for and we should have updated the kv_cache for the single + # token engine. + if len(tokens) == kv_cache.total_num_processed_tokens: + return True + return False + + def run( + self, tokens: Any, kv_cache: Any, inference_state: InferenceState, **kwargs + ): + prompt_logits = inference_state.current_state.get("prompt_logits") + prompt_logits = numpy.concatenate(prompt_logits, axis=1) + # TODO: clean this up such that dont have to keep writing current_state + # everywhere + + generation_config = inference_state.current_state.get("generation_config") + include_prompt_logits = inference_state.current_state.get( + "include_prompt_logits" + ) + + token_generator_creator_output = self.token_generator_creator.run( + logits_shape=prompt_logits[0, -1, :].shape, + deterministic=not generation_config.do_sample, + sampling_temperature=generation_config.temperature, + tokens=copy.copy(tokens), + **inference_state.current_state, + ) + token_generator = token_generator_creator_output.get("token_generator") + token_generator.generate(prompt_logits[0, -1, :]) + + max_tokens, length_finish_reason = set_generated_length( + max_length=generation_config.max_length, + prompt_tokens_length=1, + max_new_tokens=generation_config.max_new_tokens, + sequence_length=self.sequence_length, + prompt_sequence_length=self.prompt_sequence_length, + finish_reason_choices=FinishReason, + ) + state_update = { + "max_tokens": max_tokens, + "length_finish_reason": length_finish_reason, + "generated_tokens": [token_generator.tokens[-1]], + "generated_logits": [prompt_logits] + if include_prompt_logits + else [numpy.expand_dims(prompt_logits[:, -1, :], 0)], + "finished_reason": [], + "token_generator": token_generator, + } + output = { + "logits": prompt_logits, + "tokens": token_generator.tokens, + "kv_cache": kv_cache, + "in_generation": True, + } + return output, state_update diff --git a/src/deepsparse/v2/text_generation/prep_for_prefill.py b/src/deepsparse/v2/text_generation/prep_for_prefill.py new file mode 100644 index 0000000000..2e5fecb3e8 --- /dev/null +++ b/src/deepsparse/v2/text_generation/prep_for_prefill.py @@ -0,0 +1,64 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any + +from deepsparse.v2.operators import Operator +from deepsparse.v2.utils import PipelineState + + +_LOGGER = logging.getLogger(__name__) + +__all__ = ["PrepareforPrefill"] + + +class PrepareforPrefill(Operator): + def __init__(self, kv_cache_creator: Operator): + """ + Operator before prefill. Responsible for creating the kv_cache based on engine + variables. Currently, this operator expects that the kv_cache_creator is + provided during initization and then uses pipeline_state to run the + kv_cache_operator. + """ + # NOTE: Alternatively, we can initialize the kv_cache_creater operator here, + # instead of at the pipeline level. + self.kv_cache_creator = kv_cache_creator + + _LOGGER.warn( + "This operator requires the PipelineState to be set-up with the " + "cache_shape, output_names, kv_cache_data_type attributes to be set " + "from the NLEngineOperator" + ) + + def run( + self, + input_ids: Any, + attention_mask: Any, + pipeline_state: PipelineState, + **kwargs, + ): + # NOTE: Can potentially just be class attributes instead of relying on + # pipeline state. + cache_shape = pipeline_state.current_state.get("cache_shape") + data_type = pipeline_state.current_state.get("kv_cache_data_type") + output_names = pipeline_state.current_state.get("output_names") + + tokens = input_ids[attention_mask.nonzero()].tolist() + kv_cache = self.kv_cache_creator.run( + cache_shape=cache_shape, + kv_cache_data_type=data_type, + output_names=output_names, + ).get("kv_cache") + return {"tokens": tokens, "kv_cache": kv_cache} diff --git a/src/deepsparse/v2/text_generation/process_inputs.py b/src/deepsparse/v2/text_generation/process_inputs.py new file mode 100644 index 0000000000..85956416a1 --- /dev/null +++ b/src/deepsparse/v2/text_generation/process_inputs.py @@ -0,0 +1,111 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pathlib +from typing import Dict, Union + +import transformers + +from deepsparse.transformers.pipelines.text_generation import ( + GenerationDefaults, + TextGenerationInput, +) +from deepsparse.transformers.utils.helpers import ( + check_and_return_generation_config, + override_config, + repeat_inputs, +) +from deepsparse.v2.operators import Operator + + +__all__ = ["ProcessInputsTextGeneration"] + + +class ProcessInputsTextGeneration(Operator): + """ + Input processing operator. Responsible for tokenizing the input, handling the + generation_config (if provided), updating the inference_state for later use, + and returning the tokens for prompt inference. The expected input is defined by + the input_schema, which for this operator is TextGenerationInput. + """ + + input_schema = TextGenerationInput + + def __init__( + self, + tokenizer: transformers.PreTrainedTokenizerBase, + sequence_length: int, + generation_config: Union[ + str, pathlib.Path, Dict, transformers.GenerationConfig + ] = None, + ): + self.generation_config = generation_config + self.tokenizer = tokenizer + self.sequence_length = sequence_length + + def run(self, inp: TextGenerationInput, **kwargs): + generation_config = check_and_return_generation_config( + self.generation_config, inp.generation_config, GenerationDefaults() + ) + + generation_config = override_config(inp.generation_kwargs, generation_config) + + original_inputs = inp.sequences + if generation_config.num_return_sequences > 1: + if isinstance(inp.sequences, str): + inp.sequences = [inp.sequences] + inp.sequences = repeat_inputs( + inp.sequences, generation_config.num_return_sequences + ) + + if inp.fixed_sequences_length: + # to enforce a fixed sequence length, we need to + # truncate the input to the maximum sequence length + # or/and pad it to the maximum sequence length + truncate, padding = True, "max_length" + else: + # otherwise, we do not need to truncate the input + # and we shall can pad it to the longest sequence + # in the batch (so that the engine can process multiple inputs + # at once) + truncate, padding = False, "longest" + + input_tokens = self.tokenizer( + inp.sequences, + return_tensors="np", + max_length=self.sequence_length, + padding=padding, + truncation=truncate, + ) + + input_ids = input_tokens["input_ids"] + attention_mask = input_tokens["attention_mask"] + + inference_state_update = dict( + prompts=original_inputs, + streaming=inp.streaming, + generation_config=generation_config, + include_prompt_logits=inp.include_prompt_logits, + callback=inp.callback, + stop=inp.stop, + top_p=generation_config.top_p, + top_k=generation_config.top_k, + presence_penalty=inp.presence_penalty, + frequency_penalty=generation_config.repetition_penalty, + ) + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + }, inference_state_update diff --git a/src/deepsparse/v2/text_generation/process_outputs.py b/src/deepsparse/v2/text_generation/process_outputs.py new file mode 100644 index 0000000000..7173b8e256 --- /dev/null +++ b/src/deepsparse/v2/text_generation/process_outputs.py @@ -0,0 +1,97 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import datetime +from typing import Optional + +import numpy + +from deepsparse.transformers.pipelines.text_generation import ( + FinishReason, + GeneratedText, + TextGenerationOutput, +) +from deepsparse.v2.operators import Operator +from deepsparse.v2.utils import InferenceState + + +class ProcessOutputs(Operator): + output_schema = TextGenerationOutput + + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + def _create_generated_text_output( + self, + sequence: str, + finish_reason: Optional[FinishReason] = None, + logits: Optional[numpy.array] = None, + ): + if finish_reason: + return GeneratedText( + text=sequence, + score=logits, + finished=True, + finished_reason=finish_reason.value, + ) + return GeneratedText( + text=sequence, + score=logits, + finished=False, + ) + + def run( + self, + generated_tokens: numpy.ndarray, + generated_logits: numpy.ndarray, + finished_reason: list, + inference_state: InferenceState, + **kwargs, + ): + generation_config = inference_state.current_state.get("generation_config") + generated_logits = generated_logits if generation_config.output_scores else None + sequences = self.tokenizer.batch_decode( + generated_tokens, skip_special_tokens=True + ) + + finished_reason = [f[-1] for f in finished_reason] + + if generated_logits is not None: + generations = list( + map( + self._create_generated_text_output, + sequences, + finished_reason, + generated_logits, + ) + ) + else: + generations = list( + map(self._create_generated_text_output, sequences, finished_reason) + ) + + num_preds = generation_config.num_return_sequences + if num_preds > 1: + grouped_generations = [ + generations[n : n + num_preds] + for n in range(0, len(generations), num_preds) + ] + generations = grouped_generations + + outputs = dict( + created=datetime.datetime.now(), + prompts=inference_state.current_state.get("prompts"), + generations=generations, + ) + + return outputs diff --git a/src/deepsparse/v2/text_generation/token_generator.py b/src/deepsparse/v2/text_generation/token_generator.py new file mode 100644 index 0000000000..9148d71cc8 --- /dev/null +++ b/src/deepsparse/v2/text_generation/token_generator.py @@ -0,0 +1,30 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from deepsparse.transformers.utils.token_generator import TokenGenerator +from deepsparse.v2.operators import Operator + + +__all__ = ["TokenGeneratorOperator"] + + +class TokenGeneratorOperator(Operator): + def run(self, logits_shape, deterministic, tokens, sampling_temperature, **kwargs): + token_generator = TokenGenerator( + logits_shape=logits_shape, + deterministic=deterministic, + tokens=tokens, + sampling_temperature=sampling_temperature, + **kwargs, + ) + return {"token_generator": token_generator} diff --git a/src/deepsparse/v2/utils/__init__.py b/src/deepsparse/v2/utils/__init__.py new file mode 100644 index 0000000000..75935a9729 --- /dev/null +++ b/src/deepsparse/v2/utils/__init__.py @@ -0,0 +1,21 @@ +# flake8: noqa + +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .helpers import * +from .state import * +from .types import * + + +from .data import * # isort:skip diff --git a/src/deepsparse/v2/utils/data.py b/src/deepsparse/v2/utils/data.py new file mode 100644 index 0000000000..40402734cf --- /dev/null +++ b/src/deepsparse/v2/utils/data.py @@ -0,0 +1,39 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, List + +from deepsparse.v2.utils import InferenceState + + +__all__ = ["SubGraph"] + + +@dataclass +class SubGraph: + """ + Helper dataclass to store information about each running sub graph. + """ + + step: int + inf: InferenceState + end: List[str] + output: Any = None + + def parse_output(self, operator_output: Any): + if isinstance(operator_output, tuple): + operator_output, state_update = operator_output[0], operator_output[-1] + self.inf.update_state(state_update) + return operator_output diff --git a/src/deepsparse/v2/utils/helpers.py b/src/deepsparse/v2/utils/helpers.py new file mode 100644 index 0000000000..1f4bedc6c9 --- /dev/null +++ b/src/deepsparse/v2/utils/helpers.py @@ -0,0 +1,37 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable + + +__all__ = ["run_func"] + + +def run_func( + *args, + func: Callable, + inp: Any = None, + **kwargs, +): + """ + Generic function to run a given Callable. + """ + if inp: + output = ( + func(*args, **kwargs, **inp) + if isinstance(inp, dict) + else func(inp, *args, **kwargs) + ) + else: + output = func(*args, **kwargs) + return output diff --git a/src/deepsparse/v2/utils/state.py b/src/deepsparse/v2/utils/state.py new file mode 100644 index 0000000000..b54b890acf --- /dev/null +++ b/src/deepsparse/v2/utils/state.py @@ -0,0 +1,64 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings +from abc import ABC +from typing import Any, Union + + +__all__ = ["State", "PipelineState", "InferenceState"] + + +class State(ABC): + """ + Abstract class to store pipeline-level and inference-level state variables which + are generated by some Operator, and required by some other Operator. + """ + + def __init__(self): + self._current_state = None + + @property + def current_state(self): + return self._current_state + + +class PipelineState(State): + """ + Created during pipeline initialization. Pipeline state values are ready-only + duirng inference. + """ + + def create_state(self, new_state: dict): + if self._current_state: + raise ValueError("State creation is only allowed during initialization.") + self._current_state = new_state + + +class InferenceState(State): + """ + Inference state, created during every inference run. + """ + + def create_state(self, new_state: dict): + if self._current_state: + warnings.warn("Current state already exists, overriding.") + self._current_state = new_state + + def update_value(self, attribute: str, value: Union[str, int, list]): + if not self._current_state.get(attribute): + raise ValueError(f"{attribute} is not a valid state attribute") + self._current_state[attribute] = value + + def update_state(self, value: Any): + self._current_state.update(value) diff --git a/src/deepsparse/v2/utils/types.py b/src/deepsparse/v2/utils/types.py new file mode 100644 index 0000000000..3e4b974453 --- /dev/null +++ b/src/deepsparse/v2/utils/types.py @@ -0,0 +1,28 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Types to support deepsparse pipelines +""" + +from typing import Any, Dict, Union + +from pydantic import BaseModel + + +__all__ = ["OperatorSchema"] + + +# Operator inputs and outputs may either be a pydantic base model or a dict of kwargs +OperatorSchema = Union[BaseModel, Dict[str, Any]] diff --git a/tests/deepsparse/transformers/utils/test_helpers.py b/tests/deepsparse/transformers/utils/test_helpers.py index 7fcadcbf9c..95e4ee7fa7 100644 --- a/tests/deepsparse/transformers/utils/test_helpers.py +++ b/tests/deepsparse/transformers/utils/test_helpers.py @@ -16,12 +16,86 @@ import pytest from deepsparse.transformers.utils.helpers import ( + compute_engine_inputs, create_causal_mask, initialize_kv_cache_state, validate_session_ids, ) +@pytest.mark.parametrize( + "onnx_input_names, " + "token_batch, " + "prompt_sequence_length, " + "sequence_length, " + "num_total_processed_tokens, " + "expected_engine_inputs", + [ + ( + ["input_ids", "attention_mask", "positions"], + [1, 2, 3], + 3, + 6, + 2, + [ + numpy.array([[1, 2, 3]]), + numpy.array([[0, 1, 1, 1, 1, 1]]), + numpy.array([[2, 3, 4]]), + ], + ), + ( + ["input_ids", "attention_mask", "positions", "causal_mask"], + [1, 2, 3], + 3, + 6, + 2, + [ + numpy.array([[1, 2, 3]]), + numpy.array([[0, 1, 1, 1, 1, 1]]), + numpy.array([[2, 3, 4]]), + create_causal_mask( + input_ids=numpy.array([[1, 2, 3]]), + attention_mask=numpy.array([[0, 1, 1, 1, 1, 1]]), + ), + ], + ), + ( + ["input_ids", "attention_mask", "positions", "causal_mask"], + [15], + 1, + 5, + 3, + [ + numpy.array([[15]]), + numpy.array([[0, 1, 1, 1, 1]]), + numpy.array([[3]]), + create_causal_mask( + input_ids=numpy.array([[15]]), + attention_mask=numpy.array([[0, 1, 1, 1, 1]]), + ), + ], + ), + ], +) +def test_compute_engine_inputs( + onnx_input_names, + token_batch, + prompt_sequence_length, + sequence_length, + num_total_processed_tokens, + expected_engine_inputs, +): + engine_inputs = compute_engine_inputs( + onnx_input_names=onnx_input_names, + token_batch=token_batch, + prompt_sequence_length=prompt_sequence_length, + sequence_length=sequence_length, + num_total_processed_tokens=num_total_processed_tokens, + ) + for x, y in zip(engine_inputs, expected_engine_inputs): + assert numpy.array_equal(x, y) + + @pytest.mark.parametrize( "input_ids, attention_mask, expected_causal_mask", [ diff --git a/tests/deepsparse/v2/__init__.py b/tests/deepsparse/v2/__init__.py new file mode 100644 index 0000000000..0c44f887a4 --- /dev/null +++ b/tests/deepsparse/v2/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/deepsparse/v2/integration_tests/__init__.py b/tests/deepsparse/v2/integration_tests/__init__.py new file mode 100644 index 0000000000..0c44f887a4 --- /dev/null +++ b/tests/deepsparse/v2/integration_tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/deepsparse/v2/integration_tests/configs/codegen.yaml b/tests/deepsparse/v2/integration_tests/configs/codegen.yaml new file mode 100644 index 0000000000..9ec212a6cc --- /dev/null +++ b/tests/deepsparse/v2/integration_tests/configs/codegen.yaml @@ -0,0 +1,7 @@ +cadence: "nightly" +model_path: "zoo:nlg/text_generation/codegen_mono-350m/pytorch/huggingface/bigpython_bigquery_thepile/base-none" +torch_model_name: "salesforce/codegen-350m-mono" +model_name_no_kv_cache: None +prompt: "\ndef Fibonacci(n):\n # Check if input is 0 then it will\n # print incorrect input" +precision: 0.0001 +internal_kv_cache: [True, False] \ No newline at end of file diff --git a/tests/deepsparse/v2/integration_tests/configs/gpt_neo.yaml b/tests/deepsparse/v2/integration_tests/configs/gpt_neo.yaml new file mode 100644 index 0000000000..71c57e1f97 --- /dev/null +++ b/tests/deepsparse/v2/integration_tests/configs/gpt_neo.yaml @@ -0,0 +1,7 @@ +cadence: "commit" +model_path: "hf:mgoin/TinyStories-1M-ds" +torch_model_name: "roneneldan/TinyStories-1M" +model_name_no_kv_cache: "model-orig.onnx" +prompt: "Didn't know what time it was, the lights were low\n I leaned back on my radio" +precision: 0.001 +internal_kv_cache: [True, False] \ No newline at end of file diff --git a/tests/deepsparse/v2/integration_tests/configs/opt.yaml b/tests/deepsparse/v2/integration_tests/configs/opt.yaml new file mode 100644 index 0000000000..216d4c03ca --- /dev/null +++ b/tests/deepsparse/v2/integration_tests/configs/opt.yaml @@ -0,0 +1,7 @@ +cadence: "nightly" +model_path: "zoo:nlg/text_generation/opt-1.3b/pytorch/huggingface/opt_pretrain/base-none" +torch_model_name: "facebook/opt-1.3b" +model_name_no_kv_cache: None +prompt: "Didn't know what time it was, the lights were low\n I leaned back on my radio" +precision: 0.0001 +internal_kv_cache: [True, False] \ No newline at end of file diff --git a/tests/deepsparse/v2/integration_tests/helpers.py b/tests/deepsparse/v2/integration_tests/helpers.py new file mode 100644 index 0000000000..8d7f3d58d2 --- /dev/null +++ b/tests/deepsparse/v2/integration_tests/helpers.py @@ -0,0 +1,137 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from typing import Any, Dict, List, Tuple, Union + +import numpy +import yaml +from transformers import AutoModelForCausalLM, AutoTokenizer + +import pytest + + +class TorchGroundTruthSource: + """ + An object that generates ground truth logits and + cache states from a prompt. This object can + generate tokens in an autoregressive manner, and thus + will output: + - prompt logits, + - generated logits, + - prompt cache state, + - generated sequence + """ + + def __init__(self, num_tokens_to_generate: int, model_name: str): + + self.model = AutoModelForCausalLM.from_pretrained(model_name) + self.tokenizer = self._create_tokenizer(model_name) + + self.num_tokens_to_generate = num_tokens_to_generate + + def tokenize(self, prompt: str): + return self.tokenizer(prompt, return_tensors="pt") + + def __call__( + self, prompt: str + ) -> Tuple[numpy.ndarray, numpy.ndarray, List[numpy.ndarray], str]: + # afaik it is not possible to get 'past_key_values' from + # the generate method, so we have to run the model twice + out = self.model.generate( + self.tokenize(prompt).input_ids, + max_new_tokens=self.num_tokens_to_generate, + output_scores=True, + return_dict_in_generate=True, + use_cache=True, + ) + generated_text = self.tokenizer.decode( + out.sequences[0], skip_special_tokens=True + ) + generated_logits = numpy.concatenate( + [[score.numpy() for score in out.scores]] + ).transpose( + 1, 0, 2 + ) # (1, num_tokens_to_generate, vocab_size) + + out = self.model(**self.tokenize(prompt)) + prompt_logits = out.logits.detach().numpy()[ + :, :-1, : + ] # (1, prompt_length, vocab_size) + prompt_cache = [ + entry.detach().numpy() + for key_value_tuple in out.past_key_values + for entry in key_value_tuple + ] # List[(1, num_heads, past_length, head_dim)] + + return generated_logits, prompt_logits, prompt_cache, generated_text + + @staticmethod + def _create_tokenizer(model_name): + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.padding_side = "left" + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + return tokenizer + + +def parse_params(configs_directory: str) -> List[Dict[str, Any]]: + # parses the config file provided + assert os.path.isdir( + configs_directory + ), f"Config_directory {configs_directory} is not a directory" + + config_dicts = [] + for file in os.listdir(configs_directory): + if file.endswith(".yaml"): + config_path = os.path.join(configs_directory, file) + # reads the yaml file + with open(config_path, "r") as f: + config = yaml.safe_load(f) + + cadence = os.environ.get("CADENCE", "commit") + expected_cadence = config["cadence"] + + if not isinstance(expected_cadence, list): + expected_cadence = [expected_cadence] + if cadence in expected_cadence: + config_dicts.append(config) + else: + logging.info( + f"Skipping testing model: {config['model_path']} " + f"for cadence: {config['cadence']}" + ) + else: + raise FileNotFoundError( + f"Could not find a yaml file in {configs_directory}" + ) + return config_dicts + + +def validate_internal_kv_cache( + internal_kv_cache, available_kv_cache_types: Union[str, List[str]] +) -> bool: + if internal_kv_cache and True not in available_kv_cache_types: + pytest.skip( + "The tests for running the pipeline with " + "internal kv cache management are disabled." + ) + if not internal_kv_cache and False not in available_kv_cache_types: + pytest.skip( + "The tests for running the pipeline with " + "external kv cache management are disabled." + ) + return internal_kv_cache diff --git a/tests/deepsparse/v2/integration_tests/test_llms.py b/tests/deepsparse/v2/integration_tests/test_llms.py new file mode 100644 index 0000000000..3485658dda --- /dev/null +++ b/tests/deepsparse/v2/integration_tests/test_llms.py @@ -0,0 +1,321 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This test suite consumes config files to test the text generation pipeline +for various scenarios. + +A sample config file is a yaml that r_equires the following fields: + cadence: The cadence of the tests. The available options are: + "nightly", "weekly" and "commit". By default, only + the tests that have cadence "commit" will be run + in GHA. This parameter can be both a string or a + list of strings. + model_path: The path to the model to be tested + (sparsezoo stub/hf model path/local_path) + model_name_no_kv_cache: The name of the onnx model without + the KV cache support + torch_model_name: The name of the torch model + (to generate ground truth info) + prompt: The prompt to use for testing + precision: The precision for the logits/kv_cache entries + comparison + internal_kv_cache: The type of the internal KV cache + management. Is a list that can contain the following + values: [True], [False] or [True, False] (to test both + external and internal KV cache management) +""" +from typing import List, Tuple + +import numpy + +import pytest +from deepsparse.transformers.pipelines.text_generation import TextGenerationOutput +from deepsparse.v2.pipeline import Pipeline +from deepsparse.v2.text_generation import ( + TextGenerationPipeline, + TextGenerationPipelineNoCache, +) +from tests.deepsparse.transformers.pipelines.integration_tests.helpers import ( + TorchGroundTruthSource, + parse_params, + validate_internal_kv_cache, +) + + +CONFIGS_DIRECTORY = "tests/deepsparse/v2/integration_tests/configs" + + +@pytest.fixture() +def max_new_tokens() -> int: + return 64 + + +@pytest.mark.parametrize("params_dict", parse_params(CONFIGS_DIRECTORY)) +@pytest.mark.parametrize( + "internal_kv_cache", + [True, False], +) +class TestsIntegrationLLMsPipelines: + """ + This test suite is meant to test the main scenarios of + the text generation pipeline. + """ + + def get_pipeline(self, kv_cache_support=True, **kwargs) -> Pipeline: + """ + If no kwargs provided, returns the cached "default" + pipeline that is used for most of the tests. + Otherwise, returns a pipeline with the given kwargs + (the default pipeline kwargs are updated with the + user-provided kwargs) + + :param kwargs: the optional kwargs to be used to + create the pipeline (if not provided, the cached + "default" pipeline is returned) + :return: the appropriate pipeline + """ + # TODO: This if statement should disappear once + # the TextGenerationPipeline contains the + # non-kv-cache version of the pipeline + text_generation_pipeline_class = ( + TextGenerationPipeline + if kv_cache_support + else TextGenerationPipelineNoCache + ) + if not kwargs: + if self.default_pipeline is None: + self.default_pipeline = text_generation_pipeline_class( + **self.default_pipeline_kwargs + ) + return self.default_pipeline + + # return a pipeline with the updated default kwargs + updated_kwargs = self.default_pipeline_kwargs.copy() + updated_kwargs.update(kwargs) + return text_generation_pipeline_class(**updated_kwargs) + + @pytest.fixture + def setup(self, params_dict, max_new_tokens, internal_kv_cache): + # set the params_dict as the class attributes + for key, value in params_dict.items(): + setattr(self, key, value) + # check whether the specified cache management type + # is supported for testing (skip if not supported) + self.internal_kv_cache: bool = validate_internal_kv_cache( + internal_kv_cache, self.internal_kv_cache + ) + # create torch ground source + torch_source = TorchGroundTruthSource( + num_tokens_to_generate=max_new_tokens + 1, + model_name=self.torch_model_name, + ) + # create torch ground truth + self.torch_ground_truth = torch_source(self.prompt) + + # specify the default pipeline kwargs + self.default_pipeline_kwargs = dict( + model_path=self.model_path, + internal_kv_cache=self.internal_kv_cache, + force_max_tokens=True, + ) + self.default_pipeline = None + self.max_new_tokens = max_new_tokens + + def test_ort_single_token_prefill(self, setup): + # Test the pipeline that uses ORT engine. The test covers the + # following scenario: + # 1. Prompt preprocessing is performed by single-token engine + # 2. The KV Cache is never filled up + # 3. KV Cache managed externally + + if self.internal_kv_cache: + pytest.skip( + "Cannot run ORT pipeline with the internal deepsparse cache enabled." + ) + + pipeline = self.get_pipeline( + prompt_sequence_length=1, + engine_kwargs=dict(engine_type="onnxruntime"), + ) + output = pipeline( + prompt=self.prompt, + include_prompt_logits=True, + generation_kwargs=dict( + max_new_tokens=self.max_new_tokens, + output_scores=True, + ), + ) + + self._test_output( + output=output, + torch_ground_truth=self.torch_ground_truth, + ) + + def test_ort_multi_token_prefill(self, setup): + # Test the pipeline that uses ORT engine. The test covers the + # following scenario: + # 1. Prompt preprocessing is performed by multi-token engine + # 2. The KV Cache is never filled up + # 3. KV Cache managed externally + + if self.internal_kv_cache: + pytest.skip( + "Cannot run ORT pipeline with the internal deepsparse cache enabled." + ) + pipeline = self.get_pipeline( + engine_kwargs=dict(engine_type="onnxruntime"), + ) + output = pipeline( + prompt=self.prompt, + include_prompt_logits=True, + generation_kwargs=dict( + max_new_tokens=self.max_new_tokens, output_scores=True + ), + ) + + self._test_output( + output=output, + torch_ground_truth=self.torch_ground_truth, + ) + + def test_deepsparse_single_token_prefill(self, setup): + # Test the pipeline that uses deepsparse engine. The test covers the + # following scenario: + # 1. Prompt preprocessing is performed by single-token engine + # 2. The KV Cache is never filled up + # 3. KV Cache managed externally or internally + + pipeline = self.get_pipeline( + prompt_sequence_length=1, + ) + + output = pipeline( + prompt=self.prompt, + include_prompt_logits=True, + generation_kwargs=dict( + max_new_tokens=self.max_new_tokens, output_scores=True + ), + ) + + self._test_output( + output=output, + torch_ground_truth=self.torch_ground_truth, + # disable kv cache validation if using internal kv cache + run_kv_cache_validation=not self.internal_kv_cache, + ) + + def test_deepsparse_multi_token_prefill(self, setup): + # Test the pipeline that uses deepsparse engine. The test covers the + # following scenario: + # 1. Prompt preprocessing is performed by multi-token engine + # 2. The KV Cache is never filled up + # 3. KV Cache managed internally or externally + + pipeline = self.get_pipeline() + output = pipeline( + prompt=self.prompt, + include_prompt_logits=True, + generation_kwargs=dict( + max_new_tokens=self.max_new_tokens, output_scores=True + ), + ) + + self._test_output( + output=output, + torch_ground_truth=self.torch_ground_truth, + # disable kv cache validation if using internal kv cache + run_kv_cache_validation=not self.internal_kv_cache, + ) + + def test_inference_no_kv_cache_deepsparse(self, setup): + self._test_inference_no_kv_cache(engine_type="deepsparse") + + def test_inference_no_kv_cache_ort(self, setup): + self._test_inference_no_kv_cache(engine_type="onnxruntime") + + def _test_inference_no_kv_cache(self, engine_type): + pipeline = self.get_pipeline( + onnx_model_name=self.model_name_no_kv_cache, + kv_cache_support=False, + engine_kwargs=dict(engine_type=engine_type), + ) + + output = pipeline( + prompt=[self.prompt, self.prompt], + include_prompt_logits=True, + generation_kwargs=dict(output_scores=True), + ) + + # logits -> prompt logits + one logit for the new generated token + generated_logits, prompt_logits, *_ = self.torch_ground_truth + logits_gt = numpy.concatenate( + [prompt_logits[0], generated_logits[0, :1, :]], axis=0 + ) + for gen in output.generations: + assert numpy.allclose(gen.score, logits_gt, atol=self.precision) + + def _test_output( + self, + output: TextGenerationOutput, + torch_ground_truth: Tuple[numpy.ndarray, ...], + run_kv_cache_validation: bool = True, + ): + + ( + generated_logits, + prompt_logits, + prompt_kv_cache, + generated_text, + ) = torch_ground_truth + + # concatenate target prompt_logits and generated_logits + target_logits = numpy.concatenate([prompt_logits, generated_logits], axis=1) + # get the logits of the generated sequence + score = output.generations[0].score + + # we expect the logits to be exactly the same + # as the target logits; the generated sequence should + # also be the same as the target sequence + assert numpy.allclose(score, target_logits[0], atol=self.precision) + assert self.prompt + output.generations[0].text == generated_text + + if hasattr(output, "kv_cache_state") and run_kv_cache_validation: + # (if applicable) the kv cache should be the same as the + # target kv cache + expected_cache = list(output.kv_cache_state[0].values()) + total_num_processed_tokens = output.total_num_processed_tokens[0] + self._test_kv_cache_state( + expected_cache=expected_cache, + target_cache=prompt_kv_cache, + total_num_processed_tokens=total_num_processed_tokens, + ) + + def _test_kv_cache_state( + self, + expected_cache: List[numpy.ndarray], + target_cache: List[numpy.ndarray], + total_num_processed_tokens: int, + ): + for x, y in zip(expected_cache, target_cache): + start_index = total_num_processed_tokens + end_index = total_num_processed_tokens - y.shape[2] + # x is (in general) composed of three arrays: + # - padding cache entries (from 0 to -start_index) + # - prompt cache entries (from -start_index to -end_index) + # - generated cache entries (from -end_index to -1) + # as target_cache only pertains to prompt cache entries, we need to + # compare only the prompt cache entries in x with y + assert numpy.allclose( + x[:, :, -start_index:-end_index, :], y, atol=self.precision + ) diff --git a/tests/deepsparse/v2/schedulers/__init__.py b/tests/deepsparse/v2/schedulers/__init__.py new file mode 100644 index 0000000000..0c44f887a4 --- /dev/null +++ b/tests/deepsparse/v2/schedulers/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/deepsparse/v2/schedulers/test_continuous_batching_scheduler.py b/tests/deepsparse/v2/schedulers/test_continuous_batching_scheduler.py new file mode 100644 index 0000000000..7ed49de004 --- /dev/null +++ b/tests/deepsparse/v2/schedulers/test_continuous_batching_scheduler.py @@ -0,0 +1,48 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from concurrent.futures import Future + +import numpy + +from deepsparse.v2.operators import EngineOperator +from deepsparse.v2.schedulers import ContinuousBatchingScheduler + + +def test_continuous_batching_executor_thread(): + # simple test that ContinuousBatchingScheduler can be instantiated and return + # a result from a request, for testing multi-batch execution, making enough + # concurrent requests guarantee batched execution is out of scope + scheduler = ContinuousBatchingScheduler() + + # mobilenet model with batch_size=2 + engine_operator = EngineOperator( + "zoo:mobilenet_v2-1.0-imagenet-base", + batch_size=1, + ) + + scheduler.add_engine_operator(engine_operator, [1]) + + # submit job to scheduler and expect future to be returned + engine_input = engine_operator.input_schema( + engine_inputs=[numpy.random.randn(1, 3, 224, 224).astype(numpy.float32)] + ) + future = scheduler.submit(engine_input, operator=engine_operator) + assert isinstance(future, Future) + assert not future.done() # assume this runs before engine has a chance to complete + + # assert that output resolves and contains a numpy array + engine_output = future.result() + assert isinstance(engine_output, engine_operator.output_schema) + assert isinstance(engine_output.engine_outputs[0], numpy.ndarray) diff --git a/tests/deepsparse/v2/schedulers/utils/__init__.py b/tests/deepsparse/v2/schedulers/utils/__init__.py new file mode 100644 index 0000000000..0c44f887a4 --- /dev/null +++ b/tests/deepsparse/v2/schedulers/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/deepsparse/v2/schedulers/utils/test_continuous_batching_executor.py b/tests/deepsparse/v2/schedulers/utils/test_continuous_batching_executor.py new file mode 100644 index 0000000000..1d5ed9d92b --- /dev/null +++ b/tests/deepsparse/v2/schedulers/utils/test_continuous_batching_executor.py @@ -0,0 +1,83 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from concurrent.futures import Future + +import numpy + +from deepsparse.v2.operators import EngineOperator +from deepsparse.v2.schedulers.utils import ( + ContinuousBatchingExecutorThread, + ContinuousBatchingQueues, +) + + +def test_continuous_batching_executor_thread(): + # mobilenet model with batch_size=2 + engine_operator = EngineOperator("zoo:mobilenet_v2-1.0-imagenet-base", batch_size=2) + + # create queues object and add operator + queues = ContinuousBatchingQueues() + queues.add_queue(engine_operator, batch_sizes=[2]) + + # create engine map + operators_to_engines = {engine_operator: {2: engine_operator.engine}} + + worker_thread = ContinuousBatchingExecutorThread(queues, operators_to_engines) + + # thread not started yet + assert not worker_thread.is_alive() + + # start and assert thread is alive + worker_thread.start() + assert worker_thread.is_alive() + + # create first input and add it to queue + input_1 = engine_operator.input_schema( + engine_inputs=[numpy.random.randn(1, 3, 224, 224).astype(numpy.float32)] + ) + future_1 = Future() + queues.add_queue_item(engine_operator, input_1, future=future_1) + + # assert that future is not yet resolved + assert not future_1.done() + + # create second input and add it to queue + input_2 = engine_operator.input_schema( + engine_inputs=[numpy.random.randn(1, 3, 224, 224).astype(numpy.float32)] + ) + future_2 = Future() + queues.add_queue_item(engine_operator, input_2, future=future_2) + + # wait 1 second to give engine time to complete + time.sleep(1) + + assert future_1.done() + assert future_2.done() + + result_1 = future_1.result() + result_2 = future_2.result() + + assert isinstance(result_1, engine_operator.output_schema) + assert isinstance(result_2, engine_operator.output_schema) + + def assert_batch_size_one(arrays): + for array in arrays: + assert array.shape[0] == 1 + + # make sure only a single batch item was returned to each future + # TODO: test that the correct bs1 item is returned (can test against bs1 engine) + assert_batch_size_one(result_1.engine_outputs) + assert_batch_size_one(result_2.engine_outputs) diff --git a/tests/deepsparse/v2/schedulers/utils/test_continuous_batching_queues.py b/tests/deepsparse/v2/schedulers/utils/test_continuous_batching_queues.py new file mode 100644 index 0000000000..1713d54f82 --- /dev/null +++ b/tests/deepsparse/v2/schedulers/utils/test_continuous_batching_queues.py @@ -0,0 +1,177 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from threading import Thread + +import pytest +from deepsparse.v2.schedulers.utils import ( + ContinuousBatchingQueue, + ContinuousBatchingQueues, + QueueEntry, +) + + +@pytest.mark.parametrize( + "batch_sizes,num_entries,expected_batch_size", + [ + ([1, 4, 8], 20, 8), + ([1, 4, 8], 6, 4), + ([1, 4, 8], 4, 4), + ([1, 4, 8], 3, 1), + ([4], 5, 4), + ], +) +def test_queue_single_pop(batch_sizes, num_entries, expected_batch_size): + queue = ContinuousBatchingQueue(batch_sizes=batch_sizes) + assert not queue.has_batch() + for i in range(num_entries): + queue.put(i) + + assert queue.has_batch() + assert queue.max_queued_batch_size() == expected_batch_size + + batch = queue.pop_batch() + assert len(batch) == expected_batch_size + assert batch == list(range(expected_batch_size)) + + +def test_queue_multi_pop(): + queue = ContinuousBatchingQueue(batch_sizes=[2, 4, 8]) + + for i in range(23): + if i < 2: + assert not queue.has_batch() + else: + assert queue.has_batch() + queue.put(i) + + def pop_and_assert_queue_size_and_pop(expected_qsize, expected_batch_size): + assert queue.qsize() == expected_qsize + assert queue.has_batch() + assert queue.max_queued_batch_size() == expected_batch_size + assert len(queue.pop_batch()) == expected_batch_size + + # pop items from queue, checkign remaining qsize and correct batch size is popped + pop_and_assert_queue_size_and_pop(23, 8) + pop_and_assert_queue_size_and_pop(15, 8) + pop_and_assert_queue_size_and_pop(7, 4) + pop_and_assert_queue_size_and_pop(3, 2) + + assert not queue.has_batch() + queue.put(23) + pop_and_assert_queue_size_and_pop(2, 2) + + assert queue.empty() + + +def test_queue_invalid_pop(): + queue = ContinuousBatchingQueue(batch_sizes=[4, 8]) + for i in range(3): + queue.put(i) + + with pytest.raises(RuntimeError): + # queue size 3, min batch size 4 + queue.pop_batch() + + +def test_queues_pop_batch_max_valid_batch(): + queues = ContinuousBatchingQueues() + + queues.add_queue("key_1", [2, 4]) + queues.add_queue("key_2", [3]) + + assert not queues.has_next_batch() + + queues.add_queue_item("key_1", 1) + queues.add_queue_item("key_1", 2) + assert queues.has_next_batch() + + queues.add_queue_item("key_2", 1) + queues.add_queue_item("key_2", 2) + queues.add_queue_item("key_2", 3) + # NOTE - if this block takes more than 100ms, test may fail + # as timeout may lead key_1 to be popped first + + # key_2 should be popped first because it has larger loaded batch size + first_popped_key, first_popped_batch = queues.pop_batch() + assert first_popped_key == "key_2" + assert len(first_popped_batch) == 3 + assert all(isinstance(item, QueueEntry) for item in first_popped_batch) + + assert queues.has_next_batch() + + second_popped_key, second_popped_batch = queues.pop_batch() + assert second_popped_key == "key_1" + assert len(second_popped_batch) == 2 + assert all(isinstance(item, QueueEntry) for item in second_popped_batch) + + +def test_queues_pop_batch_time_elapsed_priority(): + queues = ContinuousBatchingQueues() + + queues.add_queue("key_1", [2, 4]) + queues.add_queue("key_2", [3]) + + assert not queues.has_next_batch() + + queues.add_queue_item("key_1", 1) + queues.add_queue_item("key_1", 2) + assert queues.has_next_batch() + + # sleep 150ms (time threshold is 100ms) + time.sleep(0.15) + + queues.add_queue_item("key_2", 1) + queues.add_queue_item("key_2", 2) + queues.add_queue_item("key_2", 3) + + # key 1 should be popped first because its first item has been waiting longer + # than the time threshold and key_2 was just added + + popped_key, popped_batch = queues.pop_batch() + assert popped_key == "key_1" + assert len(popped_batch) == 2 + + +def test_queues_pop_batch_blocking(): + queues = ContinuousBatchingQueues() + queues.add_queue("key_1", [2]) + + def test_fn(): + # pop batch and block until true + key, batch = queues.pop_batch(block=True) + # compare to expected results + assert key == "key_1" + assert batch == [1, 2] + + # start a thread to pop batch + # it should hang indefinitely because block=True and there are no items yet in queue + thread = Thread(target=queues.pop_batch) + thread.start() + + # confirm thread is still running + assert thread.is_alive() + time.sleep(0.15) + # sleep and confirm thread is still hanging + assert thread.is_alive() + + # confirm thread still runs after a single insertion (min batch size is 2) + queues.add_queue_item("key_1", 1) + assert thread.is_alive() + + # add a second item and assert thread finishes + queues.add_queue_item("key_1", 2) + time.sleep(0.1) + assert not thread.is_alive() diff --git a/tests/deepsparse/v2/test_basic_pipeline.py b/tests/deepsparse/v2/test_basic_pipeline.py new file mode 100644 index 0000000000..bedddd537a --- /dev/null +++ b/tests/deepsparse/v2/test_basic_pipeline.py @@ -0,0 +1,60 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Simple example and test of a dummy pipeline +""" + +from typing import Dict + +from pydantic import BaseModel + +from deepsparse.v2 import Pipeline +from deepsparse.v2.operators import Operator +from deepsparse.v2.routers import LinearRouter +from deepsparse.v2.schedulers import OperatorScheduler + + +class IntSchema(BaseModel): + value: int + + +class AddOneOperator(Operator): + input_schema = IntSchema + output_schema = IntSchema + + def run(self, inp: IntSchema, **kwargs) -> Dict: + return {"value": inp.value + 1} + + +class AddTwoOperator(Operator): + input_schema = IntSchema + output_schema = IntSchema + + def run(self, inp: IntSchema, **kwargs) -> Dict: + return {"value": inp.value + 2} + + +AddThreePipeline = Pipeline( + ops=[AddOneOperator(), AddTwoOperator()], + router=LinearRouter(end_route=2), + schedulers=[OperatorScheduler()], +) + + +def test_run_simple_pipeline(): + pipeline_input = IntSchema(value=5) + pipeline_output = AddThreePipeline(pipeline_input) + + assert pipeline_output.value == 8 diff --git a/tests/deepsparse/v2/test_image_classification.py b/tests/deepsparse/v2/test_image_classification.py new file mode 100644 index 0000000000..03e2807454 --- /dev/null +++ b/tests/deepsparse/v2/test_image_classification.py @@ -0,0 +1,39 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy + +import pytest +from deepsparse.v2.image_classification import ImageClassificationPipeline +from deepsparse.v2.image_classification.preprocess_operator import ( + ImageClassificationInput, +) +from tests.deepsparse.pipelines.data_helpers import computer_vision + + +@pytest.fixture +def get_images(): + batch_size = 2 + images = computer_vision(batch_size=batch_size) + return images.get("images") + + +def test_image_classification(get_images): + model_path = ( + "zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/pruned95-none" + ) + pipeline = ImageClassificationPipeline(model_path=model_path) + output = pipeline(ImageClassificationInput(images=get_images)) + assert output.labels == [[207], [670]] + assert numpy.allclose(output.scores, [[21.85], [17.33]], atol=0.01) diff --git a/tests/deepsparse/v2/unit/text_generation/conftest.py b/tests/deepsparse/v2/unit/text_generation/conftest.py new file mode 100644 index 0000000000..3840a9bb0a --- /dev/null +++ b/tests/deepsparse/v2/unit/text_generation/conftest.py @@ -0,0 +1,172 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +import numpy +from transformers import AutoTokenizer + +import pytest +from deepsparse.transformers.helpers import get_deployment_path +from deepsparse.transformers.pipelines.text_generation import ( + GenerationDefaults, + TextGenerationInput, +) +from deepsparse.transformers.utils import DecoderKVCache +from deepsparse.transformers.utils.helpers import initialize_kv_cache_state +from deepsparse.v2 import InferenceState, PipelineState +from deepsparse.v2.text_generation import NLEngineOperator, TokenGeneratorOperator + + +@pytest.fixture(scope="module") +def text_generation_attributes(): + sequence_length = 5 + prompt_sequence_length = 1 + return sequence_length, prompt_sequence_length + + +@pytest.fixture(scope="module") +def model_attributes(text_generation_attributes): + model_path = "hf:mgoin/TinyStories-1M-deepsparse" + sequence_length, _ = text_generation_attributes + deployment_path, model_path = get_deployment_path(model_path) + + tokenizer = AutoTokenizer.from_pretrained( + deployment_path, + trust_remote_code=False, + model_max_length=sequence_length, + ) + + tokenizer.padding_side = "left" + if not tokenizer.pad_token: + tokenizer.pad_token = tokenizer.eos_token + + return tokenizer, model_path + + +@pytest.fixture(scope="module") +def single_token_engine_no_internal_cache(text_generation_attributes, model_attributes): + seq_length, _ = text_generation_attributes + _, model_path = model_attributes + + nl_engine_operator = NLEngineOperator( + sequence_length=seq_length, input_ids_length=1, model_path=model_path + ) + return nl_engine_operator + + +@pytest.fixture(scope="module") +def pipeline_state(single_token_engine_no_internal_cache): + pipeline_state = PipelineState() + pipeline_state_vals = {} + pipeline_state_vals[ + "onnx_input_names_no_cache" + ] = single_token_engine_no_internal_cache.onnx_input_names_no_cache + pipeline_state_vals[ + "cache_shape" + ] = single_token_engine_no_internal_cache.cache_shape + pipeline_state_vals[ + "output_names" + ] = single_token_engine_no_internal_cache.output_names + pipeline_state_vals[ + "kv_cache_data_type" + ] = single_token_engine_no_internal_cache.kv_cache_data_type + pipeline_state.create_state(pipeline_state_vals) + return pipeline_state + + +@pytest.fixture(scope="module") +def large_prompt(): + prompt = "Hello, how are you doing today?" + generation_config = {"top_p": 0, "top_k": 0, "max_length": 10} + return TextGenerationInput(prompt=prompt, generation_config=generation_config) + + +@pytest.fixture(scope="module") +def small_prompt(): + prompt = "Hello" + return TextGenerationInput(prompt=prompt) + + +@pytest.fixture(scope="module") +def mock_kv_cache(): + kv_cache = DecoderKVCache() + kv_cache.setup( + state={"dummy_cache_name": numpy.array([[[[0], [0], [1], [2], [3]]]])}, + ) + return kv_cache + + +@pytest.fixture(scope="module") +def mock_kv_cache_three_tokens_processed(): + kv_cache = DecoderKVCache() + kv_cache.setup( + state={"dummy_cache_name": numpy.array([[[[0], [0], [1], [2], [3]]]])}, + num_processed_tokens=3, + ) + return kv_cache + + +@pytest.fixture(scope="module") +def mock_kv_cache_single_token_engine(pipeline_state, text_generation_attributes): + seq_len, prompt_seq_len = text_generation_attributes + kv_cache = DecoderKVCache() + kv_cache_state = initialize_kv_cache_state( + cache_shape=pipeline_state.current_state.get("cache_shape"), + kv_cache_data_type=pipeline_state.current_state.get("kv_cache_data_type"), + output_names=pipeline_state.current_state.get("output_names"), + length=seq_len - prompt_seq_len, + empty=False, + ) + kv_cache.setup(state=kv_cache_state) + return kv_cache + + +@pytest.fixture(scope="module") +def mock_tokens(): + return [15496] + + +@pytest.fixture(scope="module") +def mock_tokens_multiple(): + return [15496, 15496, 15496] + + +@pytest.fixture(scope="module") +def mock_inference_state(): + generation_config = GenerationDefaults() + inference_state = InferenceState() + inference_state.create_state({}) + inference_state.update_state({"generation_config": generation_config}) + return inference_state + + +@pytest.fixture(scope="module") +def mock_token_generator(model_attributes, mock_tokens_multiple): + tokenizer, _ = model_attributes + token_generator_creator = TokenGeneratorOperator() + prompt_logits = numpy.random.rand(1, len(mock_tokens_multiple), len(tokenizer)) + token_generator_creator_output = token_generator_creator.run( + logits_shape=prompt_logits[0, -1, :].shape, + deterministic=True, + sampling_temperature=1.0, + tokens=copy.copy(mock_tokens_multiple), + ) + return token_generator_creator_output.get("token_generator") + + +@pytest.fixture(scope="module") +def mock_logits(model_attributes): + tokenizer, _ = model_attributes + return numpy.random.rand(1, 1, len(tokenizer)) diff --git a/tests/deepsparse/v2/unit/text_generation/test_kv_cache.py b/tests/deepsparse/v2/unit/text_generation/test_kv_cache.py new file mode 100644 index 0000000000..0c6e42503a --- /dev/null +++ b/tests/deepsparse/v2/unit/text_generation/test_kv_cache.py @@ -0,0 +1,41 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from deepsparse.v2.text_generation import KVCacheCreator, KVCacheCreatorInput + + +def test_kv_cache_creation( + text_generation_attributes, model_attributes, pipeline_state +): + """ + Check if the KVCacheCreator successfully creates a kv_cache object, given the + single_token_engine attributes stored in the pipeline_state. + """ + seq_length, prompt_seq_len = text_generation_attributes + tokenizer, _ = model_attributes + kv_cache_creator = KVCacheCreator( + tokenizer=tokenizer, + prompt_sequence_length=prompt_seq_len, + sequence_length=seq_length, + internal_kv_cache=False, + ) + + assert kv_cache_creator.input_schema == KVCacheCreatorInput + kv_cache = kv_cache_creator.run( + cache_shape=pipeline_state.current_state.get("cache_shape"), + kv_cache_data_type=pipeline_state.current_state.get("kv_cache_data_type"), + output_names=pipeline_state.current_state.get("output_names"), + ) + assert kv_cache.get("kv_cache") + assert kv_cache.get("kv_cache").total_num_processed_tokens == 0 diff --git a/tests/deepsparse/v2/unit/text_generation/test_misc.py b/tests/deepsparse/v2/unit/text_generation/test_misc.py new file mode 100644 index 0000000000..f215e2aedb --- /dev/null +++ b/tests/deepsparse/v2/unit/text_generation/test_misc.py @@ -0,0 +1,38 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from deepsparse.v2.text_generation import CompilePromptLogits +from deepsparse.v2.text_generation.nl_engine_operator import NLEngineOutputs + + +def test_compile_logits(mock_logits, mock_inference_state, mock_tokens, mock_kv_cache): + mock_inference_state.update_state({"prompt_logits": [mock_logits]}) + compile_prompt_logits = CompilePromptLogits() + # Can operate as long as we're not in generation but in prompt_inference. This + # can_operate() will check for the `in_generation` flag in the input. + inp = NLEngineOutputs( + engine_outputs=mock_logits, + tokens=mock_tokens, + kv_cache=mock_kv_cache, + in_generation=None, + ) + assert compile_prompt_logits.can_operate(inp=inp) + output, state = compile_prompt_logits.run( + inp=inp, inference_state=mock_inference_state + ) + # The CompilePromptLogits is responsible for updating a list of prompt logits + # calculated at each step during prompt inference. After one step of running this + # operator, the total number of prompt_logits in the inference state should be + # the current length of prompt logits + 1 + assert len(state.get("prompt_logits")) == len([mock_logits]) + 1 diff --git a/tests/deepsparse/v2/unit/text_generation/test_pipeline_no_kv_cache.py b/tests/deepsparse/v2/unit/text_generation/test_pipeline_no_kv_cache.py new file mode 100644 index 0000000000..a6fbfc4d11 --- /dev/null +++ b/tests/deepsparse/v2/unit/text_generation/test_pipeline_no_kv_cache.py @@ -0,0 +1,43 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest +from deepsparse.v2.text_generation import TextGenerationPipelineNoCache + + +@pytest.mark.parametrize( + "onnx_model_name, raise_error", + [("model.onnx", True), (None, True), ("model-orig.onnx", False)], +) +def test_verify_no_kv_cache_present(model_attributes, onnx_model_name, raise_error): + _, model_path = model_attributes + # model_path points to .../directory/model.onnx + # we need to go up one level to .../directory + model_path = os.path.dirname(model_path) + + if raise_error: + with pytest.raises(ValueError): + if onnx_model_name is None: + TextGenerationPipelineNoCache(model_path=model_path) + else: + TextGenerationPipelineNoCache( + model_path=model_path, onnx_model_name=onnx_model_name + ) + return + else: + TextGenerationPipelineNoCache( + model_path=model_path, onnx_model_name=onnx_model_name + ) diff --git a/tests/deepsparse/v2/unit/text_generation/test_process_inputs.py b/tests/deepsparse/v2/unit/text_generation/test_process_inputs.py new file mode 100644 index 0000000000..02f4540c44 --- /dev/null +++ b/tests/deepsparse/v2/unit/text_generation/test_process_inputs.py @@ -0,0 +1,45 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from deepsparse.transformers.pipelines.text_generation import GenerationDefaults +from deepsparse.v2.text_generation import ProcessInputsTextGeneration + + +def test_process_inputs( + text_generation_attributes, model_attributes, small_prompt, large_prompt +): + """ + Check if the ProcessInputsTextGeneration Operator successfully processes the + inputs and generation config. + """ + sequence_length, _ = text_generation_attributes + tokenizer, _ = model_attributes + process_inputs = ProcessInputsTextGeneration( + sequence_length=sequence_length, tokenizer=tokenizer + ) + + outputs, state_update = process_inputs.run(small_prompt) + assert len(outputs.get("input_ids")) == 1 + assert len(outputs.get("attention_mask")) == 1 + assert isinstance(state_update.get("generation_config"), GenerationDefaults) + assert state_update.get("prompts") == small_prompt.sequences + + outputs, state_update = process_inputs.run(large_prompt) + + assert not isinstance(state_update.get("generation_config"), GenerationDefaults) + assert state_update.get( + "generation_config" + ).max_length == large_prompt.generation_config.get("max_length") + assert outputs.get("input_ids") is not None + assert state_update.get("top_k") == large_prompt.generation_config.get("top_k") diff --git a/tests/deepsparse/v2/unit/text_generation/test_single_token_engine.py b/tests/deepsparse/v2/unit/text_generation/test_single_token_engine.py new file mode 100644 index 0000000000..19bb4d1c4a --- /dev/null +++ b/tests/deepsparse/v2/unit/text_generation/test_single_token_engine.py @@ -0,0 +1,98 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy + +from deepsparse.v2.text_generation import ( + AutoRegressiveOperatorPreprocess, + NLEngineInputs, +) + + +def test_autoreg_preproces_can_run( + text_generation_attributes, pipeline_state, mock_tokens, mock_kv_cache +): + """ + Check if the single-token engine preprocess operator can run based on the provided + tokens and prompt_sequence_length. + """ + + seq_len, _ = text_generation_attributes + autoreg_prep = AutoRegressiveOperatorPreprocess( + sequence_length=seq_len, prompt_sequence_length=len(mock_tokens) + 1 + ) + inputs = {"tokens": mock_tokens, "kv_cache": mock_kv_cache} + + # The prompt_sequence_length is greater than the number of tokens that are to be + # operated on. Therefore, use the single_token_engine and can_operate() should be + # True. + assert autoreg_prep.can_operate(inputs) + outputs = autoreg_prep.run( + tokens=mock_tokens, kv_cache=mock_kv_cache, pipeline_state=pipeline_state + ) + # Assert 4 engine inputs: tokens, attention mask, causal, positions + assert len(outputs.get("engine_inputs")) == 4 + tokens, attention_mask, positions, causal_mask = outputs.get("engine_inputs") + + assert tokens.shape[-1] == 1 + assert attention_mask.shape[-1] == seq_len + assert positions[0] == mock_kv_cache.total_num_processed_tokens + assert outputs.get("in_generation") is None + + +def test_autoreg_preproces_cant_run( + text_generation_attributes, mock_kv_cache, mock_tokens_multiple +): + """ + Check if the single-token engine preprocess operator can run based on the provided + tokens and prompt_sequence_length. + """ + + seq_len, _ = text_generation_attributes + autoreg_prep = AutoRegressiveOperatorPreprocess( + sequence_length=seq_len, prompt_sequence_length=len(mock_tokens_multiple) + ) + inputs = {"tokens": mock_tokens_multiple, "kv_cache": mock_kv_cache} + # can_operate() should be False as the prompt_sequence_length is equal to the + # number of tokens we want to operate on. Therefore, the multi-token engine + # should run instead. + assert not autoreg_prep.can_operate(inputs) + + +def test_nl_single_token_engine_no_internal(single_token_engine_no_internal_cache): + assert single_token_engine_no_internal_cache.input_ids_length == 1 + + +def test_run_single_token_engine_once( + single_token_engine_no_internal_cache, + mock_kv_cache_single_token_engine, +): + """ + This operator runs through the single-token NLEngine once, given engine_inputs and + kv_cache. + """ + + mock_engine_inputs = [ + numpy.array([[15496]]), + numpy.array([[0, 0, 0, 0, 1]]), + numpy.array([[0]]), + numpy.array([[[[0, 0, 0, 0, 1]]]]), + ] + inputs = NLEngineInputs( + engine_inputs=mock_engine_inputs, + kv_cache=mock_kv_cache_single_token_engine, + tokens=mock_engine_inputs[0].tolist(), + ) + output = single_token_engine_no_internal_cache.run(inputs) + assert output.get("engine_outputs") is not None diff --git a/tests/deepsparse/v2/unit/text_generation/test_token_generation.py b/tests/deepsparse/v2/unit/text_generation/test_token_generation.py new file mode 100644 index 0000000000..219b1048fd --- /dev/null +++ b/tests/deepsparse/v2/unit/text_generation/test_token_generation.py @@ -0,0 +1,102 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy + +from deepsparse.v2.text_generation import ( + GenerateNewTokenOperator, + PrepareGeneration, + TokenGeneratorOperator, +) +from deepsparse.v2.text_generation.nl_engine_operator import NLEngineOutputs + + +def test_prep_for_generation( + text_generation_attributes, + model_attributes, + mock_tokens_multiple, + mock_kv_cache_three_tokens_processed, + mock_inference_state, +): + """ + This test will assess the PrepareGeneration, which runs after prompt_inference + and before generation. + """ + seq_len, prompt_seq_len = text_generation_attributes + tokenizer, _ = model_attributes + prep_for_generation = PrepareGeneration( + prompt_sequence_length=prompt_seq_len, + token_generator=TokenGeneratorOperator(), + sequence_length=seq_len, + ) + inputs = { + "tokens": mock_tokens_multiple, + "kv_cache": mock_kv_cache_three_tokens_processed, + } + # can_operate() if the total number of prompt tokens is equal to the + # number of processed tokens stored in the kv_cache, indicating prompt inference is + # complete and generation can begin. + assert prep_for_generation.can_operate(inputs) + + prompt_logits = [numpy.random.rand(1, len(mock_tokens_multiple), len(tokenizer))] + mock_inference_state.update_state({"prompt_logits": prompt_logits}) + outputs, state = prep_for_generation.run( + tokens=mock_tokens_multiple, + kv_cache=mock_kv_cache_three_tokens_processed, + inference_state=mock_inference_state, + ) + assert len(outputs.get("tokens")) == len(mock_tokens_multiple) + 1 + assert outputs.get("in_generation") + assert numpy.array_equal( + state.get("generated_logits")[0], + numpy.expand_dims(prompt_logits[0][:, -1, :], 0), + ) + + +def test_generate_new_token( + model_attributes, + mock_token_generator, + mock_kv_cache, + mock_inference_state, + mock_logits, + mock_tokens, +): + """ + This test is responsible for testing the GenerateNewTokenOperator, which generates + one new token, given a token_generator (stored in the inference_state) and logits + from the engine. + """ + tokenizer, _ = model_attributes + generate_new_token = GenerateNewTokenOperator( + force_max_tokens=False, tokenizer=tokenizer + ) + mock_inference_state.update_state( + { + "token_generator": mock_token_generator, + "generated_tokens": [mock_token_generator.tokens], + } + ) + inp = NLEngineOutputs( + engine_outputs=mock_logits, + tokens=mock_tokens, + kv_cache=mock_kv_cache, + in_generation=True, + ) + outputs, state = generate_new_token.run( + logits=inp.engine_outputs, + kv_cache=inp.kv_cache, + inference_state=mock_inference_state, + ) + # The new_token generated/returned by ths operator should match the last token in + # token_generator + assert outputs.get("new_token") == state.get("token_generator").tokens[-1] diff --git a/tests/deepsparse/v2/unit/text_generation/text_multi_token_engine.py b/tests/deepsparse/v2/unit/text_generation/text_multi_token_engine.py new file mode 100644 index 0000000000..d2c822af4c --- /dev/null +++ b/tests/deepsparse/v2/unit/text_generation/text_multi_token_engine.py @@ -0,0 +1,63 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from deepsparse.v2.text_generation import MultiEnginePrefill + + +def test_mult_engine_preprocess( + text_generation_attributes, pipeline_state, mock_kv_cache, mock_tokens_multiple +): + """ + Check if the multi-token engine preprocess operator can run based on the provided + tokens and prompt_sequence_length. + """ + + seq_len, _ = text_generation_attributes + multi_prep = MultiEnginePrefill( + sequence_length=seq_len, prompt_sequence_length=len(mock_tokens_multiple) + ) + inputs = {"tokens": mock_tokens_multiple, "kv_cache": mock_kv_cache} + # The number of tokens is equal to the prompt_sequence_length. + # Therefore, the multi_token_engine can run and can_operate() should be True. + assert multi_prep.can_operate(inputs) + outputs = multi_prep.run( + tokens=mock_tokens_multiple, + kv_cache=mock_kv_cache, + pipeline_state=pipeline_state, + ) + # Expect 4 engine inputs: tokens, attention mask, causal, positions + assert len(outputs.get("engine_inputs")) == 4 + tokens, attention_mask, positions, causal_mask = outputs.get("engine_inputs") + # Assert proper shapes for all engine_inputs + assert tokens.shape[-1] == len(mock_tokens_multiple) + assert attention_mask.shape[-1] == seq_len + assert positions.shape[-1] == len(mock_tokens_multiple) + + +def test_multi_engine_preprocess_cant_operate( + text_generation_attributes, mock_kv_cache, mock_tokens +): + """ + Check if the multi-token engine preprocess operator can run based on the provided + tokens and prompt_sequence_length. + """ + seq_len, _ = text_generation_attributes + multi_prep = MultiEnginePrefill( + sequence_length=seq_len, prompt_sequence_length=len(mock_tokens) + 1 + ) + inputs = {"tokens": mock_tokens, "kv_cache": mock_kv_cache} + # The prompt_sequence_length is one greater than the total number of tokens we're + # processing. Therefore, this operator should not run and can_operate() should be + # False. + assert not multi_prep.can_operate(inputs)