Skip to content

Commit

Permalink
Cache operations + fix system message + safer fill majority (#294)
Browse files Browse the repository at this point in the history
Signed-off-by: Igor Gitman <[email protected]>
  • Loading branch information
Kipok authored Dec 14, 2024
1 parent 197d090 commit d94bc47
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 36 deletions.
22 changes: 16 additions & 6 deletions nemo_skills/evaluation/fill_majority_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import json
import logging
import shutil
import sys
from collections import Counter
from itertools import zip_longest
Expand Down Expand Up @@ -75,7 +76,7 @@ def fill_majority_answer(cfg: FillMajorityAnswerConfig):
new_answers = []
all_predictions = []
for idx, predictions in enumerate(tqdm(zip_longest(*file_handles))):
data = read_predictions(predictions)
data = read_predictions(predictions, idx, file_handles)
for elem in data:
if 'predicted_answer' not in elem:
elem['predicted_answer'] = extract_answer(elem['generation'])
Expand Down Expand Up @@ -106,8 +107,12 @@ def fill_majority_answer(cfg: FillMajorityAnswerConfig):
for file_handle in file_handles:
file_handle.close()

# writing the majority answers back to the files
file_handles = [open(file, "wt", encoding="utf-8") for file in unroll_files(cfg.input_files)]
# TODO: change to instead write to a fully new set of files
# Create temp filenames and open temp files for writing
input_files = unroll_files(cfg.input_files)
temp_files = [f"{file}-tmp" for file in input_files]
file_handles = [open(temp_file, "wt", encoding="utf-8") for temp_file in temp_files]

for idx, predictions in enumerate(all_predictions):
for lidx, handle in enumerate(file_handles):
if cfg.ignore_if_not_none and predictions[lidx][cfg.fill_key] is not None:
Expand All @@ -118,7 +123,6 @@ def fill_majority_answer(cfg: FillMajorityAnswerConfig):
predictions[lidx]["majority_votes"], predictions[lidx]["total_votes"] = new_answers[idx][1]
else:
predictions[lidx]["answer_rm_score"] = new_answers[idx][1]
# this is just a string match check, so for full correctness need to rerun the evaluator
if cfg.fill_is_correct:
predictions[lidx]["is_correct"] = (
predictions[lidx]["predicted_answer"] == predictions[lidx]["expected_answer"]
Expand All @@ -127,8 +131,14 @@ def fill_majority_answer(cfg: FillMajorityAnswerConfig):
predictions[lidx].pop("is_correct")
handle.write(json.dumps(predictions[lidx]) + "\n")

for file_handle in file_handles:
file_handle.close()
# Close all files before moving
for handle in file_handles:
handle.close()

# Move temp files to original files
input_files = unroll_files(cfg.input_files)
for temp_file, orig_file in zip(temp_files, input_files):
shutil.move(temp_file, orig_file)


HELP_MESSAGE = get_help_message(FillMajorityAnswerConfig)
Expand Down
2 changes: 1 addition & 1 deletion nemo_skills/evaluation/metrics/compute_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def compute_metrics(self, input_files):
for idx, predictions in enumerate(zip_longest(*file_handles)):
if idx == self.max_samples:
break
data = read_predictions(predictions)
data = read_predictions(predictions, idx, file_handles)
# checking if we need to create a new metrics calculator
data_subset = data[0].get('subset_for_metrics', 'all')
if data_subset not in self.calculators:
Expand Down
13 changes: 10 additions & 3 deletions nemo_skills/evaluation/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,19 @@
# See the License for the specific lang

import json
import logging

LOG = logging.getLogger(__file__)

def read_predictions(predictions):

def read_predictions(predictions, line_idx, file_handles):
data = []
for prediction in predictions:
prediction_dict = json.loads(prediction)
for file_idx, prediction in enumerate(predictions):
try:
prediction_dict = json.loads(prediction)
except Exception as e:
LOG.error(f"Error reading line %s in file %s: %s", line_idx + 1, file_handles[file_idx].name, e)
raise
data.append(prediction_dict)

return data
Expand Down
34 changes: 30 additions & 4 deletions nemo_skills/pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import sys
import tarfile
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Optional

Expand Down Expand Up @@ -52,6 +53,9 @@ def get_unmounted_path(cluster_config, path):
raise ValueError(f"The path '{path}' is not mounted. Check cluster config.")


# caching the status assuming it doesn't change while experiment is being scheduled
# otherwise this results in too many ssh calls
@lru_cache
def get_exp_handles(expname: str, ignore_finished=True, ignore_exp_not_exists=True) -> list[str]:
"""Will return the handles of the tasks in the experiment.
Expand All @@ -68,7 +72,7 @@ def _get_handles(exp):
handles = []
for job in exp.jobs:
if not ignore_finished or (
job.status(exp._runner) in [AppState.RUNNING, AppState.PENDING, AppState.SUBMITTED]
job.status(exp._runner) in [AppState.RUNNING, AppState.PENDING, AppState.SUBMITTED, AppState.UNKNOWN]
):
handles.append(job.handle)
continue
Expand Down Expand Up @@ -314,8 +318,27 @@ def get_cluster_config(cluster=None, config_dir=None):
return read_config(config_file)


@lru_cache
def _get_tunnel_cached(
job_dir: str,
host: str,
user: str,
identity: str | None = None,
shell: str | None = None,
pre_command: str | None = None,
):
return run.SSHTunnel(
host=host,
user=user,
identity=identity,
shell=shell,
pre_command=pre_command,
job_dir=job_dir,
)


def get_tunnel(cluster_config):
return run.SSHTunnel(**cluster_config["ssh_tunnel"])
return _get_tunnel_cached(**cluster_config["ssh_tunnel"])


# Helper class and function to support streaming updates
Expand Down Expand Up @@ -429,7 +452,8 @@ def cluster_upload(tunnel: SSHTunnel, local_file: str, remote_dir: str, verbose:
print(f"\nTransfer complete")


def get_packager(extra_package_dirs: list[str] | None = None):
@lru_cache
def get_packager(extra_package_dirs: tuple[str] | None = None):
"""Will check if we are running from a git repo and use git packager or default packager otherwise."""
nemo_skills_dir = Path(__file__).absolute().parents[1]

Expand Down Expand Up @@ -607,13 +631,15 @@ def get_executor(
partition=None,
time_min=None,
dependencies=None,
extra_package_dirs: list[str] | None = None,
extra_package_dirs: tuple[str] | None = None,
slurm_kwargs: dict | None = None,
):
env_vars = get_env_variables(cluster_config)
config_mounts = get_mounts_from_config(cluster_config, env_vars)

mounts = mounts or config_mounts
if extra_package_dirs is not None:
extra_package_dirs = tuple(extra_package_dirs)
packager = get_packager(extra_package_dirs=extra_package_dirs)
if cluster_config["executor"] == "local":
if num_nodes > 1:
Expand Down
16 changes: 11 additions & 5 deletions nemo_skills/prompt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,14 +266,20 @@ def fill(
return prompt_string
else:
if multi_turn_key is None:
messages = [
{"role": "system", "content": self.config.system},
{"role": "user", "content": self.build_user_message(input_dict)},
]
if self.config.system:
messages = [
{"role": "system", "content": self.config.system},
]
else:
messages = []
messages.append({"role": "user", "content": self.build_user_message(input_dict)})
if generation and include_generation:
messages.append({"role": "assistant", "content": generation})
else:
messages = [{"role": "system", "content": self.config.system}]
if self.config.system:
messages = [{"role": "system", "content": self.config.system}]
else:
messages = []
for turn in input_dict[multi_turn_key][:-1]:
messages.append({"role": "user", "content": self.build_user_message(turn)})
messages.append({"role": "assistant", "content": turn["assistant"]})
Expand Down
24 changes: 7 additions & 17 deletions tests/test_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@ def test_generic_codegen_prompt():
prompt = get_prompt('generic/codegen')

expected_prompt = [
{'role': 'system', 'content': ''},
{
'role': 'user',
'content': '''
Expand All @@ -228,7 +227,6 @@ def test_generic_default_prompt():
prompt = get_prompt('generic/default')

expected_prompt = [
{'role': 'system', 'content': ''},
{
'role': 'user',
'content': 'How are you?',
Expand Down Expand Up @@ -911,10 +909,6 @@ def test_judge_math():
prompt = get_prompt('judge/math')

expected_prompt = [
{
'role': 'system',
'content': '',
},
{
'role': 'user',
'content': '''
Expand Down Expand Up @@ -1012,10 +1006,6 @@ def test_judge_check_contamination():
prompt = get_prompt('judge/check-contamination')

expected_prompt = [
{
'role': 'system',
'content': '',
},
{
'role': 'user',
'content': '''
Expand Down Expand Up @@ -1209,6 +1199,7 @@ def test_generic_lean4_fewshot_prompt():
== expected_prompt
)


def test_mmlu_pro_llama_few_shot_prompt():
prompt = get_prompt('llama3-instruct/mmlu_pro', 'llama3-instruct', 'mmlu_pro_few_shot_llama_math')
expected_prompt = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Expand Down Expand Up @@ -1340,10 +1331,10 @@ def test_mmlu_pro_llama_few_shot_prompt():
"""

inputs = {
'question' : 'If the half-life of a radioactive substance is 8 years, how long will it take, in years, for two thirds of the substance to decay?',
'options' :'A. 22.45\nB. 12.21\nC. 12.68\nD. 7.69\nE. 16.89\nF. 20.76\nG. 4.68\nH. 14.12\nI. 10.33\nJ. 18.34',
'question': 'If the half-life of a radioactive substance is 8 years, how long will it take, in years, for two thirds of the substance to decay?',
'options': 'A. 22.45\nB. 12.21\nC. 12.68\nD. 7.69\nE. 16.89\nF. 20.76\nG. 4.68\nH. 14.12\nI. 10.33\nJ. 18.34',
}
assert (prompt.fill(inputs) == expected_prompt)
assert prompt.fill(inputs) == expected_prompt


def test_mmlu_pro_tigerlab_few_shot_prompt():
Expand Down Expand Up @@ -1441,10 +1432,9 @@ def test_mmlu_pro_tigerlab_few_shot_prompt():
"""
inputs = {
'question' : "What is the square root of 81 squared?",
'options' : "A. 9^2\nB. 27\nC. 81^2\nD. 729\nE. 6561\nF. 12\nG. 162\nH. 243\nI. 81\nJ. 9",
"subset_for_metrics": "other"
'question': "What is the square root of 81 squared?",
'options': "A. 9^2\nB. 27\nC. 81^2\nD. 729\nE. 6561\nF. 12\nG. 162\nH. 243\nI. 81\nJ. 9",
"subset_for_metrics": "other",
}
print(prompt.fill(inputs))
assert prompt.fill(inputs) == expected_prompt

0 comments on commit d94bc47

Please sign in to comment.