diff --git a/prompts/humaneval-nagini-cot-instruct/steps/001/examples/001/question.txt b/prompts/humaneval-nagini-cot-instruct/steps/001/examples/001/question.txt index 389b628..771f87c 100644 --- a/prompts/humaneval-nagini-cot-instruct/steps/001/examples/001/question.txt +++ b/prompts/humaneval-nagini-cot-instruct/steps/001/examples/001/question.txt @@ -34,4 +34,7 @@ anywhere else you should not use `Acc()` or `list_pred()`: 5. Nagini contains `Forall` and `Exists` constructs that can be used in invariants. First argument of Forall/Exists is typically a type (i.e `int`), second argument is a lambda. `Forall(type, lambda x : a)` denotes that assertion `a` is true for every element `x` of type `type`. +6. In Nagini `Implies(e1, a2)` plays role of implication. `Implies(e1, a2)` denotes that assertion a2 holds if boolean expression e1 is true. +You can use it inside invariants and asserts. + Respond with 'understood' to this message. Remember this knowledge to solve the task that will be given further \ No newline at end of file diff --git a/prompts/humaneval-nagini-cot-instruct/steps/001/question.txt b/prompts/humaneval-nagini-cot-instruct/steps/001/question.txt index 389b628..771f87c 100644 --- a/prompts/humaneval-nagini-cot-instruct/steps/001/question.txt +++ b/prompts/humaneval-nagini-cot-instruct/steps/001/question.txt @@ -34,4 +34,7 @@ anywhere else you should not use `Acc()` or `list_pred()`: 5. Nagini contains `Forall` and `Exists` constructs that can be used in invariants. First argument of Forall/Exists is typically a type (i.e `int`), second argument is a lambda. `Forall(type, lambda x : a)` denotes that assertion `a` is true for every element `x` of type `type`. +6. In Nagini `Implies(e1, a2)` plays role of implication. `Implies(e1, a2)` denotes that assertion a2 holds if boolean expression e1 is true. +You can use it inside invariants and asserts. + Respond with 'understood' to this message. Remember this knowledge to solve the task that will be given further \ No newline at end of file diff --git a/tests/test_compare_errors.py b/tests/test_compare_errors.py new file mode 100644 index 0000000..f4ecec1 --- /dev/null +++ b/tests/test_compare_errors.py @@ -0,0 +1,46 @@ +from verified_cogen.tools import compare_errors + + +def test_compare_errors1(): + err1: str = "Verification timed out" + err2: str = "Verification timed out" + assert compare_errors(err1, err2) + +def test_compare_errors2(): + err1: str = "Verification timed out" + err2: str = """\ + Translation failed + Not supported: (0 <= d_6_i_ <= len(l)) + (/home/aleksandr/verified-cogen/log_tries/humaneval-nagini/040-triples-sum-to-zero.1.py@19.18)""" + assert not compare_errors(err1, err2) + +def test_compare_errors3(): + err1: str = """\ + Translation failed + Type error: invalid syntax (/home/aleksandr/verified-cogen/log_tries/humaneval-nagini/042-incr-list.2.py@18.0)""" + err2: str = """\ + Translation failed + Type error: invalid syntax (/home/aleksandr/verified-cogen/log_tries/humaneval-nagini/042-incr-list.3.py@18.0)""" + assert compare_errors(err1, err2) + +def test_compare_errors4(): + err1: str = """\ + Translation failed + Type error: invalid syntax (/home/aleksandr/verified-cogen/log_tries/humaneval-nagini/042-incr-list.1.py@19.0)""" + err2: str = """\ + Translation failed + Type error: invalid syntax (/home/aleksandr/verified-cogen/log_tries/humaneval-nagini/042-incr-list.3.py@18.0)""" + assert not compare_errors(err1, err2) + +def test_compare_errors5(): + err1: str = """\ + Verification failed + Errors: + Loop invariant might not be preserved. Assertion (s == psum(0, d_2_i_, numbers)) might not hold. (008-sum-product.3.py@32.18--32.47) + Verification took 9.04 seconds.""" + err2: str = """\ + Verification failed + Errors: + Loop invariant might not be preserved. Assertion (s == psum(0, d_2_i_, numbers)) might not hold. (008-sum-product.4.py@32.18--32.47) + Verification took 13.72 seconds.""" + assert compare_errors(err1, err2) \ No newline at end of file diff --git a/verified_cogen/args.py b/verified_cogen/args.py index e4bcf4a..ed056c8 100644 --- a/verified_cogen/args.py +++ b/verified_cogen/args.py @@ -65,7 +65,7 @@ def get_default_parser(): ) parser.add_argument( "--bench-type", - help="benchmark type, available: {invariants, generic, generate, validating, step-by-step}", + help="benchmark type, available: {invariants, generic, generate, validating, step-by-step, step-by-step-flush}", default="invariants", ) parser.add_argument("--temperature", help="model temperature", default=0, type=int) diff --git a/verified_cogen/llm/llm.py b/verified_cogen/llm/llm.py index 4341b30..be5b194 100644 --- a/verified_cogen/llm/llm.py +++ b/verified_cogen/llm/llm.py @@ -56,6 +56,12 @@ def add_response(self, response: str, temporary: bool = False): self.responses.append(response) self.is_response_temporary.append(temporary) + def wipe_all(self): + self.user_prompts = [] + self.is_user_prompt_temporary = [] + self.responses = [] + self.is_response_temporary = [] + def wipe_temporary(self): self.user_prompts = [ prompt diff --git a/verified_cogen/runners/languages/nagini.py b/verified_cogen/runners/languages/nagini.py index f1f6df0..06ac472 100644 --- a/verified_cogen/runners/languages/nagini.py +++ b/verified_cogen/runners/languages/nagini.py @@ -38,6 +38,8 @@ def __init__(self, remove_annotations: list[AnnotationType]): # type: ignore def separate_validator_errors(self, errors: str) -> tuple[str, str]: lines = errors.split("\n") lines = [ - line for line in lines if "Verification successful" not in line and "Verification took" not in line + line + for line in lines + if "Verification successful" not in line and "Verification took" not in line ] return "\n".join(lines), "" diff --git a/verified_cogen/runners/step_by_step_flush.py b/verified_cogen/runners/step_by_step_flush.py new file mode 100644 index 0000000..2444741 --- /dev/null +++ b/verified_cogen/runners/step_by_step_flush.py @@ -0,0 +1,33 @@ +from typing import Optional +from verified_cogen.runners import Runner +from verified_cogen.runners.step_by_step import StepByStepRunner, StepByStepConfig +from verified_cogen.tools import compare_errors + + +class StepByStepFlushRunner(StepByStepRunner): + previous_error: str = "" + timeout: str = "Verification timed out" + + def __init__(self, wrapping: Runner, config: Optional[StepByStepConfig] = None): + super().__init__(wrapping, config) + + def flush_and_rewrite(self) -> str: + assert self.starting_prg is not None + self.llm.wipe_all() + self.previous_error = "" + self.logger.info("Encountered same error. Rewrite") + return self.rewrite(self.starting_prg) + + def ask_for_timeout(self) -> str: + if compare_errors(self.previous_error, self.timeout): + return self.flush_and_rewrite() + else: + self.previous_error = self.timeout + return self.wrapped_runner.ask_for_timeout() + + def ask_for_fixed(self, err: str) -> str: + if compare_errors(self.previous_error, err): + return self.flush_and_rewrite() + else: + self.previous_error = err + return self.wrapped_runner.ask_for_fixed(err) diff --git a/verified_cogen/runners/validating.py b/verified_cogen/runners/validating.py index 2541825..a1c1e51 100644 --- a/verified_cogen/runners/validating.py +++ b/verified_cogen/runners/validating.py @@ -38,7 +38,9 @@ def _add_validators(self, prg: str, inv_prg: str): return val_prg def preprocess(self, prg: str, mode: Mode) -> str: - return self.language.remove_conditions(prg) + res_prg = self.language.remove_conditions(prg) + self.wrapped_runner.starting_prg = res_prg + return res_prg def postprocess(self, inv_prg: str) -> str: assert self.starting_prg is not None diff --git a/verified_cogen/tools/__init__.py b/verified_cogen/tools/__init__.py index 259be67..61c9f26 100644 --- a/verified_cogen/tools/__init__.py +++ b/verified_cogen/tools/__init__.py @@ -66,3 +66,16 @@ def register_output_handler(logger: logging.Logger): formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") handler.setFormatter(formatter) logger.addHandler(handler) + + +def compare_errors(error1: str, error2: str): + pattern = r"\(.*?\.py" + pattern_time = r"Verification took \d+\.\d+ seconds\." + + cleaned_error1 = re.sub(pattern, "(", error1).strip() + cleaned_error1 = re.sub(pattern_time, "", cleaned_error1).strip() + + cleaned_error2 = re.sub(pattern, "(", error2).strip() + cleaned_error2 = re.sub(pattern_time, "", cleaned_error2).strip() + + return cleaned_error1 == cleaned_error2