From 611c713f63a69368111686235f1ea7b0a0efe063 Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Wed, 8 Jan 2025 18:09:47 -0800 Subject: [PATCH] allow non-HF models to be pushed to the DB --- database/utils.py | 87 +++++++++++++++++++++++++++++++++++++++----- eval/eval.py | 11 +++--- eval/eval_tracker.py | 23 +++++++++--- 3 files changed, 101 insertions(+), 20 deletions(-) diff --git a/database/utils.py b/database/utils.py index c142050d..1e1de3aa 100644 --- a/database/utils.py +++ b/database/utils.py @@ -14,6 +14,19 @@ from uuid import UUID import uuid from contextlib import contextmanager +import openai + + +def get_full_openai_model_name(alias): + try: + # Make a simple request using the alias + response = openai.chat.completions.create( + model=alias, messages=[{"role": "system", "content": "Identify the model name."}], max_tokens=1 + ) + # Extract and return the full model name from the response + return response.model + except Exception as e: + return f"An error occurred: {str(e)}" def create_db_engine() -> Tuple[Engine, sessionmaker]: @@ -254,31 +267,40 @@ def get_model_from_db(id: "UUID") -> Model: return model_db_obj.to_dict() -def get_or_add_model_by_name(hf_model: str): +def get_or_add_model_by_name(model: str, model_source: str = "hf"): """ - Given hf_model path, return UUID of hf_model. + Given model path, return UUID of model. Checks for existence by using git commit hash. If doesn't exist in DB, create an entry and return UUID of entry. If there exists more than one entry in DB, return UUID of latest model by last_modified. Args: - hf_model (str): The path or identifier for the Hugging Face model. + model (str): The path or identifier for the Hugging Face or other model. """ - git_commit_hash = HfApi().model_info(hf_model).sha + if model_source == "hf": + git_commit_hash = HfApi().model_info(model).sha + else: + if "openai" in model_source: + model = get_full_openai_model_name(model) + git_commit_hash = model + "_" + datetime.now(timezone.utc).strftime("%Y-%m-%d-%H-%M-%S") + with session_scope() as session: model_instances = ( session.query(Model) - .filter(Model.weights_location == hf_model) + .filter(Model.weights_location == model) .filter(Model.git_commit_hash == git_commit_hash) .all() ) model_instances = [i.to_dict() for i in model_instances] - if len(model_instances) == 0: - print(f"{hf_model} doesn't exist in database. Creating entry:") - return register_hf_model_to_db(hf_model) + if len(model_instances) == 0 and model_source == "hf": + print(f"{model} doesn't exist in database. Creating entry:") + return register_hf_model_to_db(model) + elif len(model_instances) == 0: + print(f"{model} doesn't exist in database. Creating entry:") + return register_model_to_db(model, model_source) elif len(model_instances) > 1: - print(f"WARNING: Model {hf_model} has multiple entries in DB. Returning latest match.") + print(f"WARNING: Model {model} has multiple entries in DB. Returning latest match.") model_instances = sorted(model_instances, key=lambda x: (x["last_modified"] is not None, x["last_modified"])) for i in model_instances: print(f"id: {i['id']}, git_commit_hash: {i['git_commit_hash']}") @@ -354,3 +376,50 @@ def register_hf_model_to_db(hf_model: str, force: bool = False): print(f"Model successfully registered to db! {model}") return id + + +def register_model_to_db(model_name: str, model_source: str) -> UUID: + """ + Registers a new model to the database for non-HuggingFace models. + + Args: + model_name (str): The name or identifier for the model + model_source (str): Source of the model (e.g., 'openai-chat-completions', 'anthropic') + + Returns: + UUID: The unique identifier assigned to the registered model + + Raises: + ValueError: If the model cannot be registered due to missing metadata + """ + id = uuid.uuid4() + creation_time = datetime.now(timezone.utc) + + # Create a unique git_commit_hash-like identifier using timestamp + git_commit_hash = f"{model_name}_{creation_time.strftime('%Y-%m-%d-%H-%M-%S')}" + + with session_scope() as session: + model = Model( + id=id, + name=model_name, + base_model_id=id, + created_by=model_source, + creation_location=model_source, + creation_time=creation_time, + training_start=creation_time, + training_end=creation_time, + training_parameters=None, + training_status=None, + dataset_id=None, + is_external=True, + weights_location=model_name, + wandb_link=None, + git_commit_hash=git_commit_hash, + last_modified=creation_time, + ) + + session.add(model) + session.commit() + print(f"Model successfully registered to db! {model}") + + return id diff --git a/eval/eval.py b/eval/eval.py index 3f6dc397..88e21e6f 100644 --- a/eval/eval.py +++ b/eval/eval.py @@ -506,11 +506,12 @@ def handle_evaluation_output( if args.use_database and not args.debug: evaluation_tracker.update_evalresults_db( results, - args.model_id, - args.model_name, - args.creation_location, - args.created_by, - args.is_external_model, + model_id=args.model_id, + model_source=args.model, + model_name=args.model_name, + creation_location=args.creation_location, + created_by=args.created_by, + is_external=args.is_external_model, ) if args.log_samples: diff --git a/eval/eval_tracker.py b/eval/eval_tracker.py index 7f886233..f3188512 100644 --- a/eval/eval_tracker.py +++ b/eval/eval_tracker.py @@ -179,7 +179,9 @@ def save_results_aggregated( else: eval_logger.info("Output path not provided, skipping saving results aggregated") - def get_or_create_model(self, model_name: str, model_id: Optional[str]) -> Tuple[uuid.UUID, uuid.UUID]: + def get_or_create_model( + self, model_name: str, model_id: Optional[str], model_source: str = "hf" + ) -> Tuple[uuid.UUID, uuid.UUID]: """ Retrieve an existing model or create a new one in the database. @@ -196,7 +198,7 @@ def get_or_create_model(self, model_name: str, model_id: Optional[str]) -> Tuple assert model_name or model_id try: if not model_id: - model_id = get_or_add_model_by_name(model_name) + model_id = get_or_add_model_by_name(model_name, model_source) model_configs = get_model_from_db(model_id) return model_id, model_configs["dataset_id"] except Exception as e: @@ -333,6 +335,7 @@ def update_evalresults_db( self, eval_log_dict: Dict[str, Any], model_id: Optional[str], + model_source: str = "hf", model_name: Optional[str] = None, creation_location: Optional[str] = None, created_by: Optional[str] = None, @@ -344,6 +347,7 @@ def update_evalresults_db( Args: eval_log_dict: Dictionary containing evaluation logs and results model_id: Optional UUID of the model + model_source: Source of the model (similar to the model arg in eval.py) model_name: Optional name of the model creation_location: Location where evaluation was run created_by: Username who ran the evaluation @@ -359,11 +363,18 @@ def update_evalresults_db( args_dict = simple_parse_args_string(eval_log_dict["config"]["model_args"]) model_name = args_dict["pretrained"] - weights_location = ( - f"https://huggingface.co/{model_name}" if is_external and check_hf_model_exists(model_name) else "NA" - ) + if model_source == "hf": + weights_location = ( + f"https://huggingface.co/{model_name}" + if is_external and check_hf_model_exists(model_name) + else "NA" + ) + else: + weights_location = "NA" - model_id, dataset_id = self.get_or_create_model(model_name=model_name, model_id=model_id) + model_id, dataset_id = self.get_or_create_model( + model_name=model_name, model_id=model_id, model_source=model_source + ) eval_logger.info(f"Updating results for model_id: {str(model_id)}") results = eval_log_dict["results"]