Skip to content

Commit

Permalink
Add a coverage bot example
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Apr 29, 2023
1 parent 58c84da commit 36f84a9
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 0 deletions.
186 changes: 186 additions & 0 deletions examples/cover_bot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import time
from textwrap import dedent, indent

import coverage
from coverage.misc import import_local_file

from outlines.models.openai import OpenAIChatCompletion

completer = OpenAIChatCompletion("gpt-3.5-turbo")


def get_missing_coverage_lines(code: str):
cov = coverage.Coverage(branch=True)
cov.erase()

modname = "test_mod"
with open(modname + ".py", "wb") as f:
f.write(code.encode("utf-8"))

cov.start()
try:
mod = import_local_file(modname)
finally:
cov.stop()

analysis = cov._analyze(mod)

return sorted(analysis.missing)


def collect_test_functions(string):
"""Collects the names of the test functions in a given string.
Args:
string: The string to collect the test functions from.
Returns:
A list of the names of the test functions.
(This was generated by Bard!)
"""

# Split the string into lines.
lines = string.splitlines()

# Create a list to store the test function names.
test_function_names = []

# Iterate over the lines, looking for lines that start with the `def` keyword.
for line in lines:
if line.startswith("def"):
# Get the name of the test function from the line.
test_function_name = line.split("def")[1].split("(")[0].strip()

# Add the test function name to the list.
test_function_names.append(test_function_name)

# Return the list of test function names.
return test_function_names


def construct_prompt(target_code, test_code, lines):
# target_code = target_code.strip()
target_code = indent(target_code, " " * 4 * 3)

if not test_code:
prompt = dedent(
f"""
The Python module code is as follows:
{target_code}
Print only a completed version of the following Python function named test_some_function that achieves full coverage over the Python module code:
def test_some_function():
"""
)
else:
test_code = indent(test_code, " " * 4 * 3)

prompt = dedent(
f"""
The module code is as follows:
{target_code}
Print only a completed version of the following Python function named test_some_function so that lines {", ".join(lines)} in the Python module code are covered:
{test_code}
"""
)

return prompt


def query(test_code, lines):
"""Sample a query completion."""
prompt = construct_prompt(question_code, test_code, lines)
answer = completer(prompt)
return prompt, answer


def get_missing_lines_for_completion(answer):
"""Get the missing lines for the given completion."""

def create_call(name):
return dedent(
f"""
try:
{name}()
except AssertionError:
pass"""
)

test_function_names = collect_test_functions(answer)
run_statements = "\n".join([create_call(fname) for fname in test_function_names])

c1 = f"""
{question_code}
{answer}
{run_statements}
"""

lines = get_missing_coverage_lines(c1)
lines = [ln for ln in lines if ln < question_code.count("\n")]
return lines


lines: list = []
test_code = ""
dialog = []


question_examples = [
r"""
def some_function(x, y):
if x < 0:
z = y - x
else:
z = y + x
if z > y:
return True
return False
""",
r"""
def some_function(x: int, y: int) -> bool:
for i in range(x):
if x < 3:
z = y - x
else:
z = y + x
if z > y:
return True
return False
""",
]


for question_code in question_examples:
question_dialog = []
for i in range(3):
time.sleep(1)
# TODO: Check the format of the response
prompt, answer_code = query(question_code, lines)

print(f"QUESTION:\n{question_code}")
print(f"ANSWER:\n{answer_code}")

question_dialog.append((question_code, prompt, answer_code))

lines.extend(get_missing_lines_for_completion(answer_code))

if not lines:
break

dialog.append([question_code, question_dialog])

i = 0
print(dialog[i][0])
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,7 @@ module = [
"scipy.*",
"torch",
"transformers",
"coverage",
"coverage.misc",
]
ignore_missing_imports = true

0 comments on commit 36f84a9

Please sign in to comment.