Skip to content

Commit

Permalink
[DEVX:122] Added Workflow Interface (#145)
Browse files Browse the repository at this point in the history
* added workflow interface
  • Loading branch information
sainivedh authored Aug 21, 2023
1 parent 6df2eb1 commit dc5fbdc
Show file tree
Hide file tree
Showing 6 changed files with 275 additions and 16 deletions.
35 changes: 35 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ model_prediction = model.predict_by_url(url="VIDEO_URL", input_type="video")
### Models Listing
```python
# Note: CLARIFAI_PAT must be set as env variable.

# List all model versions
all_model_versions = model.list_versions()

Expand All @@ -72,3 +73,37 @@ all_models = app.list_models()
all_llm_community_models = App().list_models(filter_by={"query": "LLM",
"model_type_id": "text-to-text"}, only_in_app=False)
```

## Interacting with Workflows

### Workflow Predict
```python
# Note: CLARIFAI_PAT must be set as env variable.
from clarifai.client.workflow import Workflow

# Workflow Predict
workflow = Workflow(user_id="user_id", app_id="app_id", workflow_id="workflow_id")
workflow_prediction = workflow.predict_by_url(url="url", input_type="image") # Supports image, text, audio, video

# Customizing Workflow Inference Output
workflow = Workflow(user_id="user_id", app_id="app_id", workflow_id="workflow_id",
output_config={"min_value": 0.98}) # Return predictions having prediction confidence > 0.98
workflow_prediction = workflow.predict_by_filepath(filepath="local_filepath", input_type="text") # Supports image, text, audio, video
```

### Workflows Listing
```python
# Note: CLARIFAI_PAT must be set as env variable.

# List all workflow versions
all_workflow_versions = workflow.list_versions()

# Go to specific workflow version
workflow_v1 = Workflow(workflow_id="workflow_id", workflow_version=dict(id="workflow_version_id"), app_id="app_id", user_id="user_id")

# List all workflow in an app
all_workflow = app.list_workflow()

# List all workflow in community filtered by description
all_face_community_workflows = App().list_workflow(filter_by={"query": "face"}, only_in_app=False) # Get all face related workflows
```
32 changes: 30 additions & 2 deletions clarifai/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,34 @@ def list_models(self, filter_by: Dict[str, Any] = {}, only_in_app: bool = True)

return [Model(**model_info) for model_info in filtered_models_info]

def list_workflows(self):
def list_workflows(self, filter_by: Dict[str, Any] = {},
only_in_app: bool = True) -> List[Workflow]:
"""
Lists all the workflows for the app.
Args:
filter_by (dict): A dictionary of filters to apply to the list of workflows.
only_in_app (bool): If True, only return workflows that are in the app.
Returns:
List[Workflow]: A list of Workflow objects for the workflows in the app.
Example:
>>> from clarifai.client.app import App
>>> app = App(app_id="app_id", user_id="user_id")
>>> all_workflows = app.list_workflows()
"""
pass # TODO
request_data = dict(user_app_id=self.user_app_id, per_page=self.default_page_size, **filter_by)
all_workflows_info = list(
self.list_all_pages_generator(self.STUB.ListWorkflows, service_pb2.ListWorkflowsRequest,
request_data))

filtered_workflows_info = []
for workflow_info in all_workflows_info:
if only_in_app:
if workflow_info['app_id'] != self.id:
continue
filtered_workflows_info.append(workflow_info)

return [Workflow(**workflow_info) for workflow_info in all_workflows_info]

def list_concepts(self):
"""
Expand Down Expand Up @@ -184,6 +207,11 @@ def workflow(self, workflow_id: str, **kwargs) -> Workflow:
workflow_id (str): The workflow ID for the workflow to interact with.
Returns:
Workflow: A Workflow object for the existing workflow ID.
Example:
>>> from clarifai.client.app import App
>>> app = App(app_id="app_id", user_id="user_id")
>>> workflow = app.workflow(workflow_id="workflow_id")
"""
request = service_pb2.GetWorkflowRequest(user_app_id=self.user_app_id, workflow_id=workflow_id)
response = self._grpc_request(self.STUB.GetWorkflow, request)
Expand Down
2 changes: 1 addition & 1 deletion clarifai/client/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Dataset(Lister, BaseClient):
"""

def __init__(self, dataset_id: str, **kwargs):
"""Initializes an Dataset object.
"""Initializes a Dataset object.
Args:
dataset_id (str): The Dataset ID within the App to interact with.
**kwargs: Additional keyword arguments to be passed to the ClarifaiAuthHelper.
Expand Down
16 changes: 9 additions & 7 deletions clarifai/client/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self,
model_version: Dict = {'id': ""},
output_config: Dict = {'min_value': 0},
**kwargs):
"""Initializes an Model object.
"""Initializes a Model object.
Args:
model_id (str): The Model ID to interact with.
model_version (dict): The Model Version to interact with.
Expand Down Expand Up @@ -98,27 +98,29 @@ def predict_by_filepath(self, filepath: str, input_type: str):

return self.predict_by_bytes(file_bytes, input_type)

def predict_by_bytes(self, file_bytes: bytes, input_type: str):
def predict_by_bytes(self, input_bytes: bytes, input_type: str):
"""Predicts the model based on the given bytes.
Args:
file_bytes (bytes): File Bytes to predict on.
input_bytes (bytes): File Bytes to predict on.
input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio'.
"""
if input_type not in {'image', 'text', 'video', 'audio'}:
raise UserError('Invalid input type it should be image, text, video or audio.')
if not isinstance(input_bytes, bytes):
raise UserError('Invalid bytes.')
# TODO will obtain proto from input class
if input_type == "image":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(image=resources_pb2.Image(base64=file_bytes)))
data=resources_pb2.Data(image=resources_pb2.Image(base64=input_bytes)))
elif input_type == "text":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(text=resources_pb2.Text(raw=file_bytes)))
data=resources_pb2.Data(text=resources_pb2.Text(raw=input_bytes)))
elif input_type == "video":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(video=resources_pb2.Video(base64=file_bytes)))
data=resources_pb2.Data(video=resources_pb2.Video(base64=input_bytes)))
elif input_type == "audio":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(audio=resources_pb2.Audio(base64=file_bytes)))
data=resources_pb2.Data(audio=resources_pb2.Audio(base64=input_bytes)))

return self.predict(inputs=[input_proto])

Expand Down
157 changes: 151 additions & 6 deletions clarifai/client/workflow.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,167 @@
from clarifai_grpc.grpc.api import resources_pb2, service_pb2 # noqa: F401
import os
from typing import Dict, List

from clarifai_grpc.grpc.api import resources_pb2, service_pb2
from clarifai_grpc.grpc.api.resources_pb2 import Input
from clarifai_grpc.grpc.api.status import status_code_pb2

from clarifai.client.base import BaseClient
from clarifai.client.lister import Lister
from clarifai.errors import UserError
from clarifai.utils.logging import get_logger


class Workflow(BaseClient):
class Workflow(Lister, BaseClient):
"""
Workflow is a class that provides access to Clarifai API endpoints related to Workflow information.
Inherits from BaseClient for authentication purposes.
"""

def __init__(self, workflow_id: str, **kwargs):
"""Initializes an Workflow object.
def __init__(self,
workflow_id: str = "",
workflow_version: Dict = {'id': ""},
output_config: Dict = {'min_value': 0},
**kwargs):
"""Initializes a Workflow object.
Args:
workflow_id (str): The Workflow ID to interact with.
workflow_version (dict): The Workflow Version to interact with.
output_config (dict): The output config to interact with.
min_value (float): The minimum value of the prediction confidence to filter.
max_concepts (int): The maximum number of concepts to return.
select_concepts (list[Concept]): The concepts to select.
sample_ms (int): The number of milliseconds to sample.
**kwargs: Additional keyword arguments to be passed to the ClarifaiAuthHelper.
"""
self.kwargs = {**kwargs, 'id': workflow_id}
self.kwargs = {**kwargs, 'id': workflow_id, 'version': workflow_version}
self.output_config = output_config
self.workflow_info = resources_pb2.Workflow(**self.kwargs)
super().__init__(user_id=self.user_id, app_id=self.app_id)
self.logger = get_logger(logger_level="INFO")
BaseClient.__init__(self, user_id=self.user_id, app_id=self.app_id)
Lister.__init__(self)

def predict(self, inputs: List[Input]):
"""Predicts the workflow based on the given inputs.
Args:
inputs (list[Input]): The inputs to predict.
"""
if len(inputs) > 128:
raise UserError("Too many inputs. Max is 128.") # TODO Use Chunker for inputs len > 128
request = service_pb2.PostWorkflowResultsRequest(
user_app_id=self.user_app_id,
workflow_id=self.id,
version_id=self.version.id,
inputs=inputs,
output_config=self.output_config)

response = self._grpc_request(self.STUB.PostWorkflowResults, request)
if response.status.code != status_code_pb2.SUCCESS:
raise Exception(f"Workflow Predict failed with response {response.status!r}")

return response

def predict_by_filepath(self, filepath: str, input_type: str):
"""Predicts the workflow based on the given filepath.
Args:
filepath (str): The filepath to predict.
input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
Example:
>>> from clarifai.client.workflow import Workflow
>>> workflow = Workflow(user_id='user_id', app_id='app_id', workflow_id='workflow_id')
>>> workflow_prediction = workflow.predict_by_filepath('filepath', 'image')
"""
if input_type not in {'image', 'text', 'video', 'audio'}:
raise UserError('Invalid input type it should be image, text, video or audio.')
if not os.path.isfile(filepath):
raise UserError('Invalid filepath.')

with open(filepath, "rb") as f:
file_bytes = f.read()

return self.predict_by_bytes(file_bytes, input_type)

def predict_by_bytes(self, input_bytes: bytes, input_type: str):
"""Predicts the workflow based on the given bytes.
Args:
input_bytes (bytes): Bytes to predict on.
input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
"""
if input_type not in {'image', 'text', 'video', 'audio'}:
raise UserError('Invalid input type it should be image, text, video or audio.')
if not isinstance(input_bytes, bytes):
raise UserError('Invalid bytes.')

if input_type == "image":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(image=resources_pb2.Image(base64=input_bytes)))
elif input_type == "text":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(text=resources_pb2.Text(raw=input_bytes)))
elif input_type == "video":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(video=resources_pb2.Video(base64=input_bytes)))
elif input_type == "audio":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(audio=resources_pb2.Audio(base64=input_bytes)))

return self.predict(inputs=[input_proto])

def predict_by_url(self, url: str, input_type: str):
"""Predicts the workflow based on the given URL.
Args:
url (str): The URL to predict.
input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
Example:
>>> from clarifai.client.workflow import Workflow
>>> workflow = Workflow(user_id='user_id', app_id='app_id', workflow_id='workflow_id')
>>> workflow_prediction = workflow.predict_by_url('url', 'image')
"""
if input_type not in {'image', 'text', 'video', 'audio'}:
raise UserError('Invalid input type it should be image, text, video or audio.')

if input_type == "image":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(image=resources_pb2.Image(url=url)))
elif input_type == "text":
input_proto = resources_pb2.Input(data=resources_pb2.Data(text=resources_pb2.Text(url=url)))
elif input_type == "video":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(video=resources_pb2.Video(url=url)))
elif input_type == "audio":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(audio=resources_pb2.Audio(url=url)))

return self.predict(inputs=[input_proto])

def list_versions(self) -> List['Workflow']:
"""Lists all the versions of the workflow.
Returns:
list[Workflow]: A list of Workflow objects.
Example:
>>> from clarifai.client.workflow import Workflow
>>> workflow = Workflow(user_id='user_id', app_id='app_id', workflow_id='workflow_id')
>>> workflow_versions = workflow.list_versions()
"""
request_data = dict(
user_app_id=self.user_app_id,
workflow_id=self.id,
per_page=self.default_page_size,
)
all_workflow_versions_info = list(
self.list_all_pages_generator(self.STUB.ListWorkflowVersions,
service_pb2.ListWorkflowVersionsRequest, request_data))

for workflow_version_info in all_workflow_versions_info:
workflow_version_info['id'] = workflow_version_info['workflow_version_id']
del workflow_version_info['workflow_version_id']

return [
Workflow(workflow_id=self.id, **dict(self.kwargs, version=workflow_version_info))
for workflow_version_info in all_workflow_versions_info
]

def __getattr__(self, name):
return getattr(self.workflow_info, name)
Expand Down
49 changes: 49 additions & 0 deletions tests/test_workflow_predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import os

import pytest
from clarifai_grpc.grpc.api import resources_pb2

from clarifai.client.workflow import Workflow

DOG_IMAGE_URL = "https://samples.clarifai.com/dog2.jpeg"
NON_EXISTING_IMAGE_URL = "http://example.com/non-existing.jpg"
RED_TRUCK_IMAGE_FILE_PATH = os.path.dirname(__file__) + "/assets/red-truck.png"
BEER_VIDEO_URL = "https://samples.clarifai.com/beer.mp4"

MAIN_APP_ID = "main"
MAIN_APP_USER_ID = "clarifai"
WORKFLOW_ID = "General"


@pytest.fixture
def workflow():
return Workflow(
user_id=MAIN_APP_USER_ID,
app_id=MAIN_APP_ID,
workflow_id=WORKFLOW_ID,
output_config=resources_pb2.OutputConfig(max_concepts=3))


def test_workflow_predict_image_url(workflow):
post_workflows_response = workflow.predict_by_url(DOG_IMAGE_URL, input_type="image")

assert len(post_workflows_response.results[0].outputs[0].data.concepts) > 0


def test_workflow_predict_image_bytes(workflow):
with open(RED_TRUCK_IMAGE_FILE_PATH, "rb") as f:
file_bytes = f.read()
post_workflows_response = workflow.predict_by_bytes(file_bytes, input_type="image")

assert len(post_workflows_response.results[0].outputs[0].data.concepts) > 0


def test_workflow_predict_max_concepts():
workflow = Workflow(
user_id=MAIN_APP_USER_ID,
app_id=MAIN_APP_ID,
workflow_id=WORKFLOW_ID,
output_config=resources_pb2.OutputConfig(max_concepts=3))
post_workflows_response = workflow.predict_by_url(DOG_IMAGE_URL, input_type="image")

assert len(post_workflows_response.results[0].outputs[0].data.concepts) == 3

0 comments on commit dc5fbdc

Please sign in to comment.