Skip to content

Commit

Permalink
ci: only return 1 model_name per file (#1818)
Browse files Browse the repository at this point in the history
* only return 1 model_name per file

* fix args parse

* revert test change
  • Loading branch information
isaac-chung authored Jan 16, 2025
1 parent 60c4980 commit d7a7791
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,5 @@ build-docs:
model-load-test:
@echo "--- 🚀 Running model load test ---"
pip install ".[dev, speedtask, pylate,gritlm,xformers,model2vec]"
python scripts/extract_model_names.py $(BASE_BRANCH)
python scripts/extract_model_names.py $(BASE_BRANCH) --return_one_model_name_per_file
python tests/test_models/model_loading.py --model_name_file scripts/model_names.txt
41 changes: 37 additions & 4 deletions scripts/extract_model_names.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

import argparse
import ast
import sys
import logging
from pathlib import Path

from git import Repo

logging.basicConfig(level=logging.INFO)


def get_changed_files(base_branch="main"):
repo_path = Path(__file__).parent.parent
Expand All @@ -28,8 +31,11 @@ def get_changed_files(base_branch="main"):
]


def extract_model_names(files: list[str]) -> list[str]:
def extract_model_names(
files: list[str], return_one_model_name_per_file=False
) -> list[str]:
model_names = []
first_model_found = False
for file in files:
with open(file) as f:
tree = ast.parse(f.read())
Expand All @@ -52,17 +58,44 @@ def extract_model_names(files: list[str]) -> list[str]:
)
if model_name:
model_names.append(model_name)
first_model_found = True
if return_one_model_name_per_file and first_model_found:
logging.info(f"Found model name {model_name} in file {file}")
break # NOTE: Only take the first model_name per file to avoid disk out of space issue.
return model_names


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"base_branch",
nargs="?",
default="main",
help="Base branch to compare changes with",
)
parser.add_argument(
"--return_one_model_name_per_file",
action="store_true",
default=False,
help="Only return one model name per file.",
)
return parser.parse_args()


if __name__ == "__main__":
"""
Can pass in base branch as an argument. Defaults to 'main'.
e.g. python extract_model_names.py mieb
"""
base_branch = sys.argv[1] if len(sys.argv) > 1 else "main"

args = parse_args()

base_branch = args.base_branch
changed_files = get_changed_files(base_branch)
model_names = extract_model_names(changed_files)
model_names = extract_model_names(
changed_files,
return_one_model_name_per_file=args.return_one_model_name_per_file,
)
output_file = Path(__file__).parent / "model_names.txt"
with output_file.open("w") as f:
f.write(" ".join(model_names))

0 comments on commit d7a7791

Please sign in to comment.