-
Notifications
You must be signed in to change notification settings - Fork 541
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
58c84da
commit 36f84a9
Showing
2 changed files
with
188 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
import time | ||
from textwrap import dedent, indent | ||
|
||
import coverage | ||
from coverage.misc import import_local_file | ||
|
||
from outlines.models.openai import OpenAIChatCompletion | ||
|
||
completer = OpenAIChatCompletion("gpt-3.5-turbo") | ||
|
||
|
||
def get_missing_coverage_lines(code: str): | ||
cov = coverage.Coverage(branch=True) | ||
cov.erase() | ||
|
||
modname = "test_mod" | ||
with open(modname + ".py", "wb") as f: | ||
f.write(code.encode("utf-8")) | ||
|
||
cov.start() | ||
try: | ||
mod = import_local_file(modname) | ||
finally: | ||
cov.stop() | ||
|
||
analysis = cov._analyze(mod) | ||
|
||
return sorted(analysis.missing) | ||
|
||
|
||
def collect_test_functions(string): | ||
"""Collects the names of the test functions in a given string. | ||
Args: | ||
string: The string to collect the test functions from. | ||
Returns: | ||
A list of the names of the test functions. | ||
(This was generated by Bard!) | ||
""" | ||
|
||
# Split the string into lines. | ||
lines = string.splitlines() | ||
|
||
# Create a list to store the test function names. | ||
test_function_names = [] | ||
|
||
# Iterate over the lines, looking for lines that start with the `def` keyword. | ||
for line in lines: | ||
if line.startswith("def"): | ||
# Get the name of the test function from the line. | ||
test_function_name = line.split("def")[1].split("(")[0].strip() | ||
|
||
# Add the test function name to the list. | ||
test_function_names.append(test_function_name) | ||
|
||
# Return the list of test function names. | ||
return test_function_names | ||
|
||
|
||
def construct_prompt(target_code, test_code, lines): | ||
# target_code = target_code.strip() | ||
target_code = indent(target_code, " " * 4 * 3) | ||
|
||
if not test_code: | ||
prompt = dedent( | ||
f""" | ||
The Python module code is as follows: | ||
{target_code} | ||
Print only a completed version of the following Python function named test_some_function that achieves full coverage over the Python module code: | ||
def test_some_function(): | ||
""" | ||
) | ||
else: | ||
test_code = indent(test_code, " " * 4 * 3) | ||
|
||
prompt = dedent( | ||
f""" | ||
The module code is as follows: | ||
{target_code} | ||
Print only a completed version of the following Python function named test_some_function so that lines {", ".join(lines)} in the Python module code are covered: | ||
{test_code} | ||
""" | ||
) | ||
|
||
return prompt | ||
|
||
|
||
def query(test_code, lines): | ||
"""Sample a query completion.""" | ||
prompt = construct_prompt(question_code, test_code, lines) | ||
answer = completer(prompt) | ||
return prompt, answer | ||
|
||
|
||
def get_missing_lines_for_completion(answer): | ||
"""Get the missing lines for the given completion.""" | ||
|
||
def create_call(name): | ||
return dedent( | ||
f""" | ||
try: | ||
{name}() | ||
except AssertionError: | ||
pass""" | ||
) | ||
|
||
test_function_names = collect_test_functions(answer) | ||
run_statements = "\n".join([create_call(fname) for fname in test_function_names]) | ||
|
||
c1 = f""" | ||
{question_code} | ||
{answer} | ||
{run_statements} | ||
""" | ||
|
||
lines = get_missing_coverage_lines(c1) | ||
lines = [ln for ln in lines if ln < question_code.count("\n")] | ||
return lines | ||
|
||
|
||
lines: list = [] | ||
test_code = "" | ||
dialog = [] | ||
|
||
|
||
question_examples = [ | ||
r""" | ||
def some_function(x, y): | ||
if x < 0: | ||
z = y - x | ||
else: | ||
z = y + x | ||
if z > y: | ||
return True | ||
return False | ||
""", | ||
r""" | ||
def some_function(x: int, y: int) -> bool: | ||
for i in range(x): | ||
if x < 3: | ||
z = y - x | ||
else: | ||
z = y + x | ||
if z > y: | ||
return True | ||
return False | ||
""", | ||
] | ||
|
||
|
||
for question_code in question_examples: | ||
question_dialog = [] | ||
for i in range(3): | ||
time.sleep(1) | ||
# TODO: Check the format of the response | ||
prompt, answer_code = query(question_code, lines) | ||
|
||
print(f"QUESTION:\n{question_code}") | ||
print(f"ANSWER:\n{answer_code}") | ||
|
||
question_dialog.append((question_code, prompt, answer_code)) | ||
|
||
lines.extend(get_missing_lines_for_completion(answer_code)) | ||
|
||
if not lines: | ||
break | ||
|
||
dialog.append([question_code, question_dialog]) | ||
|
||
i = 0 | ||
print(dialog[i][0]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -63,5 +63,7 @@ module = [ | |
"scipy.*", | ||
"torch", | ||
"transformers", | ||
"coverage", | ||
"coverage.misc", | ||
] | ||
ignore_missing_imports = true |