Skip to content

Commit

Permalink
Improve experience for cv models (#408)
Browse files Browse the repository at this point in the history
* improve experience for cv models

* added checkpoint loader checkpoint validate and improve classes
  • Loading branch information
luv-bansal authored Oct 8, 2024
1 parent b2d6c36 commit 8469c39
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 27 deletions.
67 changes: 42 additions & 25 deletions clarifai/runners/models/model_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from rich import print

from clarifai.client import BaseClient
from clarifai.runners.utils.loader import HuggingFaceLoarder
from clarifai.runners.utils.loader import HuggingFaceLoader
from clarifai.urls.helper import ClarifaiUrlHelper
from clarifai.utils.logging import logger

Expand Down Expand Up @@ -59,6 +59,28 @@ def _load_config(config_file: str):
config = yaml.safe_load(file)
return config

@staticmethod
def _validate_config_checkpoints(self):
if not self.config.get("checkpoints"):
logger.info("No checkpoints specified in the config file")
return None

assert "type" in self.config.get("checkpoints"), "No loader type specified in the config file"
loader_type = self.config.get("checkpoints").get("type")
if not loader_type:
logger.info("No loader type specified in the config file for checkpoints")
assert loader_type == "huggingface", "Only huggingface loader supported for now"
if loader_type == "huggingface":
assert "repo_id" in self.config.get("checkpoints"), "No repo_id specified in the config file"
repo_id = self.config.get("checkpoints").get("repo_id")

# prefer env var for HF_TOKEN but if not provided then use the one from config.yaml if any.
if 'HF_TOKEN' in os.environ:
hf_token = os.environ['HF_TOKEN']
else:
hf_token = self.config.get("checkpoints").get("hf_token", None)
return repo_id, hf_token

@property
def client(self):
if self._client is None:
Expand Down Expand Up @@ -180,26 +202,9 @@ def tar_file(self):
return f"{self.folder}.tar.gz"

def download_checkpoints(self):
if not self.config.get("checkpoints"):
logger.info("No checkpoints specified in the config file")
return

assert "type" in self.config.get("checkpoints"), "No loader type specified in the config file"
loader_type = self.config.get("checkpoints").get("type")
if not loader_type:
logger.info("No loader type specified in the config file for checkpoints")
assert loader_type == "huggingface", "Only huggingface loader supported for now"
if loader_type == "huggingface":
assert "repo_id" in self.config.get("checkpoints"), "No repo_id specified in the config file"
repo_id = self.config.get("checkpoints").get("repo_id")

# prefer env var for HF_TOKEN but if not provided then use the one from config.yaml if any.
if 'HF_TOKEN' in os.environ:
hf_token = os.environ['HF_TOKEN']
else:
hf_token = self.config.get("checkpoints").get("hf_token", None)
assert hf_token != 'hf_token', "The default 'hf_token' is not valid. Please provide a valid token or leave that field out of config.yaml if not needed."
loader = HuggingFaceLoarder(repo_id=repo_id, token=hf_token)
repo_id, hf_token = self._validate_config_checkpoints()
if repo_id and hf_token:
loader = HuggingFaceLoader(repo_id=repo_id, token=hf_token)

success = loader.download_checkpoints(self.checkpoint_path)

Expand Down Expand Up @@ -242,8 +247,7 @@ def _get_model_version_proto(self):
model_type_id = self.config.get('model').get('model_type_id')
if model_type_id in self.CONCEPTS_REQUIRED_MODEL_TYPE:

loader = HuggingFaceLoarder()
labels = loader.fetch_labels(self.checkpoint_path)
labels = HuggingFaceLoader.fetch_labels(self.checkpoint_path)
# sort the concepts by id and then update the config file
labels = sorted(labels.items(), key=lambda x: int(x[0]))

Expand All @@ -258,6 +262,21 @@ def upload_model_version(self, download_checkpoints):
file_path = f"{self.folder}.tar.gz"
logger.info(f"Will tar it into file: {file_path}")

model_type_id = self.config.get('model').get('model_type_id')
repo_id, hf_token = self._validate_config_checkpoints()

loader = HuggingFaceLoader(repo_id=repo_id, token=hf_token)

if not download_checkpoints and not loader.validate_download(self.checkpoint_path) and (
model_type_id in self.CONCEPTS_REQUIRED_MODEL_TYPE) and 'concepts' not in self.config:
logger.error(
f"Model type {model_type_id} requires concepts to be specified in the config file or download the model checkpoints to infer the concepts."
)
input("Press Enter to download the checkpoints to infer the concepts and continue...")
self.download_checkpoints()

model_version_proto = self._get_model_version_proto()

if download_checkpoints:
tar_cmd = f"tar --exclude=*~ -czvf {self.tar_file} -C {self.folder} ."
else: # we don't want to send the checkpoints up even if they are in the folder.
Expand All @@ -268,8 +287,6 @@ def upload_model_version(self, download_checkpoints):
os.system(tar_cmd)
logger.info("Tarring complete, about to start upload.")

model_version_proto = self._get_model_version_proto()

file_size = os.path.getsize(self.tar_file)
logger.info(f"Size of the tar is: {file_size} bytes")

Expand Down
5 changes: 3 additions & 2 deletions clarifai/runners/utils/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from clarifai.utils.logging import logger


class HuggingFaceLoarder:
class HuggingFaceLoader:

def __init__(self, repo_id=None, token=None):
self.repo_id = repo_id
Expand Down Expand Up @@ -67,7 +67,8 @@ def validate_download(self, checkpoint_path: str):
return (len(checkpoint_dir_files) >= len(list_repo_files(self.repo_id))) and len(
list_repo_files(self.repo_id)) > 0

def fetch_labels(self, checkpoint_path: str):
@staticmethod
def fetch_labels(checkpoint_path: str):
# Fetch labels for classification, detection and segmentation models
config_path = os.path.join(checkpoint_path, 'config.json')
with open(config_path, 'r') as f:
Expand Down

0 comments on commit 8469c39

Please sign in to comment.