diff --git a/.github/workflows/black.yaml b/.github/workflows/black.yaml
new file mode 100644
index 00000000..b2cd244f
--- /dev/null
+++ b/.github/workflows/black.yaml
@@ -0,0 +1,10 @@
+name: Lint
+
+on: [push, pull_request]
+
+jobs:
+ lint:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - uses: psf/black@stable
\ No newline at end of file
diff --git a/create_csv_helper.py b/create_csv_helper.py
index b3e932b5..fe66efa0 100644
--- a/create_csv_helper.py
+++ b/create_csv_helper.py
@@ -10,6 +10,7 @@
engine, SessionMaker = create_db_engine()
+
@contextmanager
def session_scope():
"""
@@ -25,54 +26,42 @@ def session_scope():
finally:
session.close()
+
def get_model_score(name: str, model_id: uuid.UUID, annotator_model: str) -> float:
with session_scope() as session:
rows = session.query(EvalResult).filter_by(model_id=model_id).all()
if not rows:
return None
for row in rows:
- eval_setting = session.query(EvalSetting).filter_by(
- id=row.eval_setting_id
- ).first()
- if eval_setting and name == eval_setting.name and eval_setting.parameters['annotator_model'] == annotator_model:
+ eval_setting = session.query(EvalSetting).filter_by(id=row.eval_setting_id).first()
+ if (
+ eval_setting
+ and name == eval_setting.name
+ and eval_setting.parameters["annotator_model"] == annotator_model
+ ):
return float(row.score)
return None
+
def get_model_name(model_id: uuid.UUID) -> str:
with session_scope() as session:
model = session.query(Model).filter_by(id=model_id).first()
return model.name if model else None
+
def parse_args() -> argparse.Namespace:
- parser = argparse.ArgumentParser(description='Generate CSV of model evaluation scores')
- parser.add_argument(
- '--model-ids',
- required=True,
- nargs='+',
- help='List of model UUIDs to evaluate'
- )
- parser.add_argument(
- '--eval-tasks',
- required=True,
- nargs='+',
- help='List of evaluation task names'
- )
- parser.add_argument(
- '--annotator-model',
- required=True,
- help='Annotator model to filter results'
- )
- parser.add_argument(
- '--output',
- default='model_scores.csv',
- help='Output CSV filename (default: model_scores.csv)'
- )
+ parser = argparse.ArgumentParser(description="Generate CSV of model evaluation scores")
+ parser.add_argument("--model-ids", required=True, nargs="+", help="List of model UUIDs to evaluate")
+ parser.add_argument("--eval-tasks", required=True, nargs="+", help="List of evaluation task names")
+ parser.add_argument("--annotator-model", required=True, help="Annotator model to filter results")
+ parser.add_argument("--output", default="model_scores.csv", help="Output CSV filename (default: model_scores.csv)")
return parser.parse_args()
+
def generate_eval_csv(model_ids: List[str], eval_tasks: List[str], annotator_model: str, output_file: str) -> None:
"""
Generate CSV file with model evaluation scores.
-
+
Args:
model_ids: List of model UUID strings
eval_tasks: List of evaluation task names
@@ -87,7 +76,7 @@ def generate_eval_csv(model_ids: List[str], eval_tasks: List[str], annotator_mod
sys.exit(1)
# Prepare CSV headers
- headers = ['model_id', 'model_name'] + eval_tasks
+ headers = ["model_id", "model_name"] + eval_tasks
# Collect data for each model
rows = []
@@ -97,21 +86,18 @@ def generate_eval_csv(model_ids: List[str], eval_tasks: List[str], annotator_mod
print(f"Warning: Model not found for ID {model_id}", file=sys.stderr)
continue
- row = {
- 'model_id': str(model_id),
- 'model_name': model_name
- }
-
+ row = {"model_id": str(model_id), "model_name": model_name}
+
# Get scores for each eval task
for task in eval_tasks:
score = get_model_score(task, model_id, annotator_model)
- row[task] = score if score is not None else 'N/A'
-
+ row[task] = score if score is not None else "N/A"
+
rows.append(row)
# Write to CSV
try:
- with open(output_file, 'w', newline='') as csvfile:
+ with open(output_file, "w", newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=headers)
writer.writeheader()
writer.writerows(rows)
@@ -120,9 +106,11 @@ def generate_eval_csv(model_ids: List[str], eval_tasks: List[str], annotator_mod
print(f"Error writing to CSV file: {e}", file=sys.stderr)
sys.exit(1)
+
def main():
args = parse_args()
generate_eval_csv(args.model_ids, args.eval_tasks, args.annotator_model, args.output)
-if __name__ == '__main__':
- main()
\ No newline at end of file
+
+if __name__ == "__main__":
+ main()
diff --git a/eval/chat_benchmarks/MTBench/fastchat/data/clean_sharegpt.py b/eval/chat_benchmarks/MTBench/fastchat/data/clean_sharegpt.py
index af6ffb69..7720f511 100644
--- a/eval/chat_benchmarks/MTBench/fastchat/data/clean_sharegpt.py
+++ b/eval/chat_benchmarks/MTBench/fastchat/data/clean_sharegpt.py
@@ -5,6 +5,7 @@
Usage:
python3 -m fastchat.data.clean_sharegpt --in sharegpt_html.json --out sharegpt_clean.json
"""
+
import argparse
from concurrent.futures import ProcessPoolExecutor
import json
@@ -19,9 +20,7 @@
div_pattern = re.compile("
")
span_pattern = re.compile("")
-code_lang_pattern = re.compile(
- "```\s*" + "(.*?)" + "(?:Copy code)+" + "(.+?)" + "\s*?```", re.DOTALL
-)
+code_lang_pattern = re.compile("```\s*" + "(.*?)" + "(?:Copy code)+" + "(.+?)" + "\s*?```", re.DOTALL)
code_lang_format = "```\g<1>\n\g<2>\n```"
regenerate_pattern = re.compile("\d+ / \d+")
copy_chars_pattern = re.compile("Copy\d+ chars / \d+ words")
@@ -155,9 +154,7 @@ def clean_html_all(content, begin, end):
content = content[begin:end]
processed = []
with ProcessPoolExecutor() as executor:
- for result in tqdm(
- executor.map(clean_html_one_sample, content), total=len(content)
- ):
+ for result in tqdm(executor.map(clean_html_one_sample, content), total=len(content)):
processed.append(result)
visited = {}
diff --git a/eval/chat_benchmarks/MTBench/fastchat/data/extract_gpt4_only.py b/eval/chat_benchmarks/MTBench/fastchat/data/extract_gpt4_only.py
index bab53bcc..9bf185a2 100644
--- a/eval/chat_benchmarks/MTBench/fastchat/data/extract_gpt4_only.py
+++ b/eval/chat_benchmarks/MTBench/fastchat/data/extract_gpt4_only.py
@@ -3,6 +3,7 @@
Usage: python3 -m fastchat.data.extract_gpt4_only --in sharegpt.json
"""
+
import argparse
import json
diff --git a/eval/chat_benchmarks/MTBench/fastchat/data/extract_single_round.py b/eval/chat_benchmarks/MTBench/fastchat/data/extract_single_round.py
index 5da80365..a7c93ac7 100644
--- a/eval/chat_benchmarks/MTBench/fastchat/data/extract_single_round.py
+++ b/eval/chat_benchmarks/MTBench/fastchat/data/extract_single_round.py
@@ -3,6 +3,7 @@
Usage: python3 -m fastchat.data.extract_single_round --in sharegpt.json
"""
+
import argparse
import json
diff --git a/eval/chat_benchmarks/MTBench/fastchat/data/filter_wrong_format.py b/eval/chat_benchmarks/MTBench/fastchat/data/filter_wrong_format.py
index 46588ba8..90df80b9 100644
--- a/eval/chat_benchmarks/MTBench/fastchat/data/filter_wrong_format.py
+++ b/eval/chat_benchmarks/MTBench/fastchat/data/filter_wrong_format.py
@@ -5,6 +5,7 @@
python3 -m fastchat.data.filter_wrong_format --in input.json --out output.json
"""
+
import argparse
import json
import re
diff --git a/eval/chat_benchmarks/MTBench/fastchat/data/get_stats.py b/eval/chat_benchmarks/MTBench/fastchat/data/get_stats.py
index 0e0698e4..12fb646e 100644
--- a/eval/chat_benchmarks/MTBench/fastchat/data/get_stats.py
+++ b/eval/chat_benchmarks/MTBench/fastchat/data/get_stats.py
@@ -26,9 +26,7 @@ def tokenize_one_sample(c):
def tokenize_dataset(content):
processed = []
with ProcessPoolExecutor() as executor:
- for result in tqdm(
- executor.map(tokenize_one_sample, content), total=len(content)
- ):
+ for result in tqdm(executor.map(tokenize_one_sample, content), total=len(content)):
processed.append(result)
return processed
@@ -59,9 +57,7 @@ def compute_stats(content):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--in-file", type=str)
- parser.add_argument(
- "--model-name-or-path", type=str, default="meta-llama/Llama-2-7b-chat-hf"
- )
+ parser.add_argument("--model-name-or-path", type=str, default="meta-llama/Llama-2-7b-chat-hf")
args = parser.parse_args()
content = json.load(open(args.in_file, "r"))
diff --git a/eval/chat_benchmarks/MTBench/fastchat/data/hardcoded_questions.py b/eval/chat_benchmarks/MTBench/fastchat/data/hardcoded_questions.py
index a2bcff42..c6a89a45 100644
--- a/eval/chat_benchmarks/MTBench/fastchat/data/hardcoded_questions.py
+++ b/eval/chat_benchmarks/MTBench/fastchat/data/hardcoded_questions.py
@@ -1,6 +1,7 @@
"""
Hardcoded question and answers.
"""
+
import json
diff --git a/eval/chat_benchmarks/MTBench/fastchat/data/inspect_data.py b/eval/chat_benchmarks/MTBench/fastchat/data/inspect_data.py
index df922710..dd5f3378 100644
--- a/eval/chat_benchmarks/MTBench/fastchat/data/inspect_data.py
+++ b/eval/chat_benchmarks/MTBench/fastchat/data/inspect_data.py
@@ -2,6 +2,7 @@
Usage:
python3 -m fastchat.data.inspect_data --in sharegpt_20230322_clean_lang_split.json
"""
+
import argparse
import json
import random
diff --git a/eval/chat_benchmarks/MTBench/fastchat/data/optional_clean.py b/eval/chat_benchmarks/MTBench/fastchat/data/optional_clean.py
index 47aecc11..81469f4c 100644
--- a/eval/chat_benchmarks/MTBench/fastchat/data/optional_clean.py
+++ b/eval/chat_benchmarks/MTBench/fastchat/data/optional_clean.py
@@ -8,6 +8,7 @@
Requirement:
pip3 install polyglot pyicu pycld2
"""
+
import argparse
import json
import re
diff --git a/eval/chat_benchmarks/MTBench/fastchat/data/optional_replace.py b/eval/chat_benchmarks/MTBench/fastchat/data/optional_replace.py
index 1114151a..ef20150e 100644
--- a/eval/chat_benchmarks/MTBench/fastchat/data/optional_replace.py
+++ b/eval/chat_benchmarks/MTBench/fastchat/data/optional_replace.py
@@ -7,6 +7,7 @@
Requirement:
pip3 install transformers tqdm
"""
+
import argparse
import json
import traceback
@@ -15,9 +16,7 @@
from tqdm import tqdm
-def replace_special_tokens(
- tokenizer: transformers.PreTrainedTokenizer, text: str
-) -> str:
+def replace_special_tokens(tokenizer: transformers.PreTrainedTokenizer, text: str) -> str:
if not text:
return text
diff --git a/eval/chat_benchmarks/MTBench/fastchat/data/prepare_all.py b/eval/chat_benchmarks/MTBench/fastchat/data/prepare_all.py
index 6d568703..cb8f544c 100644
--- a/eval/chat_benchmarks/MTBench/fastchat/data/prepare_all.py
+++ b/eval/chat_benchmarks/MTBench/fastchat/data/prepare_all.py
@@ -9,20 +9,14 @@
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--prefix", type=str, default="~/datasets/sharegpt_20230521")
- parser.add_argument(
- "--model-name-or-path", type=str, default="meta-llama/Llama-2-7b-chat-hf"
- )
+ parser.add_argument("--model-name-or-path", type=str, default="meta-llama/Llama-2-7b-chat-hf")
parser.add_argument("--seq-len", type=int, default=4096)
args = parser.parse_args()
in_prefix = args.prefix
model_path = args.model_name_or_path
seq_len = args.seq_len
- prefix = (
- f"{in_prefix}_{seq_len}".replace("4096", "4k")
- .replace("8192", "8k")
- .replace("16384", "16k")
- )
+ prefix = f"{in_prefix}_{seq_len}".replace("4096", "4k").replace("8192", "8k").replace("16384", "16k")
cmd_list = [
f"python3 -m fastchat.data.clean_sharegpt --in {in_prefix}_html.json --out {prefix}_clean.json",
diff --git a/eval/chat_benchmarks/MTBench/fastchat/data/sample.py b/eval/chat_benchmarks/MTBench/fastchat/data/sample.py
index 5ea94fad..7a8501a9 100644
--- a/eval/chat_benchmarks/MTBench/fastchat/data/sample.py
+++ b/eval/chat_benchmarks/MTBench/fastchat/data/sample.py
@@ -3,6 +3,7 @@
Usage: python3 -m fastchat.data.sample --in sharegpt.json --out sampled.json
"""
+
import argparse
import json
diff --git a/eval/chat_benchmarks/MTBench/fastchat/data/split_long_conversation.py b/eval/chat_benchmarks/MTBench/fastchat/data/split_long_conversation.py
index 413fa8bc..9a4c04f9 100644
--- a/eval/chat_benchmarks/MTBench/fastchat/data/split_long_conversation.py
+++ b/eval/chat_benchmarks/MTBench/fastchat/data/split_long_conversation.py
@@ -6,6 +6,7 @@
--out sharegpt_split.json \
--model-name-or-path $
"""
+
import argparse
from concurrent.futures import ProcessPoolExecutor
import json
diff --git a/eval/chat_benchmarks/MTBench/fastchat/data/split_train_test.py b/eval/chat_benchmarks/MTBench/fastchat/data/split_train_test.py
index 60b8960b..eafbc5e1 100644
--- a/eval/chat_benchmarks/MTBench/fastchat/data/split_train_test.py
+++ b/eval/chat_benchmarks/MTBench/fastchat/data/split_train_test.py
@@ -3,6 +3,7 @@
Usage: python3 -m fastchat.data.split_train_test --in sharegpt.json
"""
+
import argparse
import json
diff --git a/eval/chat_benchmarks/scibench/eval_instruct.py b/eval/chat_benchmarks/scibench/eval_instruct.py
index e9e76c39..9437e23f 100644
--- a/eval/chat_benchmarks/scibench/eval_instruct.py
+++ b/eval/chat_benchmarks/scibench/eval_instruct.py
@@ -18,58 +18,63 @@
Please provide a clear and step-by-step solution for a scientific problem in the categories of Chemistry, Physics, or Mathematics. The problem will specify the unit of measurement, which should not be included in the answer. Express the final answer as a decimal number with three digits after the decimal point. Conclude the answer by stating "The answer is therefore \\boxed{[ANSWER]}."
"""
+
def remove_not(x):
- match_number = re.compile('[\$]?\ *10\^[{]?\ *-?[0-9]+\ *[}]?\ *[\$]?')
+ match_number = re.compile("[\$]?\ *10\^[{]?\ *-?[0-9]+\ *[}]?\ *[\$]?")
result = re.findall(match_number, x)
if len(result) != 0:
return re.split(match_number, x)[-1]
return None
+
def parse_not(inputs):
try:
if not inputs:
- return '', ''
- if '\times' in inputs:
- x, ab = inputs.split('\times')
- elif '\\times' in inputs:
- x, ab = inputs.split('\\times')
- elif '*' in inputs:
- x, ab = inputs.split('*')
+ return "", ""
+ if "\times" in inputs:
+ x, ab = inputs.split("\times")
+ elif "\\times" in inputs:
+ x, ab = inputs.split("\\times")
+ elif "*" in inputs:
+ x, ab = inputs.split("*")
else:
return inputs
return x, ab
except:
- return '', ''
+ return "", ""
+
def cal_not(inputs):
try:
x, ab = list(inputs)
- match_number = re.compile('10\^[{]?\ *-?[0-9]+\ *[}]?')
+ match_number = re.compile("10\^[{]?\ *-?[0-9]+\ *[}]?")
ab = re.findall(match_number, ab)[0]
- ab = ab[ab.find('^')+1:]
- if '{' in ab:
- ab = ab[ab.find('{')+1:]
- if '}' in ab:
- ab = ab[:ab.find('}')]
+ ab = ab[ab.find("^") + 1 :]
+ if "{" in ab:
+ ab = ab[ab.find("{") + 1 :]
+ if "}" in ab:
+ ab = ab[: ab.find("}")]
x = x.strip()
- out = float(x) * 10**float(ab)
+ out = float(x) * 10 ** float(ab)
return str(out)
except:
- print('error')
+ print("error")
return inputs
+
def remove_boxed(s):
left = "oxed{"
try:
- assert s[:len(left)] == left
+ assert s[: len(left)] == left
assert s[-1] == "}"
- answer = s[len(left):-1]
+ answer = s[len(left) : -1]
if "=" in answer:
answer = answer.split("=")[-1].lstrip(" ")
return answer
except:
return None
+
def last_boxed_only_string(string):
idx = string.rfind("oxed")
if idx < 0:
@@ -91,15 +96,17 @@ def last_boxed_only_string(string):
if right_brace_idx == None:
retval = None
else:
- retval = string[idx:right_brace_idx + 1]
+ retval = string[idx : right_brace_idx + 1]
return retval
+
def parse_math_answer(raw_string):
return remove_boxed(last_boxed_only_string(raw_string))
+
def equiv(model_output, answer, unit):
"""SciBench's exact equiv function"""
- model_output = model_output.replace(',', '')
+ model_output = model_output.replace(",", "")
try:
ans = float(answer.strip())
first = math.isclose(float(model_output.strip()), ans, rel_tol=0.05)
@@ -114,19 +121,20 @@ def equiv(model_output, answer, unit):
return True
return False
+
@dataclass
class SciBenchConfig:
"""Configuration for SciBench evaluation."""
- categories: List[str] = field(default_factory=lambda: [
- "chemmc"
- ])
+
+ categories: List[str] = field(default_factory=lambda: ["chemmc"])
temperature: float = 0.0
max_new_tokens: int = 1024
do_sample: bool = False
+
class SciBenchBenchmark(BaseBenchmark):
"""SciBench benchmark implementation."""
-
+
def __init__(
self,
categories: List[str] = None,
@@ -143,38 +151,38 @@ def __init__(
def _load_dataset(self, category: str):
"""Load dataset from JSON files in the specified data directory"""
try:
- data_dir = Path("./data") # TODO: CHANGE TO DATA DIRECTORY IF NEEDED
+ data_dir = Path("./data") # TODO: CHANGE TO DATA DIRECTORY IF NEEDED
file_path = data_dir / f"{category}.json"
-
- with open(file_path, 'r') as f:
+
+ with open(file_path, "r") as f:
dataset = json.load(f)
-
+
# Filter problems for the specific category
problems = [
{
- 'problem_text': item['problem_text'],
- 'answer_number': item['answer_number'],
- 'unit': item['unit'],
- 'original_unit': item['unit'], # Using same unit as original since dataset doesn't distinguish
- 'source': item['source']
+ "problem_text": item["problem_text"],
+ "answer_number": item["answer_number"],
+ "unit": item["unit"],
+ "original_unit": item["unit"], # Using same unit as original since dataset doesn't distinguish
+ "source": item["source"],
}
- for item in dataset
- if item['source'] == category
+ for item in dataset
+ if item["source"] == category
]
-
+
# Process units if needed
processed_problems = []
for problem_data in problems:
- unit = problem_data['unit']
+ unit = problem_data["unit"]
base_unit = remove_not(unit)
if base_unit:
unit = base_unit
- problem_data['unit'] = unit
+ problem_data["unit"] = unit
processed_problems.append(problem_data)
-
+
self.logger.info(f"Loaded {len(processed_problems)} problems for category {category}")
return processed_problems
-
+
except Exception as e:
self.logger.error(f"Error loading dataset: {e}")
raise
@@ -182,46 +190,43 @@ def _load_dataset(self, category: str):
def call_engine(self, messages, temperature=0, n=1, patience=100000, sleep_time=0):
"""Match eval_zero.py's implementation for API calls using new OpenAI API"""
client = OpenAI() # This will automatically use OPENAI_API_KEY from env
-
+
while patience > 0:
patience -= 1
try:
response = client.chat.completions.create(
- model="gpt-4", # or use self.config.model
- messages=messages,
- temperature=temperature,
- n=n
+ model="gpt-4", messages=messages, temperature=temperature, n=n # or use self.config.model
)
if n == 1:
prediction = response.choices[0].message.content.strip()
if prediction != "" and prediction is not None:
return prediction
else:
- prediction = [choice.message.content.strip()
- for choice in response.choices]
+ prediction = [choice.message.content.strip() for choice in response.choices]
if prediction[0] != "" and prediction[0] is not None:
return prediction
except Exception as e:
self.logger.error(f"OpenAI API error: {e}")
if sleep_time > 0:
import time
+
time.sleep(sleep_time)
return ""
def generate_responses(self, model: LM) -> Dict[str, Any]:
"""Generate responses for all problems using OpenAI API."""
results = {}
-
+
try:
category = "chemmc"
problems = self._load_dataset(category)
print(f"\nProcessing category {category}")
print(f"Total problems loaded: {len(problems)}")
-
+
if self.debug:
- problems = problems[:min(10, len(problems))]
+ problems = problems[: min(10, len(problems))]
print(f"Debug mode: Using {len(problems)} problems")
-
+
ids = [f"{category}_{i}" for i in range(len(problems))]
metadata = {
"answer_number": [],
@@ -229,27 +234,23 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
"original_unit": [],
}
outputs = []
-
+
for i, problem in enumerate(problems):
- unit_prob = problem['unit']
+ unit_prob = problem["unit"]
problem_text = f"{problem['problem_text']} The unit of the answer is {unit_prob}."
-
+
messages = [
{"role": "system", "content": sys_cal_box2},
- {"role": "user", "content": f"Q: {problem_text}\nA: The answer is"}
- ]
-
+ {"role": "user", "content": f"Q: {problem_text}\nA: The answer is"},
+ ]
+
print(f"\nProblem {i}:")
print(f"Text: {problem_text}")
output = self.call_engine(
- messages,
- temperature=self.config.temperature,
- n=1,
- patience=100000,
- sleep_time=1
+ messages, temperature=self.config.temperature, n=1, patience=100000, sleep_time=1
)
print(f"Response: {output}")
-
+
outputs.append(output)
metadata["answer_number"].append(problem["answer_number"])
metadata["unit"].append(problem["unit"])
@@ -273,21 +274,21 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
def evaluate_responses(self, results: Dict[str, Any]) -> Dict[str, float]:
"""Evaluate all responses using SciBench's exact evaluation logic."""
eval_results = {}
-
+
category = "chemmc"
category_results = results.get(category, {})
correct = 0
total = 0
-
+
for output, answer, unit, original_unit in zip(
category_results.get("outputs", []),
category_results.get("metadata", {}).get("answer_number", []),
category_results.get("metadata", {}).get("unit", []),
- category_results.get("metadata", {}).get("original_unit", [])
+ category_results.get("metadata", {}).get("original_unit", []),
):
model_output = parse_math_answer(output)
if not model_output:
- numbers = re.findall(r'\\boxed{([^}]*)}', output)
+ numbers = re.findall(r"\\boxed{([^}]*)}", output)
if numbers:
model_output = numbers[-1].strip()
else:
@@ -306,7 +307,7 @@ def evaluate_responses(self, results: Dict[str, Any]) -> Dict[str, float]:
if equiv(str(model_output), answer, unit):
correct += 1
total += 1
-
+
eval_results[category] = (correct / total) * 100 if total > 0 else 0.0
return eval_results
@@ -316,21 +317,21 @@ def evaluate_responses(self, results: Dict[str, Any]) -> Dict[str, float]:
eval_results = {}
total_score = 0
num_categories = 0
-
+
category = "chemmc"
category_results = results.get(category, {})
correct = 0
total = 0
-
+
for output, answer, unit, original_unit in zip(
category_results.get("outputs", []),
category_results.get("metadata", {}).get("answer_number", []),
category_results.get("metadata", {}).get("unit", []),
- category_results.get("metadata", {}).get("original_unit", [])
+ category_results.get("metadata", {}).get("original_unit", []),
):
model_output = parse_math_answer(output)
if not model_output:
- numbers = re.findall(r'\\boxed{([^}]*)}', output)
+ numbers = re.findall(r"\\boxed{([^}]*)}", output)
if numbers:
model_output = numbers[-1].strip()
else:
@@ -349,13 +350,13 @@ def evaluate_responses(self, results: Dict[str, Any]) -> Dict[str, float]:
if equiv(str(model_output), answer, unit):
correct += 1
total += 1
-
+
category_score = (correct / total) * 100 if total > 0 else 0.0
eval_results[category] = category_score
total_score += category_score
num_categories += 1
-
+
# Add average score
eval_results["average"] = total_score / num_categories if num_categories > 0 else 0.0
- return eval_results
\ No newline at end of file
+ return eval_results