Skip to content

Commit

Permalink
fix: side effect when determining model changes
Browse files Browse the repository at this point in the history
don't switch current working git branch

Signed-off-by: jerryzhuang <[email protected]>
  • Loading branch information
zhuangqh committed Dec 20, 2024
1 parent f65fc92 commit bd1bcd2
Showing 1 changed file with 18 additions and 17 deletions.
35 changes: 18 additions & 17 deletions .github/workflows/kind-cluster/determine_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,14 @@ def read_yaml(file_path):
# Format: {falcon-7b : {model_name:falcon-7b, type:text-generation, version: #, tag: #}}
MODELS = {model['name']: model for model in YAML_PR['models']}
KAITO_REPO_URL = "https://github.com/kaito-project/kaito.git"
GITREMOTE_TARGET = "_ciupstream"

def set_multiline_output(name, value):
with open(os.environ['GITHUB_OUTPUT'], 'a') as fh:
if not os.getenv('GITHUB_OUTPUT'):
print(f"Not in github env, skip writing to $GITHUB_OUTPUT .")
return

with open(os.getenv('GITHUB_OUTPUT'), 'a') as fh:
delimiter = uuid.uuid1()
print(f'{name}<<{delimiter}', file=fh)
print(value, file=fh)
Expand All @@ -51,9 +56,11 @@ def run_command(command):

def get_yaml_from_branch(branch, file_path):
"""Read YAML from a branch"""
subprocess.run(['git', 'fetch', 'origin', branch], check=True)
subprocess.run(['git', 'checkout', 'origin/' + branch], check=True)
return read_yaml(file_path)
subprocess.run(['git', 'fetch', GITREMOTE_TARGET, branch], check=True)
subprocess.run(['git', 'checkout', f"{GITREMOTE_TARGET}/" + branch], check=True)
content = read_yaml(file_path)
subprocess.run(['git', 'checkout', '-'], check=True)
return content

def detect_changes_in_yaml(yaml_main, yaml_pr):
"""Detecting relevant changes in support_models.yaml"""
Expand Down Expand Up @@ -90,33 +97,27 @@ def models_to_build(files_changed):
seen_model_types.add(model_info["type"])
return list(models)

def check_modified_models(pr_branch):
def check_modified_models():
"""Check for modified models in the repository."""
repo_dir = Path.cwd() / "repo"

if repo_dir.exists():
shutil.rmtree(repo_dir)

run_command(f"git clone {KAITO_REPO_URL} {repo_dir}")
os.chdir(repo_dir)

run_command("git checkout --detach")
run_command("git fetch origin main:main")
run_command(f"git fetch origin {pr_branch}:{pr_branch}")
run_command(f"git checkout {pr_branch}")
run_command(f"git remote add {GITREMOTE_TARGET} {KAITO_REPO_URL}")
run_command(f"git fetch {GITREMOTE_TARGET}")

files = run_command("git diff --name-only origin/main") # Returns each file on newline
files = run_command(f"git diff --name-only {GITREMOTE_TARGET}/main") # Returns each file on newline
files = files.split("\n")
os.chdir(Path.cwd().parent)
print("Files Changed: ", files)

modified_models = models_to_build(files)

print("Modified Models (Images to build): ", modified_models)

return modified_models

def main():
pr_branch = os.environ.get("PR_BRANCH", "main") # If not specified default to 'main'
force_run_all = os.environ.get("FORCE_RUN_ALL", "false") # If not specified default to False
force_run_all_phi = os.environ.get("FORCE_RUN_ALL_PHI", "false") # If not specified default to False
force_run_all_public = os.environ.get("FORCE_RUN_ALL_PUBLIC", "false") # If not specified default to False
Expand All @@ -131,7 +132,7 @@ def main():
else:
# Logic to determine affected models
# Example: affected_models = ['model1', 'model2', 'model3']
affected_models = check_modified_models(pr_branch)
affected_models = check_modified_models()

# Convert the list of models into JSON matrix format
matrix = create_matrix(affected_models)
Expand Down

0 comments on commit bd1bcd2

Please sign in to comment.