Skip to content

Commit

Permalink
[DEVX:123] Support Dataset Upload (#140)
Browse files Browse the repository at this point in the history
* support dataset upload
  • Loading branch information
sainivedh authored Aug 10, 2023
1 parent 6ab6067 commit 30cbc1f
Show file tree
Hide file tree
Showing 19 changed files with 1,533 additions and 37 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,6 @@ apps = client.list_apps()
# Create app and dataset
app = client.create_app(app_id="demo_app")
dataset = app.create_dataset(dataset_id="demo_dataset")
# execute data upload to Clarifai app dataset
dataset.upload_dataset(task='visual_segmentation', split="train", dataset_loader='coco_segmentation')
```
56 changes: 32 additions & 24 deletions clarifai/client/app.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from clarifai_grpc.grpc.api import resources_pb2, service_pb2 # noqa: F401
from typing import List

from clarifai_grpc.grpc.api import resources_pb2, service_pb2
from clarifai_grpc.grpc.api.status import status_code_pb2
from google.protobuf.json_format import MessageToDict

from clarifai.client.base import BaseClient
from clarifai.client.dataset import Dataset
from clarifai.client.lister import Lister
from clarifai.client.model import Model
from clarifai.client.workflow import Workflow
from clarifai.utils.logging import get_logger


class App(BaseClient):
class App(Lister, BaseClient):
"""
App is a class that provides access to Clarifai API endpoints related to App information.
Inherits from BaseClient for authentication purposes.
Expand All @@ -26,13 +29,23 @@ def __init__(self, app_id: str, **kwargs):
self.kwargs = {**kwargs, 'id': app_id}
self.app_info = resources_pb2.App(**self.kwargs)
self.logger = get_logger(logger_level="INFO", name=__name__)
super().__init__(app_id=self.id)

def list_datasets(self):
"""
Lists all the datasets for the app.
"""
pass # TODO
BaseClient.__init__(self, user_id=self.user_id, app_id=self.id)
Lister.__init__(self)

def list_datasets(self) -> List[Dataset]:
"""Lists all the datasets for the app."""
request_data = dict(
user_app_id=self.user_app_id,
per_page=self.default_page_size,
)
all_datasets_info = list(
self.list_all_pages_generator(self.STUB.ListDatasets, service_pb2.ListDatasetsRequest,
request_data))
for dataset_info in all_datasets_info:
if 'version' in list(dataset_info.keys()):
del dataset_info['version']['metrics']

return [Dataset(**dataset_info) for dataset_info in all_datasets_info]

def list_models(self):
"""
Expand Down Expand Up @@ -61,7 +74,7 @@ def create_dataset(self, dataset_id: str, **kwargs) -> Dataset:
Dataset: A Dataset object for the specified dataset ID.
"""
request = service_pb2.PostDatasetsRequest(
user_app_id=self.userDataObject, datasets=[resources_pb2.Dataset(id=dataset_id, **kwargs)])
user_app_id=self.user_app_id, datasets=[resources_pb2.Dataset(id=dataset_id, **kwargs)])
response = self._grpc_request(self.STUB.PostDatasets, request)
if response.status.code != status_code_pb2.SUCCESS:
raise Exception(response.status)
Expand All @@ -78,7 +91,7 @@ def create_model(self, model_id: str, **kwargs) -> Model:
Model: A Model object for the specified model ID.
"""
request = service_pb2.PostModelsRequest(
user_app_id=self.userDataObject, models=[resources_pb2.Model(id=model_id, **kwargs)])
user_app_id=self.user_app_id, models=[resources_pb2.Model(id=model_id, **kwargs)])
response = self._grpc_request(self.STUB.PostModels, request)
if response.status.code != status_code_pb2.SUCCESS:
raise Exception(response.status)
Expand All @@ -95,8 +108,7 @@ def create_workflow(self, workflow_id: str, **kwargs) -> Workflow:
Workflow: A Workflow object for the specified workflow ID.
"""
request = service_pb2.PostWorkflowsRequest(
user_app_id=self.userDataObject,
workflows=[resources_pb2.Workflow(id=workflow_id, **kwargs)])
user_app_id=self.user_app_id, workflows=[resources_pb2.Workflow(id=workflow_id, **kwargs)])
response = self._grpc_request(self.STUB.PostWorkflows, request)
if response.status.code != status_code_pb2.SUCCESS:
raise Exception(response.status)
Expand All @@ -111,7 +123,7 @@ def dataset(self, dataset_id: str, **kwargs) -> Dataset:
Returns:
Dataset: A Dataset object for the existing dataset ID.
"""
request = service_pb2.GetDatasetRequest(user_app_id=self.userDataObject, dataset_id=dataset_id)
request = service_pb2.GetDatasetRequest(user_app_id=self.user_app_id, dataset_id=dataset_id)
response = self._grpc_request(self.STUB.GetDataset, request)
if response.status.code != status_code_pb2.SUCCESS:
raise Exception(response.status)
Expand All @@ -128,14 +140,12 @@ def model(self, model_id: str, **kwargs) -> Model:
Returns:
Model: A Model object for the existing model ID.
"""
request = service_pb2.GetModelRequest(user_app_id=self.userDataObject, model_id=model_id)
request = service_pb2.GetModelRequest(user_app_id=self.user_app_id, model_id=model_id)
response = self._grpc_request(self.STUB.GetModel, request)
if response.status.code != status_code_pb2.SUCCESS:
raise Exception(response.status)
dict_response = MessageToDict(response)
kwargs = self.convert_keys_to_snake_case(dict_response[list(dict_response.keys())[1]],
list(dict_response.keys())[1])

kwargs = self.convert_keys_to_snake_case(dict_response['model'], 'model')
return Model(**kwargs)

def workflow(self, workflow_id: str, **kwargs) -> Workflow:
Expand All @@ -145,8 +155,7 @@ def workflow(self, workflow_id: str, **kwargs) -> Workflow:
Returns:
Workflow: A Workflow object for the existing workflow ID.
"""
request = service_pb2.GetWorkflowRequest(
user_app_id=self.userDataObject, workflow_id=workflow_id)
request = service_pb2.GetWorkflowRequest(user_app_id=self.user_app_id, workflow_id=workflow_id)
response = self._grpc_request(self.STUB.GetWorkflow, request)
if response.status.code != status_code_pb2.SUCCESS:
raise Exception(response.status)
Expand All @@ -162,7 +171,7 @@ def delete_dataset(self, dataset_id: str) -> None:
dataset_id (str): The dataset ID for the app to delete.
"""
request = service_pb2.DeleteDatasetsRequest(
user_app_id=self.userDataObject, dataset_ids=[dataset_id])
user_app_id=self.user_app_id, dataset_ids=[dataset_id])
response = self._grpc_request(self.STUB.DeleteDatasets, request)
if response.status.code != status_code_pb2.SUCCESS:
raise Exception(response.status)
Expand All @@ -173,7 +182,7 @@ def delete_model(self, model_id: str) -> None:
Args:
model_id (str): The model ID for the app to delete.
"""
request = service_pb2.DeleteModelsRequest(user_app_id=self.userDataObject, ids=[model_id])
request = service_pb2.DeleteModelsRequest(user_app_id=self.user_app_id, ids=[model_id])
response = self._grpc_request(self.STUB.DeleteModels, request)
if response.status.code != status_code_pb2.SUCCESS:
raise Exception(response.status)
Expand All @@ -184,8 +193,7 @@ def delete_workflow(self, workflow_id: str) -> None:
Args:
workflow_id (str): The workflow ID for the app to delete.
"""
request = service_pb2.DeleteWorkflowsRequest(
user_app_id=self.userDataObject, ids=[workflow_id])
request = service_pb2.DeleteWorkflowsRequest(user_app_id=self.user_app_id, ids=[workflow_id])
response = self._grpc_request(self.STUB.DeleteWorkflows, request)
if response.status.code != status_code_pb2.SUCCESS:
raise Exception(response.status)
Expand Down
15 changes: 9 additions & 6 deletions clarifai/client/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from datetime import datetime
from typing import Any, Callable

from google.protobuf.json_format import MessageToDict
from google.protobuf.json_format import MessageToDict # noqa
from google.protobuf.timestamp_pb2 import Timestamp
from google.protobuf.wrappers_pb2 import BoolValue

from clarifai.client.auth import create_stub
from clarifai.client.auth.helper import ClarifaiAuthHelper
Expand Down Expand Up @@ -32,7 +33,7 @@ def __init__(self, **kwargs):
self.auth_helper = ClarifaiAuthHelper(**kwargs)
self.STUB = create_stub(self.auth_helper)
self.metadata = self.auth_helper.metadata
self.userDataObject = self.auth_helper.get_user_app_id_proto()
self.user_app_id = self.auth_helper.get_user_app_id_proto()
self.base = self.auth_helper.base

def _grpc_request(self, method: Callable, argument: Any):
Expand All @@ -46,7 +47,7 @@ def _grpc_request(self, method: Callable, argument: Any):

try:
res = method(argument)
MessageToDict(res)
# MessageToDict(res) TODO global debug logger
return res
except ApiError:
raise Exception("ApiError")
Expand Down Expand Up @@ -93,10 +94,12 @@ def convert_recursive(item):
if isinstance(item, dict):
new_item = {}
for key, value in item.items():
if key in ['createdAt', 'modifiedAt']:
if key in ['createdAt', 'modifiedAt', 'completedAt']:
value = self.convert_string_to_timestamp(value)
if key in ['metadata', 'workflowRecommended', 'modelVersion']:
continue # TODO Fix "app_duplication",modelVersion error
elif key in ['workflowRecommended']:
value = BoolValue(value=True)
elif key in ['metadata', 'fieldsMap']:
continue # TODO Fix "app_duplication",fieldsMap(text key) error
new_key = snake_case(key)
new_item[new_key] = convert_recursive(value)
return new_item
Expand Down
Loading

0 comments on commit 30cbc1f

Please sign in to comment.