diff --git a/clarifai/runners/models/model_run_locally.py b/clarifai/runners/models/model_run_locally.py index 471c0fa..9adfc64 100644 --- a/clarifai/runners/models/model_run_locally.py +++ b/clarifai/runners/models/model_run_locally.py @@ -22,8 +22,6 @@ class ModelRunLocally: def __init__(self, model_path): self.model_path = model_path self.requirements_file = os.path.join(self.model_path, "requirements.txt") - self.venv_dir, self.temp_dir = self.create_temp_venv() - self.python_executable = os.path.join(self.venv_dir, "bin", "python") def create_temp_venv(self): """Create a temporary virtual environment.""" @@ -32,6 +30,10 @@ def create_temp_venv(self): venv_dir = os.path.join(temp_dir, "venv") venv.create(venv_dir, with_pip=True) + self.venv_dir = venv_dir + self.temp_dir = temp_dir + self.python_executable = os.path.join(venv_dir, "bin", "python") + logger.info(f"Created temporary virtual environment at {venv_dir}") return venv_dir, temp_dir @@ -125,7 +127,6 @@ def _run_test(self): nodepool_id="n/a", compute_cluster_id="n/a", ) - runner.load_model() # send an inference. response = self._run_model_inference(runner) @@ -182,6 +183,7 @@ def main(): model_path = args.model_path manager = ModelRunLocally(model_path) + manager.create_temp_venv() try: manager.install_requirements() diff --git a/clarifai/runners/models/model_upload.py b/clarifai/runners/models/model_upload.py index 4a7e083..066544e 100644 --- a/clarifai/runners/models/model_upload.py +++ b/clarifai/runners/models/model_upload.py @@ -202,14 +202,15 @@ def tar_file(self): def download_checkpoints(self): repo_id, hf_token = self._validate_config_checkpoints() - if repo_id and hf_token: - loader = HuggingFaceLoader(repo_id=repo_id, token=hf_token) - success = loader.download_checkpoints(self.checkpoint_path) - if not success: - logger.error(f"Failed to download checkpoints for model {repo_id}") - return + loader = HuggingFaceLoader(repo_id=repo_id, token=hf_token) + success = loader.download_checkpoints(self.checkpoint_path) + + if not success: + logger.error(f"Failed to download checkpoints for model {repo_id}") + else: logger.info(f"Downloaded checkpoints for model {repo_id}") + return success def _concepts_protos_from_concepts(self, concepts): concept_protos = [] @@ -245,15 +246,23 @@ def get_model_version_proto(self): model_type_id = self.config.get('model').get('model_type_id') if model_type_id in self.CONCEPTS_REQUIRED_MODEL_TYPE: - labels = HuggingFaceLoader.fetch_labels(self.checkpoint_path) - # sort the concepts by id and then update the config file - labels = sorted(labels.items(), key=lambda x: int(x[0])) + if 'concepts' in self.config: + labels = self.config.get('concepts') + logger.info(f"Found {len(labels)} concepts in the config file.") + for concept in labels: + concept_proto = json_format.ParseDict(concept, resources_pb2.Concept()) + model_version_proto.output_info.data.concepts.append(concept_proto) + else: + labels = HuggingFaceLoader.fetch_labels(self.checkpoint_path) + logger.info(f"Found {len(labels)} concepts from the model checkpoints.") + # sort the concepts by id and then update the config file + labels = sorted(labels.items(), key=lambda x: int(x[0])) - config_file = os.path.join(self.folder, 'config.yaml') - self.hf_labels_to_config(labels, config_file) + config_file = os.path.join(self.folder, 'config.yaml') + self.hf_labels_to_config(labels, config_file) - model_version_proto.output_info.data.concepts.extend( - self._concepts_protos_from_concepts(labels)) + model_version_proto.output_info.data.concepts.extend( + self._concepts_protos_from_concepts(labels)) return model_version_proto def upload_model_version(self, download_checkpoints):