From 1c4b6701c8f28a553a0fc0ab7dfaec715b61f65d Mon Sep 17 00:00:00 2001 From: Phat <109610375+phatvo9@users.noreply.github.com> Date: Fri, 15 Sep 2023 11:14:02 +0700 Subject: [PATCH] [EAGLE-3454] [EAGLE-3447] - Organize model types and Fix text output dim (#171) * init * fix inp fields map * addressed comment * fix text output error * addressed comment + update requirements * update requirements * add PyYaml in setup --- .../models/model_serving/cli/deploy_cli.py | 5 +- .../models/model_serving/cli/repository.py | 8 +- clarifai/models/model_serving/constants.py | 6 +- .../model_serving/docs/custom_config.md | 11 +- .../model_serving/model_config/__init__.py | 2 + .../model_serving/model_config/config.py | 298 ++++++++++++++++++ .../model_serving/model_config/deploy.py | 75 ----- .../model_types_config/text-classifier.yaml | 18 ++ .../model_types_config/text-embedder.yaml | 18 ++ .../model_types_config/text-to-image.yaml | 18 ++ .../model_types_config/text-to-text.yaml | 18 ++ .../model_types_config/visual-classifier.yaml | 18 ++ .../model_types_config/visual-detector.yaml | 28 ++ .../model_types_config/visual-embedder.yaml | 18 ++ .../model_types_config/visual-segmenter.yaml | 18 ++ .../model_serving/model_config/serializer.py | 2 +- .../model_config/triton_config.py | 226 ------------- .../model_serving/models/default_test.py | 43 +-- .../models/model_serving/models/output.py | 4 +- .../model_serving/pb_model_repository.py | 3 +- requirements.txt | 1 + setup.py | 3 +- 22 files changed, 494 insertions(+), 347 deletions(-) create mode 100644 clarifai/models/model_serving/model_config/config.py delete mode 100644 clarifai/models/model_serving/model_config/deploy.py create mode 100644 clarifai/models/model_serving/model_config/model_types_config/text-classifier.yaml create mode 100644 clarifai/models/model_serving/model_config/model_types_config/text-embedder.yaml create mode 100644 clarifai/models/model_serving/model_config/model_types_config/text-to-image.yaml create mode 100644 clarifai/models/model_serving/model_config/model_types_config/text-to-text.yaml create mode 100644 clarifai/models/model_serving/model_config/model_types_config/visual-classifier.yaml create mode 100644 clarifai/models/model_serving/model_config/model_types_config/visual-detector.yaml create mode 100644 clarifai/models/model_serving/model_config/model_types_config/visual-embedder.yaml create mode 100644 clarifai/models/model_serving/model_config/model_types_config/visual-segmenter.yaml delete mode 100644 clarifai/models/model_serving/model_config/triton_config.py diff --git a/clarifai/models/model_serving/cli/deploy_cli.py b/clarifai/models/model_serving/cli/deploy_cli.py index 14e64f18..058e126c 100644 --- a/clarifai/models/model_serving/cli/deploy_cli.py +++ b/clarifai/models/model_serving/cli/deploy_cli.py @@ -15,8 +15,7 @@ from clarifai.auth.helper import ClarifaiAuthHelper from clarifai.models.api import Models -from clarifai.models.model_serving.constants import MODEL_TYPES -from clarifai.models.model_serving.model_config.deploy import ClarifaiFieldsMap +from clarifai.models.model_serving.model_config import MODEL_TYPES, get_model_config def deploy(model_url, @@ -50,7 +49,7 @@ def _parse_name(name): assert model_type, "Can not parse model_type from url, please input it directly" # key map assert model_type in MODEL_TYPES, f"model_type should be one of {MODEL_TYPES}" - clarifai_key_map = ClarifaiFieldsMap(model_type=model_type) + clarifai_key_map = get_model_config(model_type=model_type).field_maps # if updating new version of existing model if update_version: resp = model_api.post_model_version( diff --git a/clarifai/models/model_serving/cli/repository.py b/clarifai/models/model_serving/cli/repository.py index e474b794..178c1e98 100644 --- a/clarifai/models/model_serving/cli/repository.py +++ b/clarifai/models/model_serving/cli/repository.py @@ -14,8 +14,8 @@ import argparse -from ..constants import MODEL_TYPES -from ..model_config.triton_config import TritonModelConfig +from ..constants import MAX_HW_DIM +from ..model_config import MODEL_TYPES, get_model_config from ..pb_model_repository import TritonModelRepository @@ -63,7 +63,6 @@ def model_upload_init(): help="Directory to create triton repository.") args = parser.parse_args() - MAX_HW_DIM = 1024 if len(args.image_shape) != 2: raise ValueError( @@ -75,10 +74,9 @@ def model_upload_init(): f"H and W each have a maximum value of 1024. Got H: {args.image_shape[0]}, W: {args.image_shape[1]}" ) - model_config = TritonModelConfig( + model_config = get_model_config(args.model_type).make_triton_model_config( model_name=args.model_name, model_version="1", - model_type=args.model_type, image_shape=args.image_shape, ) diff --git a/clarifai/models/model_serving/constants.py b/clarifai/models/model_serving/constants.py index be60b734..633fbe94 100644 --- a/clarifai/models/model_serving/constants.py +++ b/clarifai/models/model_serving/constants.py @@ -1,5 +1 @@ -# Clarifai model types -MODEL_TYPES = [ - "visual-detector", "visual-classifier", "text-classifier", "text-to-text", "text-embedder", - "text-to-image", "visual-embedder", "visual-segmenter" -] +MAX_HW_DIM = 1024 diff --git a/clarifai/models/model_serving/docs/custom_config.md b/clarifai/models/model_serving/docs/custom_config.md index 09049c18..281586e7 100644 --- a/clarifai/models/model_serving/docs/custom_config.md +++ b/clarifai/models/model_serving/docs/custom_config.md @@ -15,17 +15,16 @@ $ clarifai-model-upload-init --model_name \ ## Generating the triton model repository without the commandline -The triton model repository can be generated via a python script specifying the same values as required in the commandline. Below is a sample of how the code would be structured. +The triton model repository can be generated via a python script specifying the same values as required in the commandline. Below is a sample of how the code would be structured with `visual_classifier`. ```python -from clarifai.models.model_serving.model_config.triton_config import TritonModelConfig -from clarifai.models.model_serving pb_model_repository import TritonModelRepository +from clarifai.models.model_serving.model_config import get_model_config, ModelTypes, TritonModelConfig +from clarifai.models.model_serving.pb_model_repository import TritonModelRepository - -model_config = TritonModelConfig( +model_type = ModelTypes.visual_classifier +model_config: TritonModelConfig = get_model_config(model_type).make_triton_model_config( model_name="", model_version="1", - model_type="", image_shape=<[H,W]>, # 0 < [H,W] <= 1024 ) diff --git a/clarifai/models/model_serving/model_config/__init__.py b/clarifai/models/model_serving/model_config/__init__.py index 453892ea..a4dded79 100644 --- a/clarifai/models/model_serving/model_config/__init__.py +++ b/clarifai/models/model_serving/model_config/__init__.py @@ -10,3 +10,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from .config import * # noqa # pylint: disable=unused-import +from .serializer import Serializer # noqa # pylint: disable=unused-import diff --git a/clarifai/models/model_serving/model_config/config.py b/clarifai/models/model_serving/model_config/config.py new file mode 100644 index 00000000..e772dc03 --- /dev/null +++ b/clarifai/models/model_serving/model_config/config.py @@ -0,0 +1,298 @@ +# Copyright 2023 Clarifai, Inc. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Model Config classes.""" + +from dataclasses import asdict, dataclass, field +from typing import List + +import yaml + +from ..models.model_types import * # noqa # pylint: disable=unused-import +from ..models.output import * # noqa # pylint: disable=unused-import + +__all__ = ["get_model_config", "MODEL_TYPES", "TritonModelConfig", "ModelTypes"] + +### Triton Model Config classes.### + + +@dataclass +class DType: + """ + Triton Model Config data types. + """ + # https://github.com/triton-inference-server/common/blob/main/protobuf/model_config.proto + TYPE_UINT8: int = 2 + TYPE_INT8: int = 6 + TYPE_INT16: int = 7 + TYPE_INT32: int = 8 + TYPE_INT64: int = 9 + TYPE_FP16: int = 10 + TYPE_FP32: int = 11 + TYPE_STRING: int = 13 + KIND_GPU: int = 1 + KIND_CPU: int = 2 + + +@dataclass +class InputConfig: + """ + Triton Input definition. + Params: + ------- + name: input name + data_type: input data type + dims: Pre-defined input data shape(s). + + Returns: + -------- + InputConfig + """ + name: str + data_type: int + dims: List = field(default_factory=list) + + +@dataclass +class OutputConfig: + """ + Triton Output definition. + Params: + ------- + name: output name + data_type: output data type + dims: Pre-defined output data shape(s). + labels (bool): If labels file is required for inference. + + Returns: + -------- + OutputConfig + """ + name: str + data_type: int + dims: List = field(default_factory=list) + labels: bool = False + + def __post_init__(self): + if self.labels: + self.label_filename = "labels.txt" + del self.labels + + +@dataclass +class Device: + """ + Triton instance_group. + Define the type of inference device and number of devices to use. + Params: + ------- + count: number of devices + use_gpu: whether to use cpu or gpu. + + Returns: + -------- + Device object + """ + count: int = 1 + use_gpu: bool = True + + def __post_init__(self): + if self.use_gpu: + self.kind: str = DType.KIND_GPU + else: + self.kind: str = DType.KIND_CPU + + +@dataclass +class DynamicBatching: + """ + Triton dynamic_batching config. + Params: + ------- + preferred_batch_size: batch size + max_queue_delay_microseconds: max queue delay for a request batch + + Returns: + -------- + DynamicBatching object + """ + #preferred_batch_size: List[int] = [1] # recommended not to set + max_queue_delay_microseconds: int = 500 + + +@dataclass +class TritonModelConfig: + """ + Triton Model Config base. + Params: + ------- + name: triton inference model name + input: a list of an InputConfig field + output: a list of OutputConfig fields/dicts + instance_group: Device. see Device + dynamic_batching: Triton dynamic batching settings. + max_batch_size: max request batch size + backend: Triton Python Backend. Constant + + Returns: + -------- + TritonModelConfig + """ + model_type: str + model_name: str + model_version: str + image_shape: List #(H, W) + input: List[InputConfig] = field(default_factory=list) + output: List[OutputConfig] = field(default_factory=list) + instance_group: Device = field(default_factory=Device) + dynamic_batching: DynamicBatching = field(default_factory=DynamicBatching) + max_batch_size: int = 1 + backend: str = "python" + + def __post_init__(self): + if "image" in [each.name for each in self.input]: + image_dims = self.image_shape + image_dims.append(3) # add channel dim + self.input[0].dims = image_dims + + +### General Model Config classes & functions ### + + +# Clarifai model types +@dataclass +class ModelTypes: + visual_detector: str = "visual-detector" + visual_classifier: str = "visual-classifier" + text_classifier: str = "text-classifier" + text_to_text: str = "text-to-text" + text_embedder: str = "text-embedder" + text_to_image: str = "text-to-image" + visual_embedder: str = "visual-embedder" + visual_segmenter: str = "visual-segmenter" + + def __post_init__(self): + self.all = list(asdict(self).values()) + + +@dataclass +class InferenceConfig: + wrap_func: callable + return_type: dataclass + + +@dataclass +class FieldMapsConfig: + input_fields_map: dict + output_fields_map: dict + + +@dataclass +class DefaultTritonConfig: + input: List[InputConfig] = field(default_factory=list) + output: List[OutputConfig] = field(default_factory=list) + + +@dataclass +class ModelConfigClass: + type: str = field(init=False) + triton: DefaultTritonConfig + inference: InferenceConfig + field_maps: FieldMapsConfig + + def make_triton_model_config( + self, + model_name: str, + model_version: str, + image_shape: List = None, + instance_group: Device = Device(), + dynamic_batching: DynamicBatching = DynamicBatching(), + max_batch_size: int = 1, + backend: str = "python", + ) -> TritonModelConfig: + + return TritonModelConfig( + model_type=self.type, + model_name=model_name, + model_version=model_version, + image_shape=image_shape, + instance_group=instance_group, + dynamic_batching=dynamic_batching, + max_batch_size=max_batch_size, + backend=backend, + input=self.triton.input, + output=self.triton.output) + + +def read_config(cfg: str): + with open(cfg, encoding="utf-8") as f: + config = yaml.safe_load(f) # model dict + + # parse default triton + input_triton_configs = config["triton"]["input"] + output_triton_configs = config["triton"]["output"] + triton = DefaultTritonConfig( + input=[ + InputConfig( + name=input["name"], + data_type=eval(f"DType.{input['data_type']}"), + dims=input["dims"]) for input in input_triton_configs + ], + output=[ + OutputConfig( + name=output["name"], + data_type=eval(f"DType.{output['data_type']}"), + dims=output["dims"], + labels=output["labels"], + ) for output in output_triton_configs + ]) + + # parse inference config + inference = InferenceConfig( + wrap_func=eval(config["inference"]["wrap_func"]), + return_type=eval(config["inference"]["return_type"]), + ) + + # parse field maps for deployment + field_maps = FieldMapsConfig(**config["field_maps"]) + + return ModelConfigClass(triton=triton, inference=inference, field_maps=field_maps) + + +def get_model_config(model_type: str) -> ModelConfigClass: + """ + Get model config by model type + + Args: + + model_type (str): One of field value of ModelTypes + + Return: + ModelConfigClass + + ### Example: + >>> cfg = get_model_config(ModelTypes.text_classifier) + >>> custom_triton_config = cfg.make_triton_model_config(**kwargs) + + + """ + import os + assert model_type in MODEL_TYPES, f"`model_type` must be in {MODEL_TYPES}" + cfg = read_config( + os.path.join(os.path.dirname(__file__), "model_types_config", f"{model_type}.yaml")) + cfg.type = model_type + return cfg + + +_model_types = ModelTypes() +MODEL_TYPES = _model_types.all +del _model_types diff --git a/clarifai/models/model_serving/model_config/deploy.py b/clarifai/models/model_serving/model_config/deploy.py deleted file mode 100644 index d67c5caf..00000000 --- a/clarifai/models/model_serving/model_config/deploy.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright 2023 Clarifai, Inc. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Triton Model Config classes.""" - -from dataclasses import dataclass, field -from typing import List - - -@dataclass -class ClarifaiFieldsMap: - """ - Triton Model Config base. - Params: - ------- - model_type - - Returns: - -------- - ClarifaiFieldsMap - """ - model_type: str - input_fields_map: List = field(default_factory=list) - output_fields_map: List = field(default_factory=list) - - def __post_init__(self): - """ - Set mapping of clarifai in/output vs triton in/output - """ - text_input_fields = {"text": "text"} - image_input_fields = {"image": "image"} - - embedding_output_fields = {"embeddings": "embeddings"} - - if self.model_type == "visual-detector": - self.input_fields_map = image_input_fields - self.output_fields_map = { - "regions[...].region_info.bounding_box": "predicted_bboxes", - "regions[...].data.concepts[...].id": "predicted_labels", - "regions[...].data.concepts[...].value": "predicted_scores" - } - elif self.model_type == "visual-classifier": - self.input_fields_map = image_input_fields - self.output_fields_map = {"concepts": "softmax_predictions"} - elif self.model_type == "text-classifier": - self.input_fields_map = text_input_fields - self.output_fields_map = {"concepts": "softmax_predictions"} - elif self.model_type == "text-embedder": - self.input_fields_map = text_input_fields - self.output_fields_map = embedding_output_fields - elif self.model_type == "text-to-text": - self.input_fields_map = text_input_fields - # input and output fields are the same for text-to-text - self.output_fields_map = text_input_fields - elif self.model_type == "text-to-image": - self.input_fields_map = text_input_fields - # image output fields match image_input fields - self.output_fields_map = image_input_fields - elif self.model_type == "visual-embedder": - self.input_fields_map = image_input_fields - self.output_fields_map = embedding_output_fields - elif self.model_type == "visual-segmenter": - self.input_fields_map = image_input_fields - self.output_fields_map = { - "regions[...].region_info.mask,regions[...].data.concepts": "predicted_mask" - } diff --git a/clarifai/models/model_serving/model_config/model_types_config/text-classifier.yaml b/clarifai/models/model_serving/model_config/model_types_config/text-classifier.yaml new file mode 100644 index 00000000..24430b25 --- /dev/null +++ b/clarifai/models/model_serving/model_config/model_types_config/text-classifier.yaml @@ -0,0 +1,18 @@ +triton: + input: + - name: text + data_type: TYPE_STRING + dims: [1] + output: + - name: softmax_predictions + data_type: TYPE_FP32 + dims: [-1] + labels: true +inference: + wrap_func: text_classifier + return_type: ClassifierOutput +field_maps: + input_fields_map: + text: text + output_fields_map: + concepts: softmax_predictions diff --git a/clarifai/models/model_serving/model_config/model_types_config/text-embedder.yaml b/clarifai/models/model_serving/model_config/model_types_config/text-embedder.yaml new file mode 100644 index 00000000..9b99285e --- /dev/null +++ b/clarifai/models/model_serving/model_config/model_types_config/text-embedder.yaml @@ -0,0 +1,18 @@ +triton: + input: + - name: text + data_type: TYPE_STRING + dims: [1] + output: + - name: embeddings + data_type: TYPE_FP32 + dims: [-1] + labels: false +inference: + wrap_func: text_embedder + return_type: EmbeddingOutput +field_maps: + input_fields_map: + text: text + output_fields_map: + embeddings: embeddings diff --git a/clarifai/models/model_serving/model_config/model_types_config/text-to-image.yaml b/clarifai/models/model_serving/model_config/model_types_config/text-to-image.yaml new file mode 100644 index 00000000..94b4de53 --- /dev/null +++ b/clarifai/models/model_serving/model_config/model_types_config/text-to-image.yaml @@ -0,0 +1,18 @@ +triton: + input: + - name: text + data_type: TYPE_STRING + dims: [1] + output: + - name: image + data_type: TYPE_UINT8 + dims: [-1, -1, 3] + labels: false +inference: + wrap_func: text_to_image + return_type: ImageOutput +field_maps: + input_fields_map: + text: text + output_fields_map: + image: image diff --git a/clarifai/models/model_serving/model_config/model_types_config/text-to-text.yaml b/clarifai/models/model_serving/model_config/model_types_config/text-to-text.yaml new file mode 100644 index 00000000..3d7050b2 --- /dev/null +++ b/clarifai/models/model_serving/model_config/model_types_config/text-to-text.yaml @@ -0,0 +1,18 @@ +triton: + input: + - name: text + data_type: TYPE_STRING + dims: [1] + output: + - name: text + data_type: TYPE_STRING + dims: [1] + labels: false +inference: + wrap_func: text_to_text + return_type: TextOutput +field_maps: + input_fields_map: + text: text + output_fields_map: + text: text diff --git a/clarifai/models/model_serving/model_config/model_types_config/visual-classifier.yaml b/clarifai/models/model_serving/model_config/model_types_config/visual-classifier.yaml new file mode 100644 index 00000000..f6c7d612 --- /dev/null +++ b/clarifai/models/model_serving/model_config/model_types_config/visual-classifier.yaml @@ -0,0 +1,18 @@ +triton: + input: + - name: image + data_type: TYPE_UINT8 + dims: [-1, -1, 3] + output: + - name: softmax_predictions + data_type: TYPE_FP32 + dims: [-1] + labels: true +inference: + wrap_func: visual_classifier + return_type: ClassifierOutput +field_maps: + input_fields_map: + image: image + output_fields_map: + concepts: softmax_predictions diff --git a/clarifai/models/model_serving/model_config/model_types_config/visual-detector.yaml b/clarifai/models/model_serving/model_config/model_types_config/visual-detector.yaml new file mode 100644 index 00000000..a517aa70 --- /dev/null +++ b/clarifai/models/model_serving/model_config/model_types_config/visual-detector.yaml @@ -0,0 +1,28 @@ +triton: + input: + - name: image + data_type: TYPE_UINT8 + dims: [-1, -1, 3] + output: + - name: predicted_bboxes + data_type: TYPE_FP32 + dims: [-1, 4] + labels: false + - name: predicted_labels + data_type: TYPE_INT32 + dims: [-1, 1] + labels: true + - name: predicted_scores + data_type: TYPE_FP32 + dims: [-1, 1] + labels: false +inference: + wrap_func: visual_detector + return_type: VisualDetectorOutput +field_maps: + input_fields_map: + image: image + output_fields_map: + "regions[...].region_info.bounding_box": "predicted_bboxes" + "regions[...].data.concepts[...].id": "predicted_labels" + "regions[...].data.concepts[...].value": "predicted_scores" diff --git a/clarifai/models/model_serving/model_config/model_types_config/visual-embedder.yaml b/clarifai/models/model_serving/model_config/model_types_config/visual-embedder.yaml new file mode 100644 index 00000000..b075b534 --- /dev/null +++ b/clarifai/models/model_serving/model_config/model_types_config/visual-embedder.yaml @@ -0,0 +1,18 @@ +triton: + input: + - name: image + data_type: TYPE_UINT8 + dims: [-1, -1, 3] + output: + - name: embeddings + data_type: TYPE_FP32 + dims: [-1] + labels: false +inference: + wrap_func: visual_embedder + return_type: EmbeddingOutput +field_maps: + input_fields_map: + image: image + output_fields_map: + embeddings: embeddings diff --git a/clarifai/models/model_serving/model_config/model_types_config/visual-segmenter.yaml b/clarifai/models/model_serving/model_config/model_types_config/visual-segmenter.yaml new file mode 100644 index 00000000..9489dc6e --- /dev/null +++ b/clarifai/models/model_serving/model_config/model_types_config/visual-segmenter.yaml @@ -0,0 +1,18 @@ +triton: + input: + - name: image + data_type: TYPE_UINT8 + dims: [-1, -1, 3] + output: + - name: predicted_mask + data_type: TYPE_INT64 + dims: [-1, -1] + labels: true +inference: + wrap_func: visual_segmenter + return_type: MasksOutput +field_maps: + input_fields_map: + image: image + output_fields_map: + "regions[...].region_info.mask,regions[...].data.concepts": "predicted_mask" diff --git a/clarifai/models/model_serving/model_config/serializer.py b/clarifai/models/model_serving/model_config/serializer.py index 6bef6a78..80b8a392 100644 --- a/clarifai/models/model_serving/model_config/serializer.py +++ b/clarifai/models/model_serving/model_config/serializer.py @@ -21,7 +21,7 @@ from google.protobuf.text_format import MessageToString from tritonclient.grpc import model_config_pb2 -from .triton_config import TritonModelConfig +from .config import TritonModelConfig class Serializer: diff --git a/clarifai/models/model_serving/model_config/triton_config.py b/clarifai/models/model_serving/model_config/triton_config.py deleted file mode 100644 index 0de52dd9..00000000 --- a/clarifai/models/model_serving/model_config/triton_config.py +++ /dev/null @@ -1,226 +0,0 @@ -# Copyright 2023 Clarifai, Inc. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Triton Model Config classes.""" - -from dataclasses import dataclass, field -from typing import List - - -@dataclass -class DType: - """ - Triton Model Config data types. - """ - # https://github.com/triton-inference-server/common/blob/main/protobuf/model_config.proto - TYPE_UINT8: int = 2 - TYPE_INT8: int = 6 - TYPE_INT16: int = 7 - TYPE_INT32: int = 8 - TYPE_INT64: int = 9 - TYPE_FP16: int = 10 - TYPE_FP32: int = 11 - TYPE_STRING: int = 13 - KIND_GPU: int = 1 - KIND_CPU: int = 2 - - -@dataclass -class InputConfig: - """ - Triton Input definition. - Params: - ------- - name: input name - data_type: input data type - dims: Pre-defined input data shape(s). - - Returns: - -------- - InputConfig - """ - name: str - data_type: int - dims: List = field(default_factory=list) - - -@dataclass -class OutputConfig: - """ - Triton Output definition. - Params: - ------- - name: output name - data_type: output data type - dims: Pre-defined output data shape(s). - labels (bool): If labels file is required for inference. - - Returns: - -------- - OutputConfig - """ - name: str - data_type: int - dims: List = field(default_factory=list) - labels: bool = False - - def __post_init__(self): - if self.labels: - self.label_filename = "labels.txt" - else: - del self.labels - - -@dataclass -class Device: - """ - Triton instance_group. - Define the type of inference device and number of devices to use. - Params: - ------- - count: number of devices - use_gpu: whether to use cpu or gpu. - - Returns: - -------- - Device object - """ - count: int = 1 - use_gpu: bool = True - - def __post_init__(self): - if self.use_gpu: - self.kind: str = DType.KIND_GPU - else: - self.kind: str = DType.KIND_CPU - - -@dataclass -class DynamicBatching: - """ - Triton dynamic_batching config. - Params: - ------- - preferred_batch_size: batch size - max_queue_delay_microseconds: max queue delay for a request batch - - Returns: - -------- - DynamicBatching object - """ - #preferred_batch_size: List[int] = [1] # recommended not to set - max_queue_delay_microseconds: int = 500 - - -@dataclass -class TritonModelConfig: - """ - Triton Model Config base. - Params: - ------- - name: triton inference model name - input: a list of an InputConfig field - output: a list of OutputConfig fields/dicts - instance_group: Device. see Device - dynamic_batching: Triton dynamic batching settings. - max_batch_size: max request batch size - backend: Triton Python Backend. Constant - - Returns: - -------- - TritonModelConfig - """ - model_name: str - model_version: str - model_type: str - image_shape: List #(H, W) - input: List[InputConfig] = field(default_factory=list) - output: List[OutputConfig] = field(default_factory=list) - instance_group: Device = field(default_factory=Device) - dynamic_batching: DynamicBatching = field(default_factory=DynamicBatching) - max_batch_size: int = 1 - backend: str = "python" - - def __post_init__(self): - """ - Set supported input dims and data_types for - a given model_type. - """ - MAX_HW_DIM = 1024 - if len(self.image_shape) != 2: - raise ValueError( - f"image_shape takes 2 values, Height and Width. Got {len(self.image_shape)} instead.") - if self.image_shape[0] > MAX_HW_DIM or self.image_shape[1] > MAX_HW_DIM: - raise ValueError( - f"H and W each have a maximum value of 1024. Got H: {self.image_shape[0]}, W: {self.image_shape[1]}" - ) - image_dims = self.image_shape - image_dims.append(3) # add channel dim - image_input = InputConfig(name="image", data_type=DType.TYPE_UINT8, dims=image_dims) - text_input = InputConfig(name="text", data_type=DType.TYPE_STRING, dims=[1]) - # del image_shape as it's a temporary config that's not used by triton - del self.image_shape - - if self.model_type == "visual-detector": - self.input.append(image_input) - pred_bboxes = OutputConfig(name="predicted_bboxes", data_type=DType.TYPE_FP32, dims=[-1, 4]) - pred_labels = OutputConfig( - name="predicted_labels", data_type=DType.TYPE_INT32, dims=[-1, 1], labels=True) - del pred_labels.labels - pred_scores = OutputConfig(name="predicted_scores", data_type=DType.TYPE_FP32, dims=[-1, 1]) - self.output.extend([pred_bboxes, pred_labels, pred_scores]) - - elif self.model_type == "visual-classifier": - self.input.append(image_input) - pred_labels = OutputConfig( - name="softmax_predictions", data_type=DType.TYPE_FP32, dims=[-1], labels=True) - del pred_labels.labels - self.output.append(pred_labels) - - elif self.model_type == "text-classifier": - self.input.append(text_input) - pred_labels = OutputConfig( - name="softmax_predictions", data_type=DType.TYPE_FP32, dims=[-1], labels=True) - #'Len of out list expected to be the number of concepts returned by the model, - # with each value being the confidence for the respective model output. - del pred_labels.labels - self.output.append(pred_labels) - - elif self.model_type == "text-to-text": - self.input.append(text_input) - pred_text = OutputConfig(name="text", data_type=DType.TYPE_STRING, dims=[1], labels=False) - self.output.append(pred_text) - - elif self.model_type == "text-embedder": - self.input.append(text_input) - embedding_vector = OutputConfig( - name="embeddings", data_type=DType.TYPE_FP32, dims=[-1], labels=False) - self.output.append(embedding_vector) - - elif self.model_type == "text-to-image": - self.input.append(text_input) - gen_image = OutputConfig( - name="image", data_type=DType.TYPE_UINT8, dims=[-1, -1, 3], labels=False) - self.output.append(gen_image) - - elif self.model_type == "visual-embedder": - self.input.append(image_input) - embedding_vector = OutputConfig( - name="embeddings", data_type=DType.TYPE_FP32, dims=[-1], labels=False) - self.output.append(embedding_vector) - - elif self.model_type == "visual-segmenter": - self.input.append(image_input) - pred_masks = OutputConfig( - name="predicted_mask", data_type=DType.TYPE_INT64, dims=[-1, -1], labels=True) - del pred_masks.labels - self.output.append(pred_masks) diff --git a/clarifai/models/model_serving/models/default_test.py b/clarifai/models/model_serving/models/default_test.py index c596d5dd..61325d94 100644 --- a/clarifai/models/model_serving/models/default_test.py +++ b/clarifai/models/model_serving/models/default_test.py @@ -6,7 +6,8 @@ import numpy as np -from ..model_config.triton_config import TritonModelConfig +from ..model_config import ModelTypes +from ..model_config.config import get_model_config from .output import (ClassifierOutput, EmbeddingOutput, ImageOutput, MasksOutput, TextOutput, VisualDetectorOutput) @@ -75,19 +76,18 @@ def intitialize( model_repository=os.path.join(repo_version_dir, ".."), model_instance_kind="GPU" if self.is_instance_kind_gpu else "cpu")) # Get default config of model and model_type - self.default_triton_model_config = TritonModelConfig( - model_name=self.model_type, - model_version="1", - model_type=self.model_type, - image_shape=[-1, -1]) + self.default_triton_model_config = get_model_config(self.model_type).make_triton_model_config( + model_name=self.model_type, model_version="1", image_shape=[-1, -1]) # Get current model config self.triton_model_config = self.triton_model.config_msg self.triton_model_input_name = self.triton_model.input_name self.preprocess = self._get_preprocess() # load labels self._required_label_model_types = [ - "visual-detector", "visual-classifier", "text-classifier", "visual-segmenter" + ModelTypes.visual_detector, ModelTypes.visual_classifier, ModelTypes.text_classifier, + ModelTypes.visual_segmenter ] + self._output_text_models = [ModelTypes.text_to_text] self.labels = [] if self.model_type in self._required_label_model_types: with open(os.path.join(repo_version_dir, "../labels.txt"), 'r') as fp: @@ -144,14 +144,15 @@ def _is_integer(x): for inp, output in zip(inputs, outputs): field = dataclasses.fields(output)[0].name - self.assertEqual( - len(self.triton_model_config.output[0].dims), - len(getattr(output, field).shape), - "Length of 'dims' of config and output must be matched, but get " - f"Config {len(self.triton_model_config.output[0].dims)} != Output {len(getattr(output, field).shape)}" - ) + if self.model_type not in self._output_text_models: + self.assertEqual( + len(self.triton_model_config.output[0].dims), + len(getattr(output, field).shape), + "Length of 'dims' of config and output must be matched, but get " + f"Config {len(self.triton_model_config.output[0].dims)} != Output {len(getattr(output, field).shape)}" + ) - if self.model_type == "visual-detector": + if self.model_type == ModelTypes.visual_detector: logging.info(output.predicted_labels) self.assertEqual( type(output), VisualDetectorOutput, @@ -166,7 +167,7 @@ def _is_integer(x): f"`predicted_labels` must be in [0, {len(self.labels) - 1}]") self.assertTrue(_is_integer(output.predicted_labels), "`predicted_labels` must be integer") - elif self.model_type == "visual-classifier": + elif self.model_type == ModelTypes.visual_classifier: self.assertEqual( type(output), ClassifierOutput, f"Output type must be `ClassifierOutput`, but got {type(output)}") @@ -179,7 +180,7 @@ def _is_integer(x): f"`predicted_labels` must equal to {len(self.labels)}, however got {len(output.predicted_scores)}" ) - elif self.model_type == "text-classifier": + elif self.model_type == ModelTypes.text_classifier: self.assertEqual( type(output), ClassifierOutput, f"Output type must be `ClassifierOutput`, but got {type(output)}") @@ -192,29 +193,29 @@ def _is_integer(x): f"`predicted_labels` must equal to {len(self.labels)}, however got {len(output.predicted_scores)}" ) - elif self.model_type == "text-embedder": + elif self.model_type == ModelTypes.text_embedder: self.assertEqual( type(output), EmbeddingOutput, f"Output type must be `EmbeddingOutput`, but got {type(output)}") self.assertNotEqual(output.embedding_vector.shape, []) - elif self.model_type == "text-to-text": + elif self.model_type == ModelTypes.text_to_text: self.assertEqual( type(output), TextOutput, f"Output type must be `TextOutput`, but got {type(output)}") - elif self.model_type == "text-to-image": + elif self.model_type == ModelTypes.text_to_image: self.assertEqual( type(output), ImageOutput, f"Output type must be `ImageOutput`, but got {type(output)}") self.assertTrue(_is_non_negative(output.image), "`image` elements must be >= 0") - elif self.model_type == "visual-embedder": + elif self.model_type == ModelTypes.visual_embedder: self.assertEqual( type(output), EmbeddingOutput, f"Output type must be `EmbeddingOutput`, but got {type(output)}") self.assertNotEqual(output.embedding_vector.shape, []) - elif self.model_type == "visual-segmenter": + elif self.model_type == ModelTypes.visual_segmenter: self.assertEqual( type(output), MasksOutput, f"Output type must be `MasksOutput`, but got {type(output)}") diff --git a/clarifai/models/model_serving/models/output.py b/clarifai/models/model_serving/models/output.py index aa445794..151d4327 100644 --- a/clarifai/models/model_serving/models/output.py +++ b/clarifai/models/model_serving/models/output.py @@ -72,8 +72,8 @@ def __post_init__(self): """ Validate input upon initialization. """ - assert self.predicted_text.ndim == 1, \ - f"All predictions must be 1-dimensional, Got text-dims: {self.predicted_text.ndim} instead." + assert self.predicted_text.ndim == 0, \ + f"All predictions must be 0-dimensional, Got text-dims: {self.predicted_text.ndim} instead." @dataclass diff --git a/clarifai/models/model_serving/pb_model_repository.py b/clarifai/models/model_serving/pb_model_repository.py index 798d9ed5..86abbf61 100644 --- a/clarifai/models/model_serving/pb_model_repository.py +++ b/clarifai/models/model_serving/pb_model_repository.py @@ -19,8 +19,7 @@ from pathlib import Path from typing import Callable, Type -from .model_config.serializer import Serializer -from .model_config.triton_config import TritonModelConfig +from .model_config import Serializer, TritonModelConfig from .models import inference, pb_model, test diff --git a/requirements.txt b/requirements.txt index 36a25421..cf714021 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ opencv-python==4.7.0.68 tritonclient==2.34.0 rich==13.4.2 pytest==7.4.1 +PyYAML==6.0.1 diff --git a/setup.py b/setup.py index 92ab2b5e..e1101510 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,8 @@ license="Apache 2.0", python_requires='>=3.8', install_requires=[ - "clarifai-grpc>=9.8.1", "tritonclient==2.34.0", "packaging", "tqdm==4.64.1", "rich==13.4.2" + "clarifai-grpc>=9.8.1", "tritonclient==2.34.0", "packaging", "tqdm==4.64.1", + "rich==13.4.2", "PyYAML==6.0.1" ], entry_points={ "console_scripts": [