From 92ca1ee2e033721ff14d44f3b7eaa6962b7c934d Mon Sep 17 00:00:00 2001 From: sanjaychelliah Date: Fri, 4 Aug 2023 01:18:44 +0530 Subject: [PATCH] review_changes --- clarifai/client/app.py | 6 +++--- clarifai/client/dataset.py | 6 ++++-- clarifai/client/model.py | 5 +++-- clarifai/client/workflow.py | 6 ++++-- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/clarifai/client/app.py b/clarifai/client/app.py index 3ea7027b..9ce140eb 100644 --- a/clarifai/client/app.py +++ b/clarifai/client/app.py @@ -115,7 +115,7 @@ def dataset(self, dataset_id: str, **kwargs) -> Dataset: if response.status.code != status_code_pb2.SUCCESS: raise Exception(response.status) - return Dataset(dataset_id=dataset_id, **kwargs) + return Dataset(dataset_id=dataset_id, dataset_info=response.dataset) def model(self, model_id: str, **kwargs) -> Model: """Returns a Model object for the existing model ID. @@ -129,7 +129,7 @@ def model(self, model_id: str, **kwargs) -> Model: if response.status.code != status_code_pb2.SUCCESS: raise Exception(response.status) - return Model(model_id=model_id, **kwargs) + return Model(model_id=model_id, model_info=response.model) def workflow(self, workflow_id: str, **kwargs) -> Workflow: """Returns a workflow object for the existing workflow ID. @@ -144,7 +144,7 @@ def workflow(self, workflow_id: str, **kwargs) -> Workflow: if response.status.code != status_code_pb2.SUCCESS: raise Exception(response.status) - return Workflow(workflow_id=workflow_id, **kwargs) + return Workflow(workflow_id=workflow_id, workflow_info=response.workflow) def delete_dataset(self, dataset_id: str) -> None: """Deletes an dataset for the user. diff --git a/clarifai/client/dataset.py b/clarifai/client/dataset.py index 3652e653..5caa7021 100644 --- a/clarifai/client/dataset.py +++ b/clarifai/client/dataset.py @@ -8,14 +8,16 @@ class Dataset(BaseClient): Inherits from BaseClient for authentication purposes. """ - def __init__(self, dataset_id: str, **kwargs): + def __init__(self, dataset_id: str, dataset_info: resources_pb2.Dataset = None, **kwargs): """Initializes an Dataset object. Args: dataset_id (str): The Dataset ID within the App to interact with. **kwargs: Additional keyword arguments to be passed to the ClarifaiAuthHelper. """ self.kwargs = {**kwargs, 'id': dataset_id} - self.dataset_info = resources_pb2.Dataset(**self.kwargs) + self.dataset_info = resources_pb2.Dataset( + **self.kwargs) if dataset_info is None else dataset_info + super().__init__(app_id=self.app_id) def __getattr__(self, name): return getattr(self.dataset_info, name) diff --git a/clarifai/client/model.py b/clarifai/client/model.py index 7334801a..7241531d 100644 --- a/clarifai/client/model.py +++ b/clarifai/client/model.py @@ -8,14 +8,15 @@ class Model(BaseClient): Inherits from BaseClient for authentication purposes. """ - def __init__(self, model_id: str, **kwargs): + def __init__(self, model_id: str, model_info: resources_pb2.Model = None, **kwargs): """Initializes an Model object. Args: model_id (str): The Model ID to interact with. **kwargs: Additional keyword arguments to be passed to the ClarifaiAuthHelper. """ self.kwargs = {**kwargs, 'id': model_id} - self.model_info = resources_pb2.Model(**self.kwargs) + self.model_info = resources_pb2.Model(**self.kwargs) if model_info is None else model_info + super().__init__(app_id=self.app_id) def __getattr__(self, name): return getattr(self.model_info, name) diff --git a/clarifai/client/workflow.py b/clarifai/client/workflow.py index 5bd321ea..9297c580 100644 --- a/clarifai/client/workflow.py +++ b/clarifai/client/workflow.py @@ -8,14 +8,16 @@ class Workflow(BaseClient): Inherits from BaseClient for authentication purposes. """ - def __init__(self, workflow_id: str, **kwargs): + def __init__(self, workflow_id: str, workflow_info: resources_pb2.Workflow = None, **kwargs): """Initializes an Workflow object. Args: workflow_id (str): The Workflow ID to interact with. **kwargs: Additional keyword arguments to be passed to the ClarifaiAuthHelper. """ self.kwargs = {**kwargs, 'id': workflow_id} - self.workflow_info = resources_pb2.Workflow(**self.kwargs) + self.workflow_info = resources_pb2.Workflow( + **self.kwargs) if workflow_info is None else workflow_info + super().__init__(app_id=self.app_id) def __getattr__(self, name): return getattr(self.workflow_info, name)