diff --git a/Makefile b/Makefile index 02d0ba247..9729d080f 100644 --- a/Makefile +++ b/Makefile @@ -41,5 +41,5 @@ build-docs: model-load-test: @echo "--- 🚀 Running model load test ---" pip install ".[dev, speedtask, pylate,gritlm,xformers,model2vec]" - python scripts/extract_model_names.py $(BASE_BRANCH) + python scripts/extract_model_names.py $(BASE_BRANCH) --return_one_model_name_per_file python tests/test_models/model_loading.py --model_name_file scripts/model_names.txt \ No newline at end of file diff --git a/scripts/extract_model_names.py b/scripts/extract_model_names.py index ba1bc1a8b..6cbaa2c29 100644 --- a/scripts/extract_model_names.py +++ b/scripts/extract_model_names.py @@ -1,11 +1,14 @@ from __future__ import annotations +import argparse import ast -import sys +import logging from pathlib import Path from git import Repo +logging.basicConfig(level=logging.INFO) + def get_changed_files(base_branch="main"): repo_path = Path(__file__).parent.parent @@ -28,8 +31,11 @@ def get_changed_files(base_branch="main"): ] -def extract_model_names(files: list[str]) -> list[str]: +def extract_model_names( + files: list[str], return_one_model_name_per_file=False +) -> list[str]: model_names = [] + first_model_found = False for file in files: with open(file) as f: tree = ast.parse(f.read()) @@ -52,17 +58,44 @@ def extract_model_names(files: list[str]) -> list[str]: ) if model_name: model_names.append(model_name) + first_model_found = True + if return_one_model_name_per_file and first_model_found: + logging.info(f"Found model name {model_name} in file {file}") + break # NOTE: Only take the first model_name per file to avoid disk out of space issue. return model_names +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "base_branch", + nargs="?", + default="main", + help="Base branch to compare changes with", + ) + parser.add_argument( + "--return_one_model_name_per_file", + action="store_true", + default=False, + help="Only return one model name per file.", + ) + return parser.parse_args() + + if __name__ == "__main__": """ Can pass in base branch as an argument. Defaults to 'main'. e.g. python extract_model_names.py mieb """ - base_branch = sys.argv[1] if len(sys.argv) > 1 else "main" + + args = parse_args() + + base_branch = args.base_branch changed_files = get_changed_files(base_branch) - model_names = extract_model_names(changed_files) + model_names = extract_model_names( + changed_files, + return_one_model_name_per_file=args.return_one_model_name_per_file, + ) output_file = Path(__file__).parent / "model_names.txt" with output_file.open("w") as f: f.write(" ".join(model_names))