Skip to content

Commit

Permalink
python: add llmcommand.py
Browse files Browse the repository at this point in the history
  • Loading branch information
levy authored and avarga committed Jul 3, 2024
1 parent 061f8a5 commit 3af5f7e
Showing 1 changed file with 290 additions and 0 deletions.
290 changes: 290 additions & 0 deletions python/inet/llmcommand.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
import logging
import os
import re
import subprocess
import unidiff
import llm
import argparse
import tiktoken

# Initialize logger
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)

def collect_matching_file_paths(directory, name_regex, content_pattern=None):
matching_file_paths = []
content_regex = re.compile(content_pattern) if content_pattern else None

for root, _, files in os.walk(directory):
for file in files:
if name_regex.match(file):
file_path = os.path.join(root, file)
if not content_regex:
matching_file_paths.append(file_path)
else:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
try:
content = f.read()
if content_regex.search(content):
matching_file_paths.append(file_path)
except Exception as e:
print(f"Could not read file {file_path}: {e}")

return matching_file_paths

def read_files(file_list):
contents = []
for file_path in file_list:
with open(file_path, 'r', encoding='utf-8') as file:
content = file.read()
contents.append(f"File `{file_path}`:\n```\n{content}\n```\n\n")
return "".join(contents)

def get_llm_context_window(model):
size = model.default_max_tokens
if size: # it is not always filled in?
return size
model_name = model.model_id
context_window_sizes = {
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-16k": 16384,
"gpt-4": 8192,
"gpt-4-32k": 32768,
"gpt-4o": 128*1024,
}
if model_name not in context_window_sizes:
raise Exception(f"Context window size for llm '{model_name}' is not known, please add it to the table")
return context_window_sizes[model_name]

def check_token_count(prompt, model):
encoder = tiktoken.encoding_for_model(model.model_id)
num_tokens = len(encoder.encode(prompt))
max_tokens = get_llm_context_window(model)
_logger.info(f"Number of tokens in the prompt: {num_tokens}")
if max_tokens and num_tokens > max_tokens:
print(f"WARNING: Prompt of {num_tokens} tokens exceeds the model's context window size of {max_tokens} tokens")

def generate_command_text(task, file_type):
file_type_commands = {
"md": "The following Markdown file belongs to the OMNeT++ project.",
"rst": "The following reStructuredText file belongs to the OMNeT++ project. Blocks starting with '..' are comments, do not touch them!",
"tex": "The following LaTeX file belongs to the OMNeT++ project.",
"ned": "The following is an OMNeT++ NED file. DO NOT CHANGE ANYTHING IN NON-COMMENT LINES.",
"py": "The following Python source file belongs to the OMNeT++ project.",
}
what = "docstrings" if file_type == "py" else "text"
task_commands = {
"proofread": f"You are an editor of technical documentation. You are tireless and diligent. Fix any English mistakes in the {what}. Keep all markup, line breaks and indentation intact as much as possible!",
"improve-language": f"You are an editor of technical documentation. You are tireless and diligent. Improve the English in the {what}. Keep all other markup and line breaks intact as much as possible.",
"eliminate-you-addressing": f"You are an editor of technical documentation. You are tireless and diligent. At places where {what} addresses the user as 'you', change it to neutral, e.g., to passive voice or 'one' as subject. Keep all markup and line breaks intact as much as possible.",
"neddoc": "You are a technical writer. You are tireless and diligent. Write a new neddoc comment for the module in the NED file. NED comments use // marks instead of /*..*/.",
# "neddoc": """
# You are a technical writer. You are tireless and diligent. You are working in the context of OMNeT++ and the INET Framework.
# Your task is to write a single sentence capturing the most important aspect of the following simple module.
# Ignore the operational details of the module and focus on the aspects that help the user understand what this module is good for.
# """
}

if file_type not in file_type_commands:
raise ValueError(f'Unsupported file type "{file_type}"')
if task not in task_commands:
raise ValueError(f'Unsupported task "{task}"')
if "ned" in task and file_type != "ned":
raise ValueError(f'Task "{task}" is only supported for the "ned" file type')

return file_type_commands[file_type] + " " + task_commands[task]

def get_recommended_model(task):
recommended_models = {
None: "gpt-3.5-turbo-16k",
"proofread": "gpt-3.5-turbo-16k",
"improve-language": "gpt-3.5-turbo-16k",
"eliminate-you-addressing": "gpt-4o",
"neddoc": "gpt-4o"
}
if task not in recommended_models:
raise ValueError(f'No info on which model is recommended for "{task}", specify one explicitly via --model or update the tool')
return recommended_models[task]

def find_additional_context_files(file_path, file_type, task):
context_files = []
if task == "neddoc":
# fname_without_ext = os.path.splitext(os.path.basename(file_path))[0]
# h_fname = fname_without_ext + ".h"
# cc_fname = fname_without_ext + ".cc"
# for root, _, files in os.walk(os.path.dirname(file_path) or "."):
# if h_fname in files:
# context_files.append(os.path.join(root, h_fname))
# if cc_fname in files:
# context_files.append(os.path.join(root, cc_fname))

fname_without_ext = os.path.splitext(os.path.basename(file_path))[0]
grep_command = f"""rg --heading -g '*.cc' -g '*.h' -g '*.ini' -g '*.ned' -g '!{os.path.basename(file_path)}' -C 20 '{fname_without_ext}' $(git rev-parse --show-toplevel) | head -n 3000 > {file_path}.ctx"""
os.system(grep_command)
context_files.append(file_path + ".ctx")

return context_files

def create_prompt(content, context, task, prompt, file_type):
command_text = prompt if prompt else generate_command_text(task, file_type)
prompt = f"Update a file delimited by triple quotes. {command_text}\n\n"
if context:
prompt += f"Here is the context:\n{context}\n\n"
prompt += f'Here is the file that should be updated:\n\n```\n{content}\n```\n\n'
prompt += f"Respond with the updated file verbatim without any additional commentary.\n"
return prompt

def invoke_llm(prompt, model):
check_token_count(prompt, model)

_logger.debug(f"Sending prompt to LLM: {prompt}")
reply = model.prompt(prompt)
reply_text = reply.text()
_logger.debug(f"Received result from LLM: {reply_text}")
return reply_text

def split_content(file_content, file_type, max_chars):
if file_type == "tex":
return split_latex_by_sections(file_content, max_chars)
elif file_type == "rst":
return split_rst_by_headings(file_content, max_chars)
elif file_type == "py":
return split_python_by_defs(file_content, max_chars)
else:
return [file_content]

def split_latex_by_sections(latex_source, max_chars):
# Regex to match section headers
pattern = r'(\\section\{.*?\})'
return split_by_regex(latex_source, pattern, max_chars)

def split_rst_by_headings(rst_source, max_chars):
# Regular expression to match RST headings (title and underline)
pattern = r'(^.*\n(={3,}|-{3,}|`{3,}|:{3,}|\+{3,}|\*{3,}|\#{3,}|\^{3,}|"{3,}|~{3,})$)'
return split_by_regex(rst_source, pattern, max_chars)

def split_python_by_defs(python_source, max_chars):
# Regular expression to match Python toplevel functions and classes
pattern = r'(^(def|class) )'
return split_by_regex(python_source, pattern, max_chars)

def split_by_regex(text, regex_pattern, max_chars):
# Split to chunks
regex = re.compile(regex_pattern, re.MULTILINE)
matches = list(regex.finditer(text))
split_indices = [0] + [match.start() for match in matches] + [len(text)]
chunks = [text[split_indices[i]:split_indices[i+1]] for i in range(1, len(split_indices)-1)]

# Merge smaller chunks
parts = []
current_part = ""
for chunk in chunks:
if len(current_part) + len(chunk) > max_chars:
parts.append(current_part)
current_part = ""
current_part += chunk
parts.append(current_part)

return parts

def extract(reply_text, original_content):
content = reply_text
if content.count("```") >= 2:
content = re.sub(r"^.*```.*?\n(.*\n)```.*$", r"\1", content, 1, re.DOTALL)

trailing_whitespace_len = len(original_content) - len(original_content.rstrip())
original_trailing_whitespace = original_content[-trailing_whitespace_len:]
content = content.rstrip() + original_trailing_whitespace
return content

def apply_command_to_file(file_path, context_files, file_type, task, prompt, model, chunk_size=None, save_prompt=False):
with open(file_path, 'r', encoding='utf-8') as file:
file_content = file.read()

context_files = context_files or []
context_files += find_additional_context_files(file_path, file_type, task)
if context_files:
print(" context files: " + " ".join(context_files))
context = read_files(context_files)

if chunk_size is None:
chunk_size = get_llm_context_window(model)*2 - len(context) # assume average token length of 2 chars
modified_content = ""
parts = split_content(file_content, file_type, chunk_size)
for i, part in enumerate(parts):
if len(parts) > 1:
print(f" part {i + 1}/{len(parts)}")
prompt = create_prompt(part, context, task, prompt, file_type)

if save_prompt:
with open(file_path+".prompt"+str(i), 'w', encoding='utf-8') as file:
file.write(prompt)

reply_text = invoke_llm(prompt, model)
modified_content += extract(reply_text, part)

with open(file_path, 'w', encoding='utf-8') as file:
file.write(modified_content)
_logger.debug(f"Modified {file_path} successfully.")

def apply_command_to_files(file_list, context_files, file_type, task, prompt, model, chunk_size=None, save_prompt=False):
n = len(file_list)
for i, file_path in enumerate(file_list):
try:
print(f"Processing file {i + 1}/{n} {file_path}")
apply_command_to_file(file_path, context_files, file_type, task, prompt, model, chunk_size=chunk_size, save_prompt=save_prompt)
except Exception as e:
print(f"-> Exception: {e}")

def resolve_file_list(paths, file_type, file_ext=None):
file_extension_patterns = {
"md": r".*.md$",
"rst": r".*.rst$",
"tex": r".*.tex$",
"ned": r".*.ned$",
"py": r".*.py$"
}

if file_ext:
filename_regex = re.compile(rf".*.{file_ext}$")
else:
if file_type not in file_extension_patterns:
raise ValueError("Unsupported file type.")
filename_regex = re.compile(file_extension_patterns[file_type])

file_list = []
for path in paths:
if os.path.isdir(path):
file_list.extend(collect_matching_file_paths(path, filename_regex))
elif os.path.isfile(path) and filename_regex.match(path):
file_list.append(path)
return sorted(file_list)

def process_files(paths, context_files, file_type, file_ext, task, prompt, model_name, chunk_size=None, save_prompt=False):
if not model_name:
model_name = get_recommended_model(task)
model = llm.get_model(model_name)
model.key = ''
file_list = resolve_file_list(paths, file_type, file_ext)
print("Files to process: " + " ".join(file_list))
print("Using LLM: " + model_name)
del model_name, file_ext, paths
apply_command_to_files(**locals())

def main():
parser = argparse.ArgumentParser(description="Process and improve specific types of files in a given directory or files.")
parser.add_argument("paths", type=str, nargs='+', help="The directories or files to process.")
parser.add_argument("--file-type", type=str, choices=["md", "rst", "tex", "ned", "py"], help="The type of files to process.")
parser.add_argument("--file-ext", type=str, help="The extension of files to process. Takes precedence over --file-type.")
parser.add_argument("--task", type=str, choices=["proofread", "improve-language", "eliminate-you-addressing", "neddoc"], help="The task to perform on the files.")
parser.add_argument("--prompt", type=str, help="The LLM prompt to use. Generic instructions and the content of context files will be appended. Takes precedence over --task.")
parser.add_argument("--model", type=str, dest="model_name", default=None, help="The name of the LLM model to use.")
parser.add_argument("--context", type=str, nargs='*', dest="context_files", help="The context files to be used.")
parser.add_argument("--chunk-size", type=int, default=None, help="The maximum number of characters to be sent to the LLM model at once.")
parser.add_argument("--save-prompt", action='store_true', help="Save the LLM prompt for each input file as <filename>.prompt.")

args = parser.parse_args()
process_files(**vars(args))

if __name__ == "__main__":
main()

0 comments on commit 3af5f7e

Please sign in to comment.