Skip to content

Commit

Permalink
[Devx:119] Models Interface (#143)
Browse files Browse the repository at this point in the history
* added Model Interface
  • Loading branch information
sainivedh authored Aug 18, 2023
1 parent 223c4a9 commit 6df2eb1
Show file tree
Hide file tree
Showing 7 changed files with 367 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ jobs:
run: |
export PYTHONPATH=.
export CLARIFAI_PAT="$(python scripts/key_for_tests.py --create-pat)"
pytest tests/ -n auto
pytest tests/test_auth.py tests/test_modules.py tests/test_stub.py -n auto
38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ export CLARIFAI_PAT={your personal access token}
```

```python
# Note: CLARIFAI_PAT must be set as env variable.
from clarifai.client.user import User
client = User(user_id="user_id")

Expand All @@ -34,3 +35,40 @@ dataset = app.create_dataset(dataset_id="demo_dataset")
# execute data upload to Clarifai app dataset
dataset.upload_dataset(task='visual_segmentation', split="train", dataset_loader='coco_segmentation')
```

## Interacting with Models

### Model Predict
```python
# Note: CLARIFAI_PAT must be set as env variable.
from clarifai.client.model import Model

# Model Predict
model = Model(user_id="user_id", app_id="app_id", model_id="model_id")
model_prediction = model.predict_by_url(url="url", input_type="image") # Supports image, text, audio, video

# Customizing Model Inference Output
model = Model(user_id="user_id", app_id="app_id", model_id="model_id",
output_config={"min_value": 0.98}) # Return predictions having prediction confidence > 0.98
model_prediction = model.predict_by_filepath(filepath="local_filepath", input_type="text") # Supports image, text, audio, video

model = Model(user_id="user_id", app_id="app_id", model_id="model_id",
output_config={"sample_ms": 2000}) # Return predictions for specified interval
model_prediction = model.predict_by_url(url="VIDEO_URL", input_type="video")
```
### Models Listing
```python
# Note: CLARIFAI_PAT must be set as env variable.
# List all model versions
all_model_versions = model.list_versions()

# Go to specific model version
model_v1 = client.app("app_id").model(model_id="model_id", model_version_id="model_version_id")

# List all models in an app
all_models = app.list_models()

# List all models in community filtered by model_type, description
all_llm_community_models = App().list_models(filter_by={"query": "LLM",
"model_type_id": "text-to-text"}, only_in_app=False)
```
46 changes: 38 additions & 8 deletions clarifai/client/app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import Any, Dict, List

from clarifai_grpc.grpc.api import resources_pb2, service_pb2
from clarifai_grpc.grpc.api.status import status_code_pb2
Expand All @@ -18,7 +18,7 @@ class App(Lister, BaseClient):
Inherits from BaseClient for authentication purposes.
"""

def __init__(self, app_id: str, **kwargs):
def __init__(self, app_id: str = "", **kwargs):
"""Initializes an App object.
Args:
app_id (str): The App ID for the App to interact with.
Expand Down Expand Up @@ -47,11 +47,34 @@ def list_datasets(self) -> List[Dataset]:

return [Dataset(**dataset_info) for dataset_info in all_datasets_info]

def list_models(self):
"""
Lists all the models for the app.
def list_models(self, filter_by: Dict[str, Any] = {}, only_in_app: bool = True) -> List[Model]:
"""Lists all the models for the app.
Args:
filter_by (dict): A dictionary of filters to apply to the list of models.
only_in_app (bool): If True, only return models that are in the app.
Returns:
List[Model]: A list of Model objects for the models in the app.
Example:
>>> from clarifai.client.user import User
>>> app = User(user_id="user_id").app(app_id="app_id")
>>> all_models = app.list_models()
"""
pass # TODO
request_data = dict(user_app_id=self.user_app_id, per_page=self.default_page_size, **filter_by)
all_models_info = list(
self.list_all_pages_generator(self.STUB.ListModels, service_pb2.ListModelsRequest,
request_data))

filtered_models_info = []
for model_info in all_models_info:
if 'model_version' not in list(model_info.keys()):
continue
if only_in_app:
if model_info['app_id'] != self.id:
continue
filtered_models_info.append(model_info)

return [Model(**model_info) for model_info in filtered_models_info]

def list_workflows(self):
"""
Expand Down Expand Up @@ -133,14 +156,21 @@ def dataset(self, dataset_id: str, **kwargs) -> Dataset:

return Dataset(**kwargs)

def model(self, model_id: str, **kwargs) -> Model:
def model(self, model_id: str, model_version_id: str = "", **kwargs) -> Model:
"""Returns a Model object for the existing model ID.
Args:
model_id (str): The model ID for the model to interact with.
model_version_id (str): The model version ID for the model version to interact with.
Returns:
Model: A Model object for the existing model ID.
Example:
>>> from clarifai.client.app import App
>>> app = App(app_id="app_id", user_id="user_id")
>>> model_v1 = app.model(model_id="model_id", model_version_id="model_version_id")
"""
request = service_pb2.GetModelRequest(user_app_id=self.user_app_id, model_id=model_id)
request = service_pb2.GetModelRequest(
user_app_id=self.user_app_id, model_id=model_id, version_id=model_version_id)
response = self._grpc_request(self.STUB.GetModel, request)
if response.status.code != status_code_pb2.SUCCESS:
raise Exception(response.status)
Expand Down
170 changes: 165 additions & 5 deletions clarifai/client/model.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,182 @@
from clarifai_grpc.grpc.api import resources_pb2, service_pb2 # noqa: F401
import os
import time
from typing import Dict, List

from clarifai_grpc.grpc.api import resources_pb2, service_pb2
from clarifai_grpc.grpc.api.resources_pb2 import Input
from clarifai_grpc.grpc.api.status import status_code_pb2

from clarifai.client.base import BaseClient
from clarifai.client.lister import Lister
from clarifai.errors import UserError
from clarifai.utils.logging import get_logger
from clarifai.utils.misc import BackoffIterator


class Model(BaseClient):
class Model(Lister, BaseClient):
"""
Model is a class that provides access to Clarifai API endpoints related to Model information.
Inherits from BaseClient for authentication purposes.
"""

def __init__(self, model_id: str, **kwargs):
def __init__(self,
model_id: str,
model_version: Dict = {'id': ""},
output_config: Dict = {'min_value': 0},
**kwargs):
"""Initializes an Model object.
Args:
model_id (str): The Model ID to interact with.
model_version (dict): The Model Version to interact with.
output_config (dict): The output config to interact with.
min_value (float): The minimum value of the prediction confidence to filter.
max_concepts (int): The maximum number of concepts to return.
select_concepts (list[Concept]): The concepts to select.
sample_ms (int): The number of milliseconds to sample.
**kwargs: Additional keyword arguments to be passed to the ClarifaiAuthHelper.
"""
self.kwargs = {**kwargs, 'id': model_id}
self.kwargs = {**kwargs, 'id': model_id, 'model_version': model_version,
'output_info': {'output_config': output_config}}
self.model_info = resources_pb2.Model(**self.kwargs)
super().__init__(user_id=self.user_id, app_id=self.app_id)
self.logger = get_logger(logger_level="INFO")
BaseClient.__init__(self, user_id=self.user_id, app_id=self.app_id)
Lister.__init__(self)

def predict(self, inputs: List[Input]):
"""Predicts the model based on the given inputs.
Args:
inputs (list[Input]): The inputs to predict, must be less than 128.
"""
if len(inputs) > 128:
raise UserError("Too many inputs. Max is 128.") # TODO Use Chunker for inputs len > 128

request = service_pb2.PostModelOutputsRequest(
user_app_id=self.user_app_id,
model_id=self.id,
version_id=self.model_version.id,
inputs=inputs,
model=self.model_info)

start_time = time.time()
backoff_iterator = BackoffIterator()
while True:
response = self._grpc_request(self.STUB.PostModelOutputs, request)

if response.outputs and \
response.outputs[0].status.code == status_code_pb2.MODEL_DEPLOYING and \
time.time() - start_time < 60 * 10: # 10 minutes
self.logger.info(f"{self.id} model is still deploying, please wait...")
time.sleep(next(backoff_iterator))
continue

if response.status.code != status_code_pb2.SUCCESS:
raise Exception(f"Model Predict failed with response {response.status!r}")
else:
break

return response

def predict_by_filepath(self, filepath: str, input_type: str):
"""Predicts the model based on the given filepath.
Args:
filepath (str): The filepath to predict.
input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
Example:
>>> from clarifai.client.model import Model
>>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
>>> model_prediction = model.predict_by_filepath('/path/to/image.jpg', 'image')
>>> model_prediction = model.predict_by_filepath('/path/to/text.txt', 'text')
"""
if input_type not in ['image', 'text', 'video', 'audio']:
raise UserError('Invalid input type it should be image, text, video or audio.')
if not os.path.isfile(filepath):
raise UserError('Invalid filepath.')

with open(filepath, "rb") as f:
file_bytes = f.read()

return self.predict_by_bytes(file_bytes, input_type)

def predict_by_bytes(self, file_bytes: bytes, input_type: str):
"""Predicts the model based on the given bytes.
Args:
file_bytes (bytes): File Bytes to predict on.
input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio'.
"""
if input_type not in {'image', 'text', 'video', 'audio'}:
raise UserError('Invalid input type it should be image, text, video or audio.')
# TODO will obtain proto from input class
if input_type == "image":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(image=resources_pb2.Image(base64=file_bytes)))
elif input_type == "text":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(text=resources_pb2.Text(raw=file_bytes)))
elif input_type == "video":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(video=resources_pb2.Video(base64=file_bytes)))
elif input_type == "audio":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(audio=resources_pb2.Audio(base64=file_bytes)))

return self.predict(inputs=[input_proto])

def predict_by_url(self, url: str, input_type: str):
"""Predicts the model based on the given URL.
Args:
url (str): The URL to predict.
input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
Example:
>>> from clarifai.client.model import Model
>>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
>>> model_prediction = model.predict_by_url('url', 'image')
"""
if input_type not in {'image', 'text', 'video', 'audio'}:
raise UserError('Invalid input type it should be image, text, video or audio.')
# TODO will be obtain proto from input class
if input_type == "image":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(image=resources_pb2.Image(url=url)))
elif input_type == "text":
input_proto = resources_pb2.Input(data=resources_pb2.Data(text=resources_pb2.Text(url=url)))
elif input_type == "video":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(video=resources_pb2.Video(url=url)))
elif input_type == "audio":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(audio=resources_pb2.Audio(url=url)))

return self.predict(inputs=[input_proto])

def list_versions(self) -> List['Model']:
"""Lists all the versions for the model.
Returns:
List[Model]: A list of Model objects for the versions of the model.
Example:
>>> from clarifai.client.model import Model
>>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
>>> all_model_versions = model.list_versions()
"""
request_data = dict(
user_app_id=self.user_app_id,
model_id=self.id,
per_page=self.default_page_size,
)
all_model_versions_info = list(
self.list_all_pages_generator(self.STUB.ListModelVersions,
service_pb2.ListModelVersionsRequest, request_data))

for model_version_info in all_model_versions_info:
model_version_info['id'] = model_version_info['model_version_id']
del model_version_info['model_version_id']

return [
Model(model_id=self.id, **dict(self.kwargs, model_version=model_version_info))
for model_version_info in all_model_versions_info
]

def __getattr__(self, name):
return getattr(self.model_info, name)
Expand Down
2 changes: 1 addition & 1 deletion clarifai/client/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class User(Lister, BaseClient):
Inherits from BaseClient for authentication purposes.
"""

def __init__(self, user_id: str, **kwargs):
def __init__(self, user_id: str = "", **kwargs):
"""Initializes an User object.
Args:
user_id (str): The user ID for the user to interact with.
Expand Down
Binary file added tests/assets/red-truck.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 6df2eb1

Please sign in to comment.