Skip to content

Commit

Permalink
[EAGLE-3454] [EAGLE-3447] - Organize model types and Fix text output …
Browse files Browse the repository at this point in the history
…dim (#171)

* init

* fix inp fields map

* addressed comment

* fix text output error

* addressed comment + update requirements

* update requirements

* add PyYaml in setup
  • Loading branch information
phatvo9 authored Sep 15, 2023
1 parent 48df778 commit 1c4b670
Show file tree
Hide file tree
Showing 22 changed files with 494 additions and 347 deletions.
5 changes: 2 additions & 3 deletions clarifai/models/model_serving/cli/deploy_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@

from clarifai.auth.helper import ClarifaiAuthHelper
from clarifai.models.api import Models
from clarifai.models.model_serving.constants import MODEL_TYPES
from clarifai.models.model_serving.model_config.deploy import ClarifaiFieldsMap
from clarifai.models.model_serving.model_config import MODEL_TYPES, get_model_config


def deploy(model_url,
Expand Down Expand Up @@ -50,7 +49,7 @@ def _parse_name(name):
assert model_type, "Can not parse model_type from url, please input it directly"
# key map
assert model_type in MODEL_TYPES, f"model_type should be one of {MODEL_TYPES}"
clarifai_key_map = ClarifaiFieldsMap(model_type=model_type)
clarifai_key_map = get_model_config(model_type=model_type).field_maps
# if updating new version of existing model
if update_version:
resp = model_api.post_model_version(
Expand Down
8 changes: 3 additions & 5 deletions clarifai/models/model_serving/cli/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

import argparse

from ..constants import MODEL_TYPES
from ..model_config.triton_config import TritonModelConfig
from ..constants import MAX_HW_DIM
from ..model_config import MODEL_TYPES, get_model_config
from ..pb_model_repository import TritonModelRepository


Expand Down Expand Up @@ -63,7 +63,6 @@ def model_upload_init():
help="Directory to create triton repository.")

args = parser.parse_args()
MAX_HW_DIM = 1024

if len(args.image_shape) != 2:
raise ValueError(
Expand All @@ -75,10 +74,9 @@ def model_upload_init():
f"H and W each have a maximum value of 1024. Got H: {args.image_shape[0]}, W: {args.image_shape[1]}"
)

model_config = TritonModelConfig(
model_config = get_model_config(args.model_type).make_triton_model_config(
model_name=args.model_name,
model_version="1",
model_type=args.model_type,
image_shape=args.image_shape,
)

Expand Down
6 changes: 1 addition & 5 deletions clarifai/models/model_serving/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1 @@
# Clarifai model types
MODEL_TYPES = [
"visual-detector", "visual-classifier", "text-classifier", "text-to-text", "text-embedder",
"text-to-image", "visual-embedder", "visual-segmenter"
]
MAX_HW_DIM = 1024
11 changes: 5 additions & 6 deletions clarifai/models/model_serving/docs/custom_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,16 @@ $ clarifai-model-upload-init --model_name <Your model name> \

## Generating the triton model repository without the commandline

The triton model repository can be generated via a python script specifying the same values as required in the commandline. Below is a sample of how the code would be structured.
The triton model repository can be generated via a python script specifying the same values as required in the commandline. Below is a sample of how the code would be structured with `visual_classifier`.

```python
from clarifai.models.model_serving.model_config.triton_config import TritonModelConfig
from clarifai.models.model_serving pb_model_repository import TritonModelRepository
from clarifai.models.model_serving.model_config import get_model_config, ModelTypes, TritonModelConfig
from clarifai.models.model_serving.pb_model_repository import TritonModelRepository


model_config = TritonModelConfig(
model_type = ModelTypes.visual_classifier
model_config: TritonModelConfig = get_model_config(model_type).make_triton_model_config(
model_name="<model_name>",
model_version="1",
model_type="<model_type>",
image_shape=<[H,W]>, # 0 < [H,W] <= 1024
)

Expand Down
2 changes: 2 additions & 0 deletions clarifai/models/model_serving/model_config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config import * # noqa # pylint: disable=unused-import
from .serializer import Serializer # noqa # pylint: disable=unused-import
298 changes: 298 additions & 0 deletions clarifai/models/model_serving/model_config/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
# Copyright 2023 Clarifai, Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Model Config classes."""

from dataclasses import asdict, dataclass, field
from typing import List

import yaml

from ..models.model_types import * # noqa # pylint: disable=unused-import
from ..models.output import * # noqa # pylint: disable=unused-import

__all__ = ["get_model_config", "MODEL_TYPES", "TritonModelConfig", "ModelTypes"]

### Triton Model Config classes.###


@dataclass
class DType:
"""
Triton Model Config data types.
"""
# https://github.com/triton-inference-server/common/blob/main/protobuf/model_config.proto
TYPE_UINT8: int = 2
TYPE_INT8: int = 6
TYPE_INT16: int = 7
TYPE_INT32: int = 8
TYPE_INT64: int = 9
TYPE_FP16: int = 10
TYPE_FP32: int = 11
TYPE_STRING: int = 13
KIND_GPU: int = 1
KIND_CPU: int = 2


@dataclass
class InputConfig:
"""
Triton Input definition.
Params:
-------
name: input name
data_type: input data type
dims: Pre-defined input data shape(s).
Returns:
--------
InputConfig
"""
name: str
data_type: int
dims: List = field(default_factory=list)


@dataclass
class OutputConfig:
"""
Triton Output definition.
Params:
-------
name: output name
data_type: output data type
dims: Pre-defined output data shape(s).
labels (bool): If labels file is required for inference.
Returns:
--------
OutputConfig
"""
name: str
data_type: int
dims: List = field(default_factory=list)
labels: bool = False

def __post_init__(self):
if self.labels:
self.label_filename = "labels.txt"
del self.labels


@dataclass
class Device:
"""
Triton instance_group.
Define the type of inference device and number of devices to use.
Params:
-------
count: number of devices
use_gpu: whether to use cpu or gpu.
Returns:
--------
Device object
"""
count: int = 1
use_gpu: bool = True

def __post_init__(self):
if self.use_gpu:
self.kind: str = DType.KIND_GPU
else:
self.kind: str = DType.KIND_CPU


@dataclass
class DynamicBatching:
"""
Triton dynamic_batching config.
Params:
-------
preferred_batch_size: batch size
max_queue_delay_microseconds: max queue delay for a request batch
Returns:
--------
DynamicBatching object
"""
#preferred_batch_size: List[int] = [1] # recommended not to set
max_queue_delay_microseconds: int = 500


@dataclass
class TritonModelConfig:
"""
Triton Model Config base.
Params:
-------
name: triton inference model name
input: a list of an InputConfig field
output: a list of OutputConfig fields/dicts
instance_group: Device. see Device
dynamic_batching: Triton dynamic batching settings.
max_batch_size: max request batch size
backend: Triton Python Backend. Constant
Returns:
--------
TritonModelConfig
"""
model_type: str
model_name: str
model_version: str
image_shape: List #(H, W)
input: List[InputConfig] = field(default_factory=list)
output: List[OutputConfig] = field(default_factory=list)
instance_group: Device = field(default_factory=Device)
dynamic_batching: DynamicBatching = field(default_factory=DynamicBatching)
max_batch_size: int = 1
backend: str = "python"

def __post_init__(self):
if "image" in [each.name for each in self.input]:
image_dims = self.image_shape
image_dims.append(3) # add channel dim
self.input[0].dims = image_dims


### General Model Config classes & functions ###


# Clarifai model types
@dataclass
class ModelTypes:
visual_detector: str = "visual-detector"
visual_classifier: str = "visual-classifier"
text_classifier: str = "text-classifier"
text_to_text: str = "text-to-text"
text_embedder: str = "text-embedder"
text_to_image: str = "text-to-image"
visual_embedder: str = "visual-embedder"
visual_segmenter: str = "visual-segmenter"

def __post_init__(self):
self.all = list(asdict(self).values())


@dataclass
class InferenceConfig:
wrap_func: callable
return_type: dataclass


@dataclass
class FieldMapsConfig:
input_fields_map: dict
output_fields_map: dict


@dataclass
class DefaultTritonConfig:
input: List[InputConfig] = field(default_factory=list)
output: List[OutputConfig] = field(default_factory=list)


@dataclass
class ModelConfigClass:
type: str = field(init=False)
triton: DefaultTritonConfig
inference: InferenceConfig
field_maps: FieldMapsConfig

def make_triton_model_config(
self,
model_name: str,
model_version: str,
image_shape: List = None,
instance_group: Device = Device(),
dynamic_batching: DynamicBatching = DynamicBatching(),
max_batch_size: int = 1,
backend: str = "python",
) -> TritonModelConfig:

return TritonModelConfig(
model_type=self.type,
model_name=model_name,
model_version=model_version,
image_shape=image_shape,
instance_group=instance_group,
dynamic_batching=dynamic_batching,
max_batch_size=max_batch_size,
backend=backend,
input=self.triton.input,
output=self.triton.output)


def read_config(cfg: str):
with open(cfg, encoding="utf-8") as f:
config = yaml.safe_load(f) # model dict

# parse default triton
input_triton_configs = config["triton"]["input"]
output_triton_configs = config["triton"]["output"]
triton = DefaultTritonConfig(
input=[
InputConfig(
name=input["name"],
data_type=eval(f"DType.{input['data_type']}"),
dims=input["dims"]) for input in input_triton_configs
],
output=[
OutputConfig(
name=output["name"],
data_type=eval(f"DType.{output['data_type']}"),
dims=output["dims"],
labels=output["labels"],
) for output in output_triton_configs
])

# parse inference config
inference = InferenceConfig(
wrap_func=eval(config["inference"]["wrap_func"]),
return_type=eval(config["inference"]["return_type"]),
)

# parse field maps for deployment
field_maps = FieldMapsConfig(**config["field_maps"])

return ModelConfigClass(triton=triton, inference=inference, field_maps=field_maps)


def get_model_config(model_type: str) -> ModelConfigClass:
"""
Get model config by model type
Args:
model_type (str): One of field value of ModelTypes
Return:
ModelConfigClass
### Example:
>>> cfg = get_model_config(ModelTypes.text_classifier)
>>> custom_triton_config = cfg.make_triton_model_config(**kwargs)
"""
import os
assert model_type in MODEL_TYPES, f"`model_type` must be in {MODEL_TYPES}"
cfg = read_config(
os.path.join(os.path.dirname(__file__), "model_types_config", f"{model_type}.yaml"))
cfg.type = model_type
return cfg


_model_types = ModelTypes()
MODEL_TYPES = _model_types.all
del _model_types
Loading

0 comments on commit 1c4b670

Please sign in to comment.