Skip to content

Commit

Permalink
modify generate, stream funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
sainivedh committed Oct 9, 2024
1 parent 7b299eb commit 16d3597
Showing 1 changed file with 50 additions and 20 deletions.
70 changes: 50 additions & 20 deletions clarifai/client/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ def predict_by_filepath(self,
Args:
filepath (str): The filepath to predict.
input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
compute_cluster_id (str): The compute cluster ID to use for the model.
nodepool_id (str): The nodepool ID to use for the model.
deployment_id (str): The deployment ID to use for the model.
Expand Down Expand Up @@ -529,7 +529,7 @@ def predict_by_bytes(self,
Args:
input_bytes (bytes): File Bytes to predict on.
input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
compute_cluster_id (str): The compute cluster ID to use for the model.
nodepool_id (str): The nodepool ID to use for the model.
deployment_id (str): The deployment ID to use for the model.
Expand Down Expand Up @@ -596,7 +596,7 @@ def predict_by_url(self,
Args:
url (str): The URL to predict.
input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio'.
input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio'.
compute_cluster_id (str): The compute cluster ID to use for the model.
nodepool_id (str): The nodepool ID to use for the model.
deployment_id (str): The deployment ID to use for the model.
Expand Down Expand Up @@ -709,7 +709,7 @@ def generate(self,

def generate_by_filepath(self,
filepath: str,
input_type: str,
input_type: str = None,
compute_cluster_id: str = None,
nodepool_id: str = None,
deployment_id: str = None,
Expand All @@ -719,7 +719,7 @@ def generate_by_filepath(self,
Args:
filepath (str): The filepath to predict.
input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
compute_cluster_id (str): The compute cluster ID to use for the model.
nodepool_id (str): The nodepool ID to use for the model.
deployment_id (str): The deployment ID to use for the model.
Expand Down Expand Up @@ -754,7 +754,7 @@ def generate_by_filepath(self,

def generate_by_bytes(self,
input_bytes: bytes,
input_type: str,
input_type: str = None,
compute_cluster_id: str = None,
nodepool_id: str = None,
deployment_id: str = None,
Expand All @@ -764,7 +764,7 @@ def generate_by_bytes(self,
Args:
input_bytes (bytes): File Bytes to predict on.
input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
compute_cluster_id (str): The compute cluster ID to use for the model.
nodepool_id (str): The nodepool ID to use for the model.
deployment_id (str): The deployment ID to use for the model.
Expand All @@ -778,11 +778,18 @@ def generate_by_bytes(self,
>>> from clarifai.client.model import Model
>>> model = Model("https://clarifai.com/openai/chat-completion/models/GPT-4")
>>> stream_response = model.generate_by_bytes(b'Write a tweet on future of AI',
input_type='text',
deployment_id='deployment_id',
inference_params=dict(temperature=str(0.7), max_tokens=30)))
>>> list_stream_response = [response for response in stream_response]
"""
if not input_type:
model_input_types = self.get_model_input_types()
if len(model_input_types) > 1:
raise UserError(
"Model has multiple input types. Please use model.predict() for this multi-modal model."
)
input_type = model_input_types[0]

if input_type not in {'image', 'text', 'video', 'audio'}:
raise UserError(
f"Got input type {input_type} but expected one of image, text, video, audio.")
Expand Down Expand Up @@ -818,7 +825,7 @@ def generate_by_bytes(self,

def generate_by_url(self,
url: str,
input_type: str,
input_type: str = None,
compute_cluster_id: str = None,
nodepool_id: str = None,
deployment_id: str = None,
Expand All @@ -828,7 +835,7 @@ def generate_by_url(self,
Args:
url (str): The URL to predict.
input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
compute_cluster_id (str): The compute cluster ID to use for the model.
nodepool_id (str): The nodepool ID to use for the model.
deployment_id (str): The deployment ID to use for the model.
Expand All @@ -843,9 +850,17 @@ def generate_by_url(self,
>>> model = Model("url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition
or
>>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
>>> stream_response = model.generate_by_url('url', 'image', deployment_id='deployment_id')
>>> stream_response = model.generate_by_url('url', deployment_id='deployment_id')
>>> list_stream_response = [response for response in stream_response]
"""
if not input_type:
model_input_types = self.get_model_input_types()
if len(model_input_types) > 1:
raise UserError(
"Model has multiple input types. Please use model.predict() for this multi-modal model."
)
input_type = model_input_types[0]

if input_type not in {'image', 'text', 'video', 'audio'}:
raise UserError(
f"Got input type {input_type} but expected one of image, text, video, audio.")
Expand Down Expand Up @@ -934,7 +949,7 @@ def stream(self,

def stream_by_filepath(self,
filepath: str,
input_type: str,
input_type: str = None,
compute_cluster_id: str = None,
nodepool_id: str = None,
deployment_id: str = None,
Expand All @@ -944,7 +959,7 @@ def stream_by_filepath(self,
Args:
filepath (str): The filepath to predict.
input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
compute_cluster_id (str): The compute cluster ID to use for the model.
nodepool_id (str): The nodepool ID to use for the model.
deployment_id (str): The deployment ID to use for the model.
Expand All @@ -957,7 +972,7 @@ def stream_by_filepath(self,
Example:
>>> from clarifai.client.model import Model
>>> model = Model("url")
>>> stream_response = model.stream_by_filepath('/path/to/image.jpg', 'image', deployment_id='deployment_id')
>>> stream_response = model.stream_by_filepath('/path/to/image.jpg', deployment_id='deployment_id')
>>> list_stream_response = [response for response in stream_response]
"""
if not os.path.isfile(filepath):
Expand All @@ -977,7 +992,7 @@ def stream_by_filepath(self,

def stream_by_bytes(self,
input_bytes_iterator: Iterator[bytes],
input_type: str,
input_type: str = None,
compute_cluster_id: str = None,
nodepool_id: str = None,
deployment_id: str = None,
Expand All @@ -987,7 +1002,7 @@ def stream_by_bytes(self,
Args:
input_bytes_iterator (Iterator[bytes]): Iterator of file bytes to predict on.
input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
compute_cluster_id (str): The compute cluster ID to use for the model.
nodepool_id (str): The nodepool ID to use for the model.
deployment_id (str): The deployment ID to use for the model.
Expand All @@ -1001,11 +1016,18 @@ def stream_by_bytes(self,
>>> from clarifai.client.model import Model
>>> model = Model("https://clarifai.com/openai/chat-completion/models/GPT-4")
>>> stream_response = model.stream_by_bytes(iter([b'Write a tweet on future of AI']),
input_type='text',
deployment_id='deployment_id',
inference_params=dict(temperature=str(0.7), max_tokens=30)))
>>> list_stream_response = [response for response in stream_response]
"""
if not input_type:
model_input_types = self.get_model_input_types()
if len(model_input_types) > 1:
raise UserError(
"Model has multiple input types. Please use model.predict() for this multi-modal model."
)
input_type = model_input_types[0]

if input_type not in {'image', 'text', 'video', 'audio'}:
raise UserError(
f"Got input type {input_type} but expected one of image, text, video, audio.")
Expand Down Expand Up @@ -1041,7 +1063,7 @@ def input_generator():

def stream_by_url(self,
url_iterator: Iterator[str],
input_type: str,
input_type: str = None,
compute_cluster_id: str = None,
nodepool_id: str = None,
deployment_id: str = None,
Expand All @@ -1051,7 +1073,7 @@ def stream_by_url(self,
Args:
url_iterator (Iterator[str]): Iterator of URLs to predict.
input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
compute_cluster_id (str): The compute cluster ID to use for the model.
nodepool_id (str): The nodepool ID to use for the model.
deployment_id (str): The deployment ID to use for the model.
Expand All @@ -1064,9 +1086,17 @@ def stream_by_url(self,
Example:
>>> from clarifai.client.model import Model
>>> model = Model("url")
>>> stream_response = model.stream_by_url(iter(['url']), 'image', deployment_id='deployment_id')
>>> stream_response = model.stream_by_url(iter(['url']), deployment_id='deployment_id')
>>> list_stream_response = [response for response in stream_response]
"""
if not input_type:
model_input_types = self.get_model_input_types()
if len(model_input_types) > 1:
raise UserError(
"Model has multiple input types. Please use model.predict() for this multi-modal model."
)
input_type = model_input_types[0]

if input_type not in {'image', 'text', 'video', 'audio'}:
raise UserError(
f"Got input type {input_type} but expected one of image, text, video, audio.")
Expand Down

0 comments on commit 16d3597

Please sign in to comment.