Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DEVX:123] Support Dataset Upload #140

Merged
merged 5 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,6 @@ apps = client.list_apps()
# Create app and dataset
app = client.create_app(app_id="demo_app")
dataset = app.create_dataset(dataset_id="demo_dataset")
# execute data upload to Clarifai app dataset
dataset.upload_dataset(task='visual_segmentation', split="train", dataset_loader='coco_segmentation')
```
56 changes: 32 additions & 24 deletions clarifai/client/app.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from clarifai_grpc.grpc.api import resources_pb2, service_pb2 # noqa: F401
from typing import List

from clarifai_grpc.grpc.api import resources_pb2, service_pb2
from clarifai_grpc.grpc.api.status import status_code_pb2
from google.protobuf.json_format import MessageToDict

from clarifai.client.base import BaseClient
from clarifai.client.dataset import Dataset
from clarifai.client.lister import Lister
from clarifai.client.model import Model
from clarifai.client.workflow import Workflow
from clarifai.utils.logging import get_logger


class App(BaseClient):
class App(Lister, BaseClient):
"""
App is a class that provides access to Clarifai API endpoints related to App information.
Inherits from BaseClient for authentication purposes.
Expand All @@ -26,13 +29,23 @@ def __init__(self, app_id: str, **kwargs):
self.kwargs = {**kwargs, 'id': app_id}
self.app_info = resources_pb2.App(**self.kwargs)
self.logger = get_logger(logger_level="INFO", name=__name__)
super().__init__(app_id=self.id)

def list_datasets(self):
"""
Lists all the datasets for the app.
"""
pass # TODO
BaseClient.__init__(self, user_id=self.user_id, app_id=self.id)
Lister.__init__(self)

def list_datasets(self) -> List[Dataset]:
"""Lists all the datasets for the app."""
request_data = dict(
user_app_id=self.user_app_id,
per_page=self.default_page_size,
)
all_datasets_info = list(
self.list_all_pages_generator(self.STUB.ListDatasets, service_pb2.ListDatasetsRequest,
request_data))
for dataset_info in all_datasets_info:
if 'version' in list(dataset_info.keys()):
del dataset_info['version']['metrics']

return [Dataset(**dataset_info) for dataset_info in all_datasets_info]

def list_models(self):
"""
Expand Down Expand Up @@ -61,7 +74,7 @@ def create_dataset(self, dataset_id: str, **kwargs) -> Dataset:
Dataset: A Dataset object for the specified dataset ID.
"""
request = service_pb2.PostDatasetsRequest(
user_app_id=self.userDataObject, datasets=[resources_pb2.Dataset(id=dataset_id, **kwargs)])
user_app_id=self.user_app_id, datasets=[resources_pb2.Dataset(id=dataset_id, **kwargs)])
response = self._grpc_request(self.STUB.PostDatasets, request)
if response.status.code != status_code_pb2.SUCCESS:
raise Exception(response.status)
Expand All @@ -78,7 +91,7 @@ def create_model(self, model_id: str, **kwargs) -> Model:
Model: A Model object for the specified model ID.
"""
request = service_pb2.PostModelsRequest(
user_app_id=self.userDataObject, models=[resources_pb2.Model(id=model_id, **kwargs)])
user_app_id=self.user_app_id, models=[resources_pb2.Model(id=model_id, **kwargs)])
response = self._grpc_request(self.STUB.PostModels, request)
if response.status.code != status_code_pb2.SUCCESS:
raise Exception(response.status)
Expand All @@ -95,8 +108,7 @@ def create_workflow(self, workflow_id: str, **kwargs) -> Workflow:
Workflow: A Workflow object for the specified workflow ID.
"""
request = service_pb2.PostWorkflowsRequest(
user_app_id=self.userDataObject,
workflows=[resources_pb2.Workflow(id=workflow_id, **kwargs)])
user_app_id=self.user_app_id, workflows=[resources_pb2.Workflow(id=workflow_id, **kwargs)])
response = self._grpc_request(self.STUB.PostWorkflows, request)
if response.status.code != status_code_pb2.SUCCESS:
raise Exception(response.status)
Expand All @@ -111,7 +123,7 @@ def dataset(self, dataset_id: str, **kwargs) -> Dataset:
Returns:
Dataset: A Dataset object for the existing dataset ID.
"""
request = service_pb2.GetDatasetRequest(user_app_id=self.userDataObject, dataset_id=dataset_id)
request = service_pb2.GetDatasetRequest(user_app_id=self.user_app_id, dataset_id=dataset_id)
response = self._grpc_request(self.STUB.GetDataset, request)
if response.status.code != status_code_pb2.SUCCESS:
raise Exception(response.status)
Expand All @@ -128,14 +140,12 @@ def model(self, model_id: str, **kwargs) -> Model:
Returns:
Model: A Model object for the existing model ID.
"""
request = service_pb2.GetModelRequest(user_app_id=self.userDataObject, model_id=model_id)
request = service_pb2.GetModelRequest(user_app_id=self.user_app_id, model_id=model_id)
response = self._grpc_request(self.STUB.GetModel, request)
if response.status.code != status_code_pb2.SUCCESS:
raise Exception(response.status)
dict_response = MessageToDict(response)
kwargs = self.convert_keys_to_snake_case(dict_response[list(dict_response.keys())[1]],
list(dict_response.keys())[1])

kwargs = self.convert_keys_to_snake_case(dict_response['model'], 'model')
return Model(**kwargs)

def workflow(self, workflow_id: str, **kwargs) -> Workflow:
Expand All @@ -145,8 +155,7 @@ def workflow(self, workflow_id: str, **kwargs) -> Workflow:
Returns:
Workflow: A Workflow object for the existing workflow ID.
"""
request = service_pb2.GetWorkflowRequest(
user_app_id=self.userDataObject, workflow_id=workflow_id)
request = service_pb2.GetWorkflowRequest(user_app_id=self.user_app_id, workflow_id=workflow_id)
response = self._grpc_request(self.STUB.GetWorkflow, request)
if response.status.code != status_code_pb2.SUCCESS:
raise Exception(response.status)
Expand All @@ -162,7 +171,7 @@ def delete_dataset(self, dataset_id: str) -> None:
dataset_id (str): The dataset ID for the app to delete.
"""
request = service_pb2.DeleteDatasetsRequest(
user_app_id=self.userDataObject, dataset_ids=[dataset_id])
user_app_id=self.user_app_id, dataset_ids=[dataset_id])
response = self._grpc_request(self.STUB.DeleteDatasets, request)
if response.status.code != status_code_pb2.SUCCESS:
raise Exception(response.status)
Expand All @@ -173,7 +182,7 @@ def delete_model(self, model_id: str) -> None:
Args:
model_id (str): The model ID for the app to delete.
"""
request = service_pb2.DeleteModelsRequest(user_app_id=self.userDataObject, ids=[model_id])
request = service_pb2.DeleteModelsRequest(user_app_id=self.user_app_id, ids=[model_id])
response = self._grpc_request(self.STUB.DeleteModels, request)
if response.status.code != status_code_pb2.SUCCESS:
raise Exception(response.status)
Expand All @@ -184,8 +193,7 @@ def delete_workflow(self, workflow_id: str) -> None:
Args:
workflow_id (str): The workflow ID for the app to delete.
"""
request = service_pb2.DeleteWorkflowsRequest(
user_app_id=self.userDataObject, ids=[workflow_id])
request = service_pb2.DeleteWorkflowsRequest(user_app_id=self.user_app_id, ids=[workflow_id])
response = self._grpc_request(self.STUB.DeleteWorkflows, request)
if response.status.code != status_code_pb2.SUCCESS:
raise Exception(response.status)
Expand Down
15 changes: 9 additions & 6 deletions clarifai/client/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from datetime import datetime
from typing import Any, Callable

from google.protobuf.json_format import MessageToDict
from google.protobuf.json_format import MessageToDict # noqa
from google.protobuf.timestamp_pb2 import Timestamp
from google.protobuf.wrappers_pb2 import BoolValue

from clarifai.client.auth import create_stub
from clarifai.client.auth.helper import ClarifaiAuthHelper
Expand Down Expand Up @@ -32,7 +33,7 @@ def __init__(self, **kwargs):
self.auth_helper = ClarifaiAuthHelper(**kwargs)
self.STUB = create_stub(self.auth_helper)
self.metadata = self.auth_helper.metadata
self.userDataObject = self.auth_helper.get_user_app_id_proto()
self.user_app_id = self.auth_helper.get_user_app_id_proto()
self.base = self.auth_helper.base

def _grpc_request(self, method: Callable, argument: Any):
Expand All @@ -46,7 +47,7 @@ def _grpc_request(self, method: Callable, argument: Any):

try:
res = method(argument)
MessageToDict(res)
# MessageToDict(res) TODO global debug logger
return res
except ApiError:
raise Exception("ApiError")
Expand Down Expand Up @@ -93,10 +94,12 @@ def convert_recursive(item):
if isinstance(item, dict):
new_item = {}
for key, value in item.items():
if key in ['createdAt', 'modifiedAt']:
if key in ['createdAt', 'modifiedAt', 'completedAt']:
value = self.convert_string_to_timestamp(value)
if key in ['metadata', 'workflowRecommended', 'modelVersion']:
continue # TODO Fix "app_duplication",modelVersion error
elif key in ['workflowRecommended']:
value = BoolValue(value=True)
elif key in ['metadata', 'fieldsMap']:
continue # TODO Fix "app_duplication",fieldsMap(text key) error
new_key = snake_case(key)
new_item[new_key] = convert_recursive(value)
return new_item
Expand Down
Loading
Loading