Skip to content

Commit

Permalink
Improve UX for model upload (#420)
Browse files Browse the repository at this point in the history
* improve UX for model upload

* improve logging for cv models

* improve logging for cv models

* updated cuda dockerfile

* fix runner tests

---------

Co-authored-by: Sai Nivedh <[email protected]>
  • Loading branch information
luv-bansal and sainivedh authored Oct 14, 2024
1 parent 67cf227 commit 341d26b
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 107 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ ENV CLARIFAI_API_BASE=${CLARIFAI_API_BASE}

# Set the NUMBA cache dir to /tmp
ENV NUMBA_CACHE_DIR=/tmp/numba_cache
ENV HOME=/tmp

# Set the working directory to /app
WORKDIR /app
Expand Down
32 changes: 23 additions & 9 deletions clarifai/runners/models/model_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,17 +271,31 @@ def upload_model_version(self, download_checkpoints):
logger.info(f"Will tar it into file: {file_path}")

model_type_id = self.config.get('model').get('model_type_id')
repo_id, hf_token = self._validate_config_checkpoints()

loader = HuggingFaceLoader(repo_id=repo_id, token=hf_token)

if not download_checkpoints and not loader.validate_download(self.checkpoint_path) and (
model_type_id in self.CONCEPTS_REQUIRED_MODEL_TYPE) and 'concepts' not in self.config:
logger.error(
f"Model type {model_type_id} requires concepts to be specified in the config file or download the model checkpoints to infer the concepts."
if (model_type_id in self.CONCEPTS_REQUIRED_MODEL_TYPE) and 'concepts' not in self.config:
logger.info(
f"Model type {model_type_id} requires concepts to be specified in the config.yaml file.."
)
input("Press Enter to download the checkpoints to infer the concepts and continue...")
self.download_checkpoints()
if self.config.get("checkpoints"):
logger.info(
"Checkpoints specified in the config.yaml file, will download the HF model's config.json file to infer the concepts."
)

if not download_checkpoints and not HuggingFaceLoader.validate_config(
self.checkpoint_path):

input(
"Press Enter to download the HuggingFace model's config.json file to infer the concepts and continue..."
)
repo_id, hf_token = self._validate_config_checkpoints()
loader = HuggingFaceLoader(repo_id=repo_id, token=hf_token)
loader.download_config(self.checkpoint_path)

else:
logger.error(
"No checkpoints specified in the config.yaml file to infer the concepts. Please either specify the concepts directly in the config.yaml file or include a checkpoints section to download the HF model's config.json file to infer the concepts."
)
return

model_version_proto = self.get_model_version_proto()

Expand Down
38 changes: 37 additions & 1 deletion clarifai/runners/utils/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,17 @@ def __init__(self, repo_id=None, token=None):
if importlib.util.find_spec("huggingface_hub") is None:
raise ImportError(self.HF_DOWNLOAD_TEXT)
os.environ['HF_TOKEN'] = token
from huggingface_hub import HfApi

api = HfApi()
api.whoami(token=token)

subprocess.run(f'huggingface-cli login --token={os.environ["HF_TOKEN"]}', shell=True)
except Exception as e:
Exception("Error setting up Hugging Face token ", e)
logger.error(
f"Error setting up Hugging Face token, please make sure you have the correct token: {e}"
)
logger.info("Continuing without Hugging Face token")

def download_checkpoints(self, checkpoint_path: str):
# throw error if huggingface_hub wasn't installed
Expand Down Expand Up @@ -50,6 +58,28 @@ def download_checkpoints(self, checkpoint_path: str):
return False
return True

def download_config(self, checkpoint_path: str):
# throw error if huggingface_hub wasn't installed
try:
from huggingface_hub import hf_hub_download
except ImportError:
raise ImportError(self.HF_DOWNLOAD_TEXT)
if os.path.exists(checkpoint_path) and os.path.exists(
os.path.join(checkpoint_path, 'config.json')):
logger.info("HF model's config.json already exists")
return True
os.makedirs(checkpoint_path, exist_ok=True)
try:
is_hf_model_exists = self.validate_hf_model()
if not is_hf_model_exists:
logger.error("Model %s not found on Hugging Face" % (self.repo_id))
return False
hf_hub_download(repo_id=self.repo_id, filename='config.json', local_dir=checkpoint_path)
except Exception as e:
logger.error(f"Error downloading model's config.json {e}")
return False
return True

def validate_hf_model(self,):
# check if model exists on HF
try:
Expand All @@ -70,6 +100,12 @@ def validate_download(self, checkpoint_path: str):
return (len(checkpoint_dir_files) >= len(list_repo_files(self.repo_id))) and len(
list_repo_files(self.repo_id)) > 0

@staticmethod
def validate_config(checkpoint_path: str):
# check if downloaded config.json exists
return os.path.exists(checkpoint_path) and os.path.exists(
os.path.join(checkpoint_path, 'config.json'))

@staticmethod
def fetch_labels(checkpoint_path: str):
# Fetch labels for classification, detection and segmentation models
Expand Down
165 changes: 68 additions & 97 deletions tests/runners/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def init_components(
client: BaseClient,
app_id,
model_id,
runner_id,
nodepool_id,
compute_cluster_id,
):
Expand All @@ -41,96 +40,72 @@ def init_components(
new_model = model.create_version()

new_model_version = new_model.model_version.id
try:
compute_cluster_delete_request = service_pb2.DeleteComputeClustersRequest(
user_app_id=client.user_app_id,
ids=[compute_cluster_id],
)
client.STUB.DeleteComputeClusters(compute_cluster_delete_request)

nodepool_delete_request = service_pb2.DeleteNodepoolsRequest(
user_app_id=client.user_app_id, compute_cluster_id=compute_cluster_id, ids=[nodepool_id])
client.STUB.DeleteNodepools(nodepool_delete_request)

runner_delete_request = service_pb2.DeleteRunnersRequest(
user_app_id=client.user_app_id,
compute_cluster_id=compute_cluster_id,
nodepool_id=nodepool_id,
ids=[runner_id],
)
client.STUB.DeleteRunners(runner_delete_request)
except Exception as _:
pass
finally:
compute_cluster = resources_pb2.ComputeCluster(
id=compute_cluster_id,
description="test runners repo",
cloud_provider=resources_pb2.CloudProvider(id="local", name="Colo 1"),
region="us-east-1",
user_id=auth.user_id,
cluster_type="local-dev",
managed_by="user",
key=resources_pb2.Key(id=os.environ["CLARIFAI_PAT"]),
)
compute_cluster_request = service_pb2.PostComputeClustersRequest(
user_app_id=client.user_app_id,
compute_clusters=[compute_cluster],
)
res = client.STUB.PostComputeClusters(compute_cluster_request)
if res.status.code != status_code_pb2.SUCCESS:
logger.error(json_format.MessageToDict(res, preserving_proto_field_name=True))
raise Exception(res.status)

nodepool = resources_pb2.Nodepool(
id=nodepool_id,
description="test runners repo",
compute_cluster=compute_cluster,
node_capacity_type=resources_pb2.NodeCapacityType(capacity_types=[1, 2]),
instance_types=[
resources_pb2.InstanceType(
id='instance-1',
compute_info=resources_pb2.ComputeInfo(
cpu_limit="1",
cpu_memory="8Gi",
num_accelerators=0,
),
)
],
max_instances=1,
)
nodepools_request = service_pb2.PostNodepoolsRequest(
user_app_id=client.user_app_id,
compute_cluster_id=compute_cluster_id,
nodepools=[nodepool])
res = client.STUB.PostNodepools(nodepools_request)
if res.status.code != status_code_pb2.SUCCESS:
logger.error(json_format.MessageToDict(res, preserving_proto_field_name=True))
raise Exception(res.status)

runner = resources_pb2.Runner(
id=runner_id,
description="test runners repo",
worker=resources_pb2.Worker(model=resources_pb2.Model(
id=model_id,
user_id=auth.user_id,
app_id=app_id,
model_version=resources_pb2.ModelVersion(id=new_model_version),
)),
num_replicas=1,
nodepool=nodepool,
)
runners_request = service_pb2.PostRunnersRequest(
user_app_id=client.user_app_id,
compute_cluster_id=compute_cluster_id,
nodepool_id=nodepool_id,
runners=[runner],
)
res = client.STUB.PostRunners(runners_request)
if res.status.code != status_code_pb2.SUCCESS:
logger.error(json_format.MessageToDict(res, preserving_proto_field_name=True))
raise Exception(res.status)

return new_model_version
compute_cluster = resources_pb2.ComputeCluster(
id=compute_cluster_id,
description="test runners repo",
cloud_provider=resources_pb2.CloudProvider(id="local", name="Colo 1"),
region="us-east-1",
user_id=auth.user_id,
cluster_type="local-dev",
managed_by="user",
key=resources_pb2.Key(id=os.environ["CLARIFAI_PAT"]),
)
compute_cluster_request = service_pb2.PostComputeClustersRequest(
user_app_id=client.user_app_id,
compute_clusters=[compute_cluster],
)
res = client.STUB.PostComputeClusters(compute_cluster_request)
if res.status.code != status_code_pb2.SUCCESS:
logger.error(json_format.MessageToDict(res, preserving_proto_field_name=True))
raise Exception(res.status)

nodepool = resources_pb2.Nodepool(
id=nodepool_id,
description="test runners repo",
compute_cluster=compute_cluster,
node_capacity_type=resources_pb2.NodeCapacityType(capacity_types=[1, 2]),
instance_types=[
resources_pb2.InstanceType(
id='instance-1',
compute_info=resources_pb2.ComputeInfo(
cpu_limit="1",
cpu_memory="8Gi",
num_accelerators=0,
),
)
],
max_instances=1,
)
nodepools_request = service_pb2.PostNodepoolsRequest(
user_app_id=client.user_app_id, compute_cluster_id=compute_cluster_id, nodepools=[nodepool])
res = client.STUB.PostNodepools(nodepools_request)
if res.status.code != status_code_pb2.SUCCESS:
logger.error(json_format.MessageToDict(res, preserving_proto_field_name=True))
raise Exception(res.status)

runner = resources_pb2.Runner(
description="test runners repo",
worker=resources_pb2.Worker(model=resources_pb2.Model(
id=model_id,
user_id=auth.user_id,
app_id=app_id,
model_version=resources_pb2.ModelVersion(id=new_model_version),
)),
num_replicas=1,
nodepool=nodepool,
)
runners_request = service_pb2.PostRunnersRequest(
user_app_id=client.user_app_id,
compute_cluster_id=compute_cluster_id,
nodepool_id=nodepool_id,
runners=[runner],
)
res = client.STUB.PostRunners(runners_request)
if res.status.code != status_code_pb2.SUCCESS:
logger.error(json_format.MessageToDict(res, preserving_proto_field_name=True))
raise Exception(res.status)

return new_model_version, res.runners[0].id


@pytest.mark.requires_secrets
Expand All @@ -140,7 +115,6 @@ class TestRunnerServer:
def setup_class(cls):
NOW = uuid.uuid4().hex[:10]
cls.MODEL_ID = f"test-runner-model-{NOW}"
cls.RUNNER_ID = f"test-runner-{NOW}"
cls.NODEPOOL_ID = f"test-nodepool-{NOW}"
cls.COMPUTE_CLUSTER_ID = f"test-compute_cluster-{NOW}"
cls.APP_ID = f"ci-test-runner-app-{NOW}"
Expand All @@ -151,12 +125,11 @@ def setup_class(cls):
cls.logger = logger
cls.logger.info("Starting runner server")

cls.MODEL_VERSION_ID = init_components(
cls.MODEL_VERSION_ID, cls.RUNNER_ID = init_components(
cls.AUTH,
cls.CLIENT,
cls.APP_ID,
cls.MODEL_ID,
cls.RUNNER_ID,
cls.NODEPOOL_ID,
cls.COMPUTE_CLUSTER_ID,
)
Expand Down Expand Up @@ -267,20 +240,18 @@ class TestWrapperRunnerServer(TestRunnerServer):
def setup_class(cls):
NOW = uuid.uuid4().hex[:10]
cls.MODEL_ID = f"test-runner-model-{NOW}"
cls.RUNNER_ID = f"test-runner-{NOW}"
cls.NODEPOOL_ID = f"test-nodepool-{NOW}"
cls.COMPUTE_CLUSTER_ID = f"test-compute_cluster-{NOW}"
cls.APP_ID = f"ci-test-runner-app-{NOW}"
cls.CLIENT = BaseClient.from_env()
cls.AUTH = cls.CLIENT.auth_helper
cls.AUTH.app_id = cls.APP_ID

cls.MODEL_VERSION_ID = init_components(
cls.MODEL_VERSION_ID, cls.RUNNER_ID = init_components(
cls.AUTH,
cls.CLIENT,
cls.APP_ID,
cls.MODEL_ID,
cls.RUNNER_ID,
cls.NODEPOOL_ID,
cls.COMPUTE_CLUSTER_ID,
)
Expand Down

0 comments on commit 341d26b

Please sign in to comment.