Skip to content

Commit

Permalink
Fix download_checkpoints and fix run model locally (#415)
Browse files Browse the repository at this point in the history
* fix download_checkpoints and run model locally

* fix concepts protos during modle upload
  • Loading branch information
luv-bansal authored Oct 9, 2024
1 parent 0d7c3b7 commit 6f56f51
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 16 deletions.
8 changes: 5 additions & 3 deletions clarifai/runners/models/model_run_locally.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ class ModelRunLocally:
def __init__(self, model_path):
self.model_path = model_path
self.requirements_file = os.path.join(self.model_path, "requirements.txt")
self.venv_dir, self.temp_dir = self.create_temp_venv()
self.python_executable = os.path.join(self.venv_dir, "bin", "python")

def create_temp_venv(self):
"""Create a temporary virtual environment."""
Expand All @@ -32,6 +30,10 @@ def create_temp_venv(self):
venv_dir = os.path.join(temp_dir, "venv")
venv.create(venv_dir, with_pip=True)

self.venv_dir = venv_dir
self.temp_dir = temp_dir
self.python_executable = os.path.join(venv_dir, "bin", "python")

logger.info(f"Created temporary virtual environment at {venv_dir}")
return venv_dir, temp_dir

Expand Down Expand Up @@ -125,7 +127,6 @@ def _run_test(self):
nodepool_id="n/a",
compute_cluster_id="n/a",
)
runner.load_model()

# send an inference.
response = self._run_model_inference(runner)
Expand Down Expand Up @@ -182,6 +183,7 @@ def main():

model_path = args.model_path
manager = ModelRunLocally(model_path)
manager.create_temp_venv()

try:
manager.install_requirements()
Expand Down
35 changes: 22 additions & 13 deletions clarifai/runners/models/model_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,14 +202,15 @@ def tar_file(self):

def download_checkpoints(self):
repo_id, hf_token = self._validate_config_checkpoints()
if repo_id and hf_token:
loader = HuggingFaceLoader(repo_id=repo_id, token=hf_token)
success = loader.download_checkpoints(self.checkpoint_path)

if not success:
logger.error(f"Failed to download checkpoints for model {repo_id}")
return
loader = HuggingFaceLoader(repo_id=repo_id, token=hf_token)
success = loader.download_checkpoints(self.checkpoint_path)

if not success:
logger.error(f"Failed to download checkpoints for model {repo_id}")
else:
logger.info(f"Downloaded checkpoints for model {repo_id}")
return success

def _concepts_protos_from_concepts(self, concepts):
concept_protos = []
Expand Down Expand Up @@ -245,15 +246,23 @@ def get_model_version_proto(self):
model_type_id = self.config.get('model').get('model_type_id')
if model_type_id in self.CONCEPTS_REQUIRED_MODEL_TYPE:

labels = HuggingFaceLoader.fetch_labels(self.checkpoint_path)
# sort the concepts by id and then update the config file
labels = sorted(labels.items(), key=lambda x: int(x[0]))
if 'concepts' in self.config:
labels = self.config.get('concepts')
logger.info(f"Found {len(labels)} concepts in the config file.")
for concept in labels:
concept_proto = json_format.ParseDict(concept, resources_pb2.Concept())
model_version_proto.output_info.data.concepts.append(concept_proto)
else:
labels = HuggingFaceLoader.fetch_labels(self.checkpoint_path)
logger.info(f"Found {len(labels)} concepts from the model checkpoints.")
# sort the concepts by id and then update the config file
labels = sorted(labels.items(), key=lambda x: int(x[0]))

config_file = os.path.join(self.folder, 'config.yaml')
self.hf_labels_to_config(labels, config_file)
config_file = os.path.join(self.folder, 'config.yaml')
self.hf_labels_to_config(labels, config_file)

model_version_proto.output_info.data.concepts.extend(
self._concepts_protos_from_concepts(labels))
model_version_proto.output_info.data.concepts.extend(
self._concepts_protos_from_concepts(labels))
return model_version_proto

def upload_model_version(self, download_checkpoints):
Expand Down

0 comments on commit 6f56f51

Please sign in to comment.