Skip to content

Commit

Permalink
allow non-HF models to be pushed to the DB
Browse files Browse the repository at this point in the history
  • Loading branch information
jmercat committed Jan 9, 2025
1 parent 04bf2e6 commit 611c713
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 20 deletions.
87 changes: 78 additions & 9 deletions database/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,19 @@
from uuid import UUID
import uuid
from contextlib import contextmanager
import openai


def get_full_openai_model_name(alias):
try:
# Make a simple request using the alias
response = openai.chat.completions.create(
model=alias, messages=[{"role": "system", "content": "Identify the model name."}], max_tokens=1
)
# Extract and return the full model name from the response
return response.model
except Exception as e:
return f"An error occurred: {str(e)}"


def create_db_engine() -> Tuple[Engine, sessionmaker]:
Expand Down Expand Up @@ -254,31 +267,40 @@ def get_model_from_db(id: "UUID") -> Model:
return model_db_obj.to_dict()


def get_or_add_model_by_name(hf_model: str):
def get_or_add_model_by_name(model: str, model_source: str = "hf"):
"""
Given hf_model path, return UUID of hf_model.
Given model path, return UUID of model.
Checks for existence by using git commit hash.
If doesn't exist in DB, create an entry and return UUID of entry.
If there exists more than one entry in DB, return UUID of latest model by last_modified.
Args:
hf_model (str): The path or identifier for the Hugging Face model.
model (str): The path or identifier for the Hugging Face or other model.
"""
git_commit_hash = HfApi().model_info(hf_model).sha
if model_source == "hf":
git_commit_hash = HfApi().model_info(model).sha
else:
if "openai" in model_source:
model = get_full_openai_model_name(model)
git_commit_hash = model + "_" + datetime.now(timezone.utc).strftime("%Y-%m-%d-%H-%M-%S")

with session_scope() as session:
model_instances = (
session.query(Model)
.filter(Model.weights_location == hf_model)
.filter(Model.weights_location == model)
.filter(Model.git_commit_hash == git_commit_hash)
.all()
)
model_instances = [i.to_dict() for i in model_instances]

if len(model_instances) == 0:
print(f"{hf_model} doesn't exist in database. Creating entry:")
return register_hf_model_to_db(hf_model)
if len(model_instances) == 0 and model_source == "hf":
print(f"{model} doesn't exist in database. Creating entry:")
return register_hf_model_to_db(model)
elif len(model_instances) == 0:
print(f"{model} doesn't exist in database. Creating entry:")
return register_model_to_db(model, model_source)
elif len(model_instances) > 1:
print(f"WARNING: Model {hf_model} has multiple entries in DB. Returning latest match.")
print(f"WARNING: Model {model} has multiple entries in DB. Returning latest match.")
model_instances = sorted(model_instances, key=lambda x: (x["last_modified"] is not None, x["last_modified"]))
for i in model_instances:
print(f"id: {i['id']}, git_commit_hash: {i['git_commit_hash']}")
Expand Down Expand Up @@ -354,3 +376,50 @@ def register_hf_model_to_db(hf_model: str, force: bool = False):
print(f"Model successfully registered to db! {model}")

return id


def register_model_to_db(model_name: str, model_source: str) -> UUID:
"""
Registers a new model to the database for non-HuggingFace models.
Args:
model_name (str): The name or identifier for the model
model_source (str): Source of the model (e.g., 'openai-chat-completions', 'anthropic')
Returns:
UUID: The unique identifier assigned to the registered model
Raises:
ValueError: If the model cannot be registered due to missing metadata
"""
id = uuid.uuid4()
creation_time = datetime.now(timezone.utc)

# Create a unique git_commit_hash-like identifier using timestamp
git_commit_hash = f"{model_name}_{creation_time.strftime('%Y-%m-%d-%H-%M-%S')}"

with session_scope() as session:
model = Model(
id=id,
name=model_name,
base_model_id=id,
created_by=model_source,
creation_location=model_source,
creation_time=creation_time,
training_start=creation_time,
training_end=creation_time,
training_parameters=None,
training_status=None,
dataset_id=None,
is_external=True,
weights_location=model_name,
wandb_link=None,
git_commit_hash=git_commit_hash,
last_modified=creation_time,
)

session.add(model)
session.commit()
print(f"Model successfully registered to db! {model}")

return id
11 changes: 6 additions & 5 deletions eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,11 +506,12 @@ def handle_evaluation_output(
if args.use_database and not args.debug:
evaluation_tracker.update_evalresults_db(
results,
args.model_id,
args.model_name,
args.creation_location,
args.created_by,
args.is_external_model,
model_id=args.model_id,
model_source=args.model,
model_name=args.model_name,
creation_location=args.creation_location,
created_by=args.created_by,
is_external=args.is_external_model,
)

if args.log_samples:
Expand Down
23 changes: 17 additions & 6 deletions eval/eval_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ def save_results_aggregated(
else:
eval_logger.info("Output path not provided, skipping saving results aggregated")

def get_or_create_model(self, model_name: str, model_id: Optional[str]) -> Tuple[uuid.UUID, uuid.UUID]:
def get_or_create_model(
self, model_name: str, model_id: Optional[str], model_source: str = "hf"
) -> Tuple[uuid.UUID, uuid.UUID]:
"""
Retrieve an existing model or create a new one in the database.
Expand All @@ -196,7 +198,7 @@ def get_or_create_model(self, model_name: str, model_id: Optional[str]) -> Tuple
assert model_name or model_id
try:
if not model_id:
model_id = get_or_add_model_by_name(model_name)
model_id = get_or_add_model_by_name(model_name, model_source)
model_configs = get_model_from_db(model_id)
return model_id, model_configs["dataset_id"]
except Exception as e:
Expand Down Expand Up @@ -333,6 +335,7 @@ def update_evalresults_db(
self,
eval_log_dict: Dict[str, Any],
model_id: Optional[str],
model_source: str = "hf",
model_name: Optional[str] = None,
creation_location: Optional[str] = None,
created_by: Optional[str] = None,
Expand All @@ -344,6 +347,7 @@ def update_evalresults_db(
Args:
eval_log_dict: Dictionary containing evaluation logs and results
model_id: Optional UUID of the model
model_source: Source of the model (similar to the model arg in eval.py)
model_name: Optional name of the model
creation_location: Location where evaluation was run
created_by: Username who ran the evaluation
Expand All @@ -359,11 +363,18 @@ def update_evalresults_db(
args_dict = simple_parse_args_string(eval_log_dict["config"]["model_args"])
model_name = args_dict["pretrained"]

weights_location = (
f"https://huggingface.co/{model_name}" if is_external and check_hf_model_exists(model_name) else "NA"
)
if model_source == "hf":
weights_location = (
f"https://huggingface.co/{model_name}"
if is_external and check_hf_model_exists(model_name)
else "NA"
)
else:
weights_location = "NA"

model_id, dataset_id = self.get_or_create_model(model_name=model_name, model_id=model_id)
model_id, dataset_id = self.get_or_create_model(
model_name=model_name, model_id=model_id, model_source=model_source
)
eval_logger.info(f"Updating results for model_id: {str(model_id)}")

results = eval_log_dict["results"]
Expand Down

0 comments on commit 611c713

Please sign in to comment.