Skip to content

Commit

Permalink
Minor fixes to the evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhamugare committed Dec 8, 2024
1 parent b144be2 commit a034a36
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion syncode/evaluation/fol_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def run_eval(syncode, out_path: Optional[str]=None, debug_task_id=None):
for task_id, problem in enumerate(problems):
results[task_id] = []
full_prompt = FOLEval._prompt_folio(problem)
completion = syncode.model.generate_batch_completion_grammar(
completion = syncode.model.generate_grammar_constrained_completion(
full_prompt,
syncode.num_samples,
stop_words=['\n\n', '------']
Expand Down
2 changes: 1 addition & 1 deletion syncode/evaluation/json_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def run_eval_for_task(syncode, num_samples_per_task, problem, samples, pbar, tas

prompt = syncode.model.tokenizer.apply_chat_template(problem["prompt"], tokenize = False)

batch_completions = syncode.model.generate_batch_completion_grammar(prompt, num_samples_per_task)
batch_completions = syncode.model.generate_grammar_constrained_completion(prompt, num_samples_per_task)
for completion_id, completion in enumerate(batch_completions):
result = dict(
task_id = task_id,
Expand Down
2 changes: 1 addition & 1 deletion syncode/evaluation/math_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def run_math_eval(syncode, out_path: Optional[str], debug_task_id=None, logger=c

for task_id, problem in enumerate(problems):
results[task_id] = []
batch_completions = syncode.model.generate_batch_completion_grammar(
batch_completions = syncode.model.generate_grammar_constrained_completion(
problem['question'],
syncode.num_samples
)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_vocab(self) -> Dict[str, int]:
return {v: i for i, v in enumerate(self.vocab)}

class TestHuggingFaceModel(unittest.TestCase):
def test_generate_batch_completion_grammar(self):
def test_generate_grammar_constrained_completion(self):
torch.manual_seed(0)
model = TestModel()
tokenizer = TestTokenizer()
Expand All @@ -65,7 +65,7 @@ def test_generate_batch_completion_grammar(self):
output = lm.generate_grammar_constrained_completion(prompt, 1)
self.assertEqual(len(output[0]), 15, "The output length does not match the expected value.")

def test_generate_batch_completion_grammar2(self):
def test_generate_grammar_constrained_completion2(self):
torch.manual_seed(0)
model = TestModel()
tokenizer = TestTokenizer()
Expand Down

0 comments on commit a034a36

Please sign in to comment.