Skip to content

Commit

Permalink
Merge pull request #143 from iCSawyer/main
Browse files Browse the repository at this point in the history
Add --save_references_path into args
  • Loading branch information
loubnabnl authored Dec 26, 2023
2 parents be2a44c + a5fc279 commit 8d9f667
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 4 deletions.
4 changes: 2 additions & 2 deletions bigcode_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ def evaluate(self, task_name):
f"generations were saved at {self.args.save_generations_path}"
)
if self.args.save_references:
with open("references.json", "w") as fp:
with open(self.args.save_references_path, "w") as fp:
json.dump(references, fp)
print("references were saved at references.json")
print(f"references were saved at {self.args.save_references_path}")

# make sure tokenizer plays nice with multiprocessing
os.environ["TOKENIZERS_PARALLELISM"] = "false"
Expand Down
10 changes: 8 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,12 @@ def parse_args():
action="store_true",
help="Whether to save reference solutions/tests",
)
parser.add_argument(
"--save_references_path",
type=str,
default="references.json",
help="Path for saving the references solutions/tests",
)
parser.add_argument(
"--prompt",
type=str,
Expand Down Expand Up @@ -335,9 +341,9 @@ def main():
json.dump(generations, fp)
print(f"generations were saved at {args.save_generations_path}")
if args.save_references:
with open("references.json", "w") as fp:
with open(args.save_references_path, "w") as fp:
json.dump(references, fp)
print("references were saved")
print(f"references were saved at {args.save_references_path}")
else:
results[task] = evaluator.evaluate(task)

Expand Down
1 change: 1 addition & 0 deletions tests/test_generation_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def update_args(args):
args.save_generations = False
args.save_generations_path = ""
args.save_references = False
args.save_references_path = ""
args.metric_output_path = TMPDIR
args.load_generations_path = None
args.generation_only = False
Expand Down

0 comments on commit 8d9f667

Please sign in to comment.