From 8469c394b708e13770d07f821530512c3d461acc Mon Sep 17 00:00:00 2001 From: Luv Bansal <70321430+luv-bansal@users.noreply.github.com> Date: Tue, 8 Oct 2024 22:19:39 +0530 Subject: [PATCH] Improve experience for cv models (#408) * improve experience for cv models * added checkpoint loader checkpoint validate and improve classes --- clarifai/runners/models/model_upload.py | 67 ++++++++++++++++--------- clarifai/runners/utils/loader.py | 5 +- 2 files changed, 45 insertions(+), 27 deletions(-) diff --git a/clarifai/runners/models/model_upload.py b/clarifai/runners/models/model_upload.py index 1ca46da..573173f 100644 --- a/clarifai/runners/models/model_upload.py +++ b/clarifai/runners/models/model_upload.py @@ -10,7 +10,7 @@ from rich import print from clarifai.client import BaseClient -from clarifai.runners.utils.loader import HuggingFaceLoarder +from clarifai.runners.utils.loader import HuggingFaceLoader from clarifai.urls.helper import ClarifaiUrlHelper from clarifai.utils.logging import logger @@ -59,6 +59,28 @@ def _load_config(config_file: str): config = yaml.safe_load(file) return config + @staticmethod + def _validate_config_checkpoints(self): + if not self.config.get("checkpoints"): + logger.info("No checkpoints specified in the config file") + return None + + assert "type" in self.config.get("checkpoints"), "No loader type specified in the config file" + loader_type = self.config.get("checkpoints").get("type") + if not loader_type: + logger.info("No loader type specified in the config file for checkpoints") + assert loader_type == "huggingface", "Only huggingface loader supported for now" + if loader_type == "huggingface": + assert "repo_id" in self.config.get("checkpoints"), "No repo_id specified in the config file" + repo_id = self.config.get("checkpoints").get("repo_id") + + # prefer env var for HF_TOKEN but if not provided then use the one from config.yaml if any. + if 'HF_TOKEN' in os.environ: + hf_token = os.environ['HF_TOKEN'] + else: + hf_token = self.config.get("checkpoints").get("hf_token", None) + return repo_id, hf_token + @property def client(self): if self._client is None: @@ -180,26 +202,9 @@ def tar_file(self): return f"{self.folder}.tar.gz" def download_checkpoints(self): - if not self.config.get("checkpoints"): - logger.info("No checkpoints specified in the config file") - return - - assert "type" in self.config.get("checkpoints"), "No loader type specified in the config file" - loader_type = self.config.get("checkpoints").get("type") - if not loader_type: - logger.info("No loader type specified in the config file for checkpoints") - assert loader_type == "huggingface", "Only huggingface loader supported for now" - if loader_type == "huggingface": - assert "repo_id" in self.config.get("checkpoints"), "No repo_id specified in the config file" - repo_id = self.config.get("checkpoints").get("repo_id") - - # prefer env var for HF_TOKEN but if not provided then use the one from config.yaml if any. - if 'HF_TOKEN' in os.environ: - hf_token = os.environ['HF_TOKEN'] - else: - hf_token = self.config.get("checkpoints").get("hf_token", None) - assert hf_token != 'hf_token', "The default 'hf_token' is not valid. Please provide a valid token or leave that field out of config.yaml if not needed." - loader = HuggingFaceLoarder(repo_id=repo_id, token=hf_token) + repo_id, hf_token = self._validate_config_checkpoints() + if repo_id and hf_token: + loader = HuggingFaceLoader(repo_id=repo_id, token=hf_token) success = loader.download_checkpoints(self.checkpoint_path) @@ -242,8 +247,7 @@ def _get_model_version_proto(self): model_type_id = self.config.get('model').get('model_type_id') if model_type_id in self.CONCEPTS_REQUIRED_MODEL_TYPE: - loader = HuggingFaceLoarder() - labels = loader.fetch_labels(self.checkpoint_path) + labels = HuggingFaceLoader.fetch_labels(self.checkpoint_path) # sort the concepts by id and then update the config file labels = sorted(labels.items(), key=lambda x: int(x[0])) @@ -258,6 +262,21 @@ def upload_model_version(self, download_checkpoints): file_path = f"{self.folder}.tar.gz" logger.info(f"Will tar it into file: {file_path}") + model_type_id = self.config.get('model').get('model_type_id') + repo_id, hf_token = self._validate_config_checkpoints() + + loader = HuggingFaceLoader(repo_id=repo_id, token=hf_token) + + if not download_checkpoints and not loader.validate_download(self.checkpoint_path) and ( + model_type_id in self.CONCEPTS_REQUIRED_MODEL_TYPE) and 'concepts' not in self.config: + logger.error( + f"Model type {model_type_id} requires concepts to be specified in the config file or download the model checkpoints to infer the concepts." + ) + input("Press Enter to download the checkpoints to infer the concepts and continue...") + self.download_checkpoints() + + model_version_proto = self._get_model_version_proto() + if download_checkpoints: tar_cmd = f"tar --exclude=*~ -czvf {self.tar_file} -C {self.folder} ." else: # we don't want to send the checkpoints up even if they are in the folder. @@ -268,8 +287,6 @@ def upload_model_version(self, download_checkpoints): os.system(tar_cmd) logger.info("Tarring complete, about to start upload.") - model_version_proto = self._get_model_version_proto() - file_size = os.path.getsize(self.tar_file) logger.info(f"Size of the tar is: {file_size} bytes") diff --git a/clarifai/runners/utils/loader.py b/clarifai/runners/utils/loader.py index 050732f..83fbe9f 100644 --- a/clarifai/runners/utils/loader.py +++ b/clarifai/runners/utils/loader.py @@ -6,7 +6,7 @@ from clarifai.utils.logging import logger -class HuggingFaceLoarder: +class HuggingFaceLoader: def __init__(self, repo_id=None, token=None): self.repo_id = repo_id @@ -67,7 +67,8 @@ def validate_download(self, checkpoint_path: str): return (len(checkpoint_dir_files) >= len(list_repo_files(self.repo_id))) and len( list_repo_files(self.repo_id)) > 0 - def fetch_labels(self, checkpoint_path: str): + @staticmethod + def fetch_labels(checkpoint_path: str): # Fetch labels for classification, detection and segmentation models config_path = os.path.join(checkpoint_path, 'config.json') with open(config_path, 'r') as f: