diff --git a/nemo_skills/evaluation/fill_majority_answer.py b/nemo_skills/evaluation/fill_majority_answer.py index 8ef3b14cd..cd6e5e7d3 100644 --- a/nemo_skills/evaluation/fill_majority_answer.py +++ b/nemo_skills/evaluation/fill_majority_answer.py @@ -14,6 +14,7 @@ import json import logging +import shutil import sys from collections import Counter from itertools import zip_longest @@ -75,7 +76,7 @@ def fill_majority_answer(cfg: FillMajorityAnswerConfig): new_answers = [] all_predictions = [] for idx, predictions in enumerate(tqdm(zip_longest(*file_handles))): - data = read_predictions(predictions) + data = read_predictions(predictions, idx, file_handles) for elem in data: if 'predicted_answer' not in elem: elem['predicted_answer'] = extract_answer(elem['generation']) @@ -106,8 +107,12 @@ def fill_majority_answer(cfg: FillMajorityAnswerConfig): for file_handle in file_handles: file_handle.close() - # writing the majority answers back to the files - file_handles = [open(file, "wt", encoding="utf-8") for file in unroll_files(cfg.input_files)] + # TODO: change to instead write to a fully new set of files + # Create temp filenames and open temp files for writing + input_files = unroll_files(cfg.input_files) + temp_files = [f"{file}-tmp" for file in input_files] + file_handles = [open(temp_file, "wt", encoding="utf-8") for temp_file in temp_files] + for idx, predictions in enumerate(all_predictions): for lidx, handle in enumerate(file_handles): if cfg.ignore_if_not_none and predictions[lidx][cfg.fill_key] is not None: @@ -118,7 +123,6 @@ def fill_majority_answer(cfg: FillMajorityAnswerConfig): predictions[lidx]["majority_votes"], predictions[lidx]["total_votes"] = new_answers[idx][1] else: predictions[lidx]["answer_rm_score"] = new_answers[idx][1] - # this is just a string match check, so for full correctness need to rerun the evaluator if cfg.fill_is_correct: predictions[lidx]["is_correct"] = ( predictions[lidx]["predicted_answer"] == predictions[lidx]["expected_answer"] @@ -127,8 +131,14 @@ def fill_majority_answer(cfg: FillMajorityAnswerConfig): predictions[lidx].pop("is_correct") handle.write(json.dumps(predictions[lidx]) + "\n") - for file_handle in file_handles: - file_handle.close() + # Close all files before moving + for handle in file_handles: + handle.close() + + # Move temp files to original files + input_files = unroll_files(cfg.input_files) + for temp_file, orig_file in zip(temp_files, input_files): + shutil.move(temp_file, orig_file) HELP_MESSAGE = get_help_message(FillMajorityAnswerConfig) diff --git a/nemo_skills/evaluation/metrics/compute_metrics.py b/nemo_skills/evaluation/metrics/compute_metrics.py index 1a4e5b289..3f89b5cb1 100644 --- a/nemo_skills/evaluation/metrics/compute_metrics.py +++ b/nemo_skills/evaluation/metrics/compute_metrics.py @@ -53,7 +53,7 @@ def compute_metrics(self, input_files): for idx, predictions in enumerate(zip_longest(*file_handles)): if idx == self.max_samples: break - data = read_predictions(predictions) + data = read_predictions(predictions, idx, file_handles) # checking if we need to create a new metrics calculator data_subset = data[0].get('subset_for_metrics', 'all') if data_subset not in self.calculators: diff --git a/nemo_skills/evaluation/metrics/utils.py b/nemo_skills/evaluation/metrics/utils.py index fb6f3c120..2771a2330 100644 --- a/nemo_skills/evaluation/metrics/utils.py +++ b/nemo_skills/evaluation/metrics/utils.py @@ -12,12 +12,19 @@ # See the License for the specific lang import json +import logging +LOG = logging.getLogger(__file__) -def read_predictions(predictions): + +def read_predictions(predictions, line_idx, file_handles): data = [] - for prediction in predictions: - prediction_dict = json.loads(prediction) + for file_idx, prediction in enumerate(predictions): + try: + prediction_dict = json.loads(prediction) + except Exception as e: + LOG.error(f"Error reading line %s in file %s: %s", line_idx + 1, file_handles[file_idx].name, e) + raise data.append(prediction_dict) return data diff --git a/nemo_skills/pipeline/utils.py b/nemo_skills/pipeline/utils.py index c5f4529d9..e302d6e52 100644 --- a/nemo_skills/pipeline/utils.py +++ b/nemo_skills/pipeline/utils.py @@ -19,6 +19,7 @@ import sys import tarfile from dataclasses import dataclass +from functools import lru_cache from pathlib import Path from typing import Optional @@ -52,6 +53,9 @@ def get_unmounted_path(cluster_config, path): raise ValueError(f"The path '{path}' is not mounted. Check cluster config.") +# caching the status assuming it doesn't change while experiment is being scheduled +# otherwise this results in too many ssh calls +@lru_cache def get_exp_handles(expname: str, ignore_finished=True, ignore_exp_not_exists=True) -> list[str]: """Will return the handles of the tasks in the experiment. @@ -68,7 +72,7 @@ def _get_handles(exp): handles = [] for job in exp.jobs: if not ignore_finished or ( - job.status(exp._runner) in [AppState.RUNNING, AppState.PENDING, AppState.SUBMITTED] + job.status(exp._runner) in [AppState.RUNNING, AppState.PENDING, AppState.SUBMITTED, AppState.UNKNOWN] ): handles.append(job.handle) continue @@ -314,8 +318,27 @@ def get_cluster_config(cluster=None, config_dir=None): return read_config(config_file) +@lru_cache +def _get_tunnel_cached( + job_dir: str, + host: str, + user: str, + identity: str | None = None, + shell: str | None = None, + pre_command: str | None = None, +): + return run.SSHTunnel( + host=host, + user=user, + identity=identity, + shell=shell, + pre_command=pre_command, + job_dir=job_dir, + ) + + def get_tunnel(cluster_config): - return run.SSHTunnel(**cluster_config["ssh_tunnel"]) + return _get_tunnel_cached(**cluster_config["ssh_tunnel"]) # Helper class and function to support streaming updates @@ -429,7 +452,8 @@ def cluster_upload(tunnel: SSHTunnel, local_file: str, remote_dir: str, verbose: print(f"\nTransfer complete") -def get_packager(extra_package_dirs: list[str] | None = None): +@lru_cache +def get_packager(extra_package_dirs: tuple[str] | None = None): """Will check if we are running from a git repo and use git packager or default packager otherwise.""" nemo_skills_dir = Path(__file__).absolute().parents[1] @@ -607,13 +631,15 @@ def get_executor( partition=None, time_min=None, dependencies=None, - extra_package_dirs: list[str] | None = None, + extra_package_dirs: tuple[str] | None = None, slurm_kwargs: dict | None = None, ): env_vars = get_env_variables(cluster_config) config_mounts = get_mounts_from_config(cluster_config, env_vars) mounts = mounts or config_mounts + if extra_package_dirs is not None: + extra_package_dirs = tuple(extra_package_dirs) packager = get_packager(extra_package_dirs=extra_package_dirs) if cluster_config["executor"] == "local": if num_nodes > 1: diff --git a/nemo_skills/prompt/utils.py b/nemo_skills/prompt/utils.py index 623a8b67b..c286140ea 100644 --- a/nemo_skills/prompt/utils.py +++ b/nemo_skills/prompt/utils.py @@ -266,14 +266,20 @@ def fill( return prompt_string else: if multi_turn_key is None: - messages = [ - {"role": "system", "content": self.config.system}, - {"role": "user", "content": self.build_user_message(input_dict)}, - ] + if self.config.system: + messages = [ + {"role": "system", "content": self.config.system}, + ] + else: + messages = [] + messages.append({"role": "user", "content": self.build_user_message(input_dict)}) if generation and include_generation: messages.append({"role": "assistant", "content": generation}) else: - messages = [{"role": "system", "content": self.config.system}] + if self.config.system: + messages = [{"role": "system", "content": self.config.system}] + else: + messages = [] for turn in input_dict[multi_turn_key][:-1]: messages.append({"role": "user", "content": self.build_user_message(turn)}) messages.append({"role": "assistant", "content": turn["assistant"]}) diff --git a/tests/test_prompts.py b/tests/test_prompts.py index 649321c0f..c3a321c7a 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -205,7 +205,6 @@ def test_generic_codegen_prompt(): prompt = get_prompt('generic/codegen') expected_prompt = [ - {'role': 'system', 'content': ''}, { 'role': 'user', 'content': ''' @@ -228,7 +227,6 @@ def test_generic_default_prompt(): prompt = get_prompt('generic/default') expected_prompt = [ - {'role': 'system', 'content': ''}, { 'role': 'user', 'content': 'How are you?', @@ -911,10 +909,6 @@ def test_judge_math(): prompt = get_prompt('judge/math') expected_prompt = [ - { - 'role': 'system', - 'content': '', - }, { 'role': 'user', 'content': ''' @@ -1012,10 +1006,6 @@ def test_judge_check_contamination(): prompt = get_prompt('judge/check-contamination') expected_prompt = [ - { - 'role': 'system', - 'content': '', - }, { 'role': 'user', 'content': ''' @@ -1209,6 +1199,7 @@ def test_generic_lean4_fewshot_prompt(): == expected_prompt ) + def test_mmlu_pro_llama_few_shot_prompt(): prompt = get_prompt('llama3-instruct/mmlu_pro', 'llama3-instruct', 'mmlu_pro_few_shot_llama_math') expected_prompt = """<|begin_of_text|><|start_header_id|>system<|end_header_id|> @@ -1340,10 +1331,10 @@ def test_mmlu_pro_llama_few_shot_prompt(): """ inputs = { - 'question' : 'If the half-life of a radioactive substance is 8 years, how long will it take, in years, for two thirds of the substance to decay?', - 'options' :'A. 22.45\nB. 12.21\nC. 12.68\nD. 7.69\nE. 16.89\nF. 20.76\nG. 4.68\nH. 14.12\nI. 10.33\nJ. 18.34', + 'question': 'If the half-life of a radioactive substance is 8 years, how long will it take, in years, for two thirds of the substance to decay?', + 'options': 'A. 22.45\nB. 12.21\nC. 12.68\nD. 7.69\nE. 16.89\nF. 20.76\nG. 4.68\nH. 14.12\nI. 10.33\nJ. 18.34', } - assert (prompt.fill(inputs) == expected_prompt) + assert prompt.fill(inputs) == expected_prompt def test_mmlu_pro_tigerlab_few_shot_prompt(): @@ -1441,10 +1432,9 @@ def test_mmlu_pro_tigerlab_few_shot_prompt(): """ inputs = { - 'question' : "What is the square root of 81 squared?", - 'options' : "A. 9^2\nB. 27\nC. 81^2\nD. 729\nE. 6561\nF. 12\nG. 162\nH. 243\nI. 81\nJ. 9", - "subset_for_metrics": "other" + 'question': "What is the square root of 81 squared?", + 'options': "A. 9^2\nB. 27\nC. 81^2\nD. 729\nE. 6561\nF. 12\nG. 162\nH. 243\nI. 81\nJ. 9", + "subset_for_metrics": "other", } print(prompt.fill(inputs)) assert prompt.fill(inputs) == expected_prompt -