Skip to content

Commit

Permalink
cast input batch size to int in cli_evaluate to fix batch size issue …
Browse files Browse the repository at this point in the history
…in some benchmark
  • Loading branch information
jmercat committed Jan 6, 2025
1 parent e937b74 commit 09b66c2
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,10 @@ def cli_evaluate(args: Optional[argparse.Namespace] = None) -> None:
with open(args.config, "r") as file:
tasks_yaml = yaml.safe_load(file)
args.tasks = ",".join([t["task_name"] for t in tasks_yaml["tasks"]])
batch_sizes_list = [t["batch_size"] for t in tasks_yaml["tasks"]]
batch_sizes_list = [int(t["batch_size"]) for t in tasks_yaml["tasks"]]
args.annotator_model = tasks_yaml.get("annotator_model", args.annotator_model)
else:
batch_sizes_list = [args.batch_size for _ in range(len(args.tasks.split(",")))]
batch_sizes_list = [int(args.batch_size) for _ in range(len(args.tasks.split(",")))]

# Initialize evaluation tracker
if args.output_path:
Expand Down

0 comments on commit 09b66c2

Please sign in to comment.