Skip to content

Commit

Permalink
Add pipeline_tag to model card (#1287)
Browse files Browse the repository at this point in the history
  • Loading branch information
davanstrien authored May 23, 2023
1 parent 74e6484 commit 307a15f
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion bertopic/_save_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
tags:
- bertopic
library_name: bertopic
pipeline_tag: {PIPELINE_TAG}
---
# {MODEL_NAME}
Expand Down Expand Up @@ -284,6 +285,13 @@ def generate_readme(model, repo_id: str):
model_card = model_card.replace("{HYPERPARAMS}", params)
model_card = model_card.replace("{FRAMEWORKS}", frameworks)

# Fill Pipeline tag
has_visual_aspect = check_has_visual_aspect(model)
if not has_visual_aspect:
model_card = model_card.replace("{PIPELINE_TAG}", "text-classification")
else:
model_card = model_card.replace("pipeline_tag: {PIPELINE_TAG} /n","") # TODO add proper tag for this instance

return model_card


Expand Down Expand Up @@ -363,6 +371,13 @@ def save_config(model, path: str, embedding_model):

return config

def check_has_visual_aspect(model):
"""Check if model has visual aspect"""
if _has_vision:
for aspect, value in model.topic_aspects_.items():
if isinstance(value[0], Image.Image):
visual_aspects = model.topic_aspects_[aspect]
return True

def save_images(model, path: str):
""" Save topic images """
Expand Down Expand Up @@ -470,4 +485,4 @@ def save_safetensors(path, tensors):
import safetensors
safetensors.torch.save_file(tensors, path)
except ImportError:
raise ValueError("`pip install safetensors` to save as .safetensors")
raise ValueError("`pip install safetensors` to save as .safetensors")

0 comments on commit 307a15f

Please sign in to comment.