Skip to content

Commit

Permalink
cast batch_size to int
Browse files Browse the repository at this point in the history
  • Loading branch information
jmercat committed Jan 6, 2025
1 parent 4a181e0 commit b8652ca
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 15 deletions.
11 changes: 0 additions & 11 deletions eval/chat_benchmarks/LiveBench/eval_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,6 @@ def get_question_list(self, model_name: str, release_set: set):
else:
raise ValueError(f"Bad question source {self.question_source}.")

# questions_all = [
# q
# for q in questions_all
# if q[0]["livebench_removal_date"] == "" or q[0]["livebench_removal_date"] > self.release_date
# ]
return questions_all

def _get_model_name(self, model: LM) -> str:
Expand Down Expand Up @@ -317,12 +312,6 @@ def evaluate_responses(self, results: Dict[str, Any]) -> Dict[str, float]:
question_file, self.all_release_dates, self.question_begin, self.question_end
)

# questions = [
# q
# for q in questions
# if q["livebench_removal_date"] == "" or q["livebench_removal_date"] > self.release_date
# ]

bench_name = os.path.dirname(question_file).replace(f"{self.data_path}/", "")

output_file = f"{self.data_path}/{bench_name}/model_judgment/ground_truth_judgment.jsonl"
Expand Down
3 changes: 1 addition & 2 deletions eval/chat_benchmarks/LiveBench/livebench/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,8 +647,7 @@ def check_data(questions, model_answers, models):
# check model answers
for m in models:
if not m in model_answers:
breakpoint()
# raise ValueError(f"Missing model answer for {m}")
raise ValueError(f"Missing model answer for {m}")
m_answer = model_answers[m]
for q in questions:
assert q["question_id"] in m_answer, f"Missing model {m}'s answer to Question {q['question_id']}"
Expand Down
4 changes: 2 additions & 2 deletions eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,10 @@ def cli_evaluate(args: Optional[argparse.Namespace] = None) -> None:
with open(args.config, "r") as file:
tasks_yaml = yaml.safe_load(file)
args.tasks = ",".join([t["task_name"] for t in tasks_yaml["tasks"]])
batch_sizes_list = [t["batch_size"] for t in tasks_yaml["tasks"]]
batch_sizes_list = [int(t["batch_size"]) for t in tasks_yaml["tasks"]]
args.annotator_model = tasks_yaml.get("annotator_model", args.annotator_model)
else:
batch_sizes_list = [args.batch_size for _ in range(len(args.tasks.split(",")))]
batch_sizes_list = [int(args.batch_size) for _ in range(len(args.tasks.split(",")))]

# Initialize evaluation tracker
if args.output_path:
Expand Down

0 comments on commit b8652ca

Please sign in to comment.