Skip to content

Commit

Permalink
add flushing on repeating errors
Browse files Browse the repository at this point in the history
  • Loading branch information
alex28sh committed Oct 23, 2024
1 parent 57d2c71 commit b19eb4f
Show file tree
Hide file tree
Showing 10 changed files with 117 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,7 @@ anywhere else you should not use `Acc()` or `list_pred()`:
5. Nagini contains `Forall` and `Exists` constructs that can be used in invariants. First argument of Forall/Exists is typically a type (i.e `int`),
second argument is a lambda. `Forall(type, lambda x : a)` denotes that assertion `a` is true for every element `x` of type `type`.

6. In Nagini `Implies(e1, a2)` plays role of implication. `Implies(e1, a2)` denotes that assertion a2 holds if boolean expression e1 is true.
You can use it inside invariants and asserts.

Respond with 'understood' to this message. Remember this knowledge to solve the task that will be given further
3 changes: 3 additions & 0 deletions prompts/humaneval-nagini-cot-instruct/steps/001/question.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,7 @@ anywhere else you should not use `Acc()` or `list_pred()`:
5. Nagini contains `Forall` and `Exists` constructs that can be used in invariants. First argument of Forall/Exists is typically a type (i.e `int`),
second argument is a lambda. `Forall(type, lambda x : a)` denotes that assertion `a` is true for every element `x` of type `type`.

6. In Nagini `Implies(e1, a2)` plays role of implication. `Implies(e1, a2)` denotes that assertion a2 holds if boolean expression e1 is true.
You can use it inside invariants and asserts.

Respond with 'understood' to this message. Remember this knowledge to solve the task that will be given further
46 changes: 46 additions & 0 deletions tests/test_compare_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from verified_cogen.tools import compare_errors


def test_compare_errors1():
err1: str = "Verification timed out"
err2: str = "Verification timed out"
assert compare_errors(err1, err2)

def test_compare_errors2():
err1: str = "Verification timed out"
err2: str = """\
Translation failed
Not supported: (0 <= d_6_i_ <= len(l))
(/home/aleksandr/verified-cogen/log_tries/humaneval-nagini/[email protected])"""
assert not compare_errors(err1, err2)

def test_compare_errors3():
err1: str = """\
Translation failed
Type error: invalid syntax (/home/aleksandr/verified-cogen/log_tries/humaneval-nagini/[email protected])"""
err2: str = """\
Translation failed
Type error: invalid syntax (/home/aleksandr/verified-cogen/log_tries/humaneval-nagini/[email protected])"""
assert compare_errors(err1, err2)

def test_compare_errors4():
err1: str = """\
Translation failed
Type error: invalid syntax (/home/aleksandr/verified-cogen/log_tries/humaneval-nagini/[email protected])"""
err2: str = """\
Translation failed
Type error: invalid syntax (/home/aleksandr/verified-cogen/log_tries/humaneval-nagini/[email protected])"""
assert not compare_errors(err1, err2)

def test_compare_errors5():
err1: str = """\
Verification failed
Errors:
Loop invariant might not be preserved. Assertion (s == psum(0, d_2_i_, numbers)) might not hold. ([email protected])
Verification took 9.04 seconds."""
err2: str = """\
Verification failed
Errors:
Loop invariant might not be preserved. Assertion (s == psum(0, d_2_i_, numbers)) might not hold. ([email protected])
Verification took 13.72 seconds."""
assert compare_errors(err1, err2)
2 changes: 1 addition & 1 deletion verified_cogen/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def get_default_parser():
)
parser.add_argument(
"--bench-type",
help="benchmark type, available: {invariants, generic, generate, validating, step-by-step}",
help="benchmark type, available: {invariants, generic, generate, validating, step-by-step, step-by-step-flush}",
default="invariants",
)
parser.add_argument("--temperature", help="model temperature", default=0, type=int)
Expand Down
6 changes: 6 additions & 0 deletions verified_cogen/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ def add_response(self, response: str, temporary: bool = False):
self.responses.append(response)
self.is_response_temporary.append(temporary)

def wipe_all(self):
self.user_prompts = []
self.is_user_prompt_temporary = []
self.responses = []
self.is_response_temporary = []

def wipe_temporary(self):
self.user_prompts = [
prompt
Expand Down
6 changes: 6 additions & 0 deletions verified_cogen/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from verified_cogen.runners.languages import register_basic_languages
from verified_cogen.runners.languages.language import AnnotationType, LanguageDatabase
from verified_cogen.runners.step_by_step import StepByStepRunner
from verified_cogen.runners.step_by_step_flush import StepByStepFlushRunner
from verified_cogen.runners.validating import ValidatingRunner
from verified_cogen.tools import (
ext_glob,
Expand Down Expand Up @@ -115,6 +116,11 @@ def runner_cls(llm: LLM, logger: Logger, verifier: Verifier):
StepByStepRunner(InvariantRunner(llm, logger, verifier, config)),
LanguageDatabase().get(extension),
)
elif bench_type == "step-by-step-flush":
return ValidatingRunner(
StepByStepFlushRunner(InvariantRunner(llm, logger, verifier, config)),
LanguageDatabase().get(extension),
)
else:
raise ValueError(f"Unexpected bench_type: {bench_type}")

Expand Down
4 changes: 3 additions & 1 deletion verified_cogen/runners/languages/nagini.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def __init__(self, remove_annotations: list[AnnotationType]): # type: ignore
def separate_validator_errors(self, errors: str) -> tuple[str, str]:
lines = errors.split("\n")
lines = [
line for line in lines if "Verification successful" not in line and "Verification took" not in line
line
for line in lines
if "Verification successful" not in line and "Verification took" not in line
]
return "\n".join(lines), ""
33 changes: 33 additions & 0 deletions verified_cogen/runners/step_by_step_flush.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Optional
from verified_cogen.runners import Runner
from verified_cogen.runners.step_by_step import StepByStepRunner, StepByStepConfig
from verified_cogen.tools import compare_errors


class StepByStepFlushRunner(StepByStepRunner):
previous_error: str = ""
timeout: str = "Verification timed out"

def __init__(self, wrapping: Runner, config: Optional[StepByStepConfig] = None):
super().__init__(wrapping, config)

def flush_and_rewrite(self) -> str:
assert self.starting_prg is not None
self.llm.wipe_all()
self.previous_error = ""
self.logger.info("Encountered same error. Rewrite")
return self.rewrite(self.starting_prg)

def ask_for_timeout(self) -> str:
if compare_errors(self.previous_error, self.timeout):
return self.flush_and_rewrite()
else:
self.previous_error = self.timeout
return self.wrapped_runner.ask_for_timeout()

def ask_for_fixed(self, err: str) -> str:
if compare_errors(self.previous_error, err):
return self.flush_and_rewrite()
else:
self.previous_error = err
return self.wrapped_runner.ask_for_fixed(err)
4 changes: 3 additions & 1 deletion verified_cogen/runners/validating.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def _add_validators(self, prg: str, inv_prg: str):
return val_prg

def preprocess(self, prg: str, mode: Mode) -> str:
return self.language.remove_conditions(prg)
res_prg = self.language.remove_conditions(prg)
self.wrapped_runner.starting_prg = res_prg
return res_prg

def postprocess(self, inv_prg: str) -> str:
assert self.starting_prg is not None
Expand Down
13 changes: 13 additions & 0 deletions verified_cogen/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,16 @@ def register_output_handler(logger: logging.Logger):
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)


def compare_errors(error1: str, error2: str):
pattern = r"\(.*?\.py"
pattern_time = r"Verification took \d+\.\d+ seconds\."

cleaned_error1 = re.sub(pattern, "(", error1).strip()
cleaned_error1 = re.sub(pattern_time, "", cleaned_error1).strip()

cleaned_error2 = re.sub(pattern, "(", error2).strip()
cleaned_error2 = re.sub(pattern_time, "", cleaned_error2).strip()

return cleaned_error1 == cleaned_error2

0 comments on commit b19eb4f

Please sign in to comment.