Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to OverleafGitPaperRemote #25

Merged
merged 13 commits into from
Jun 22, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 92 additions & 43 deletions llm4papers/paper_remote/OverleafGitPaperRemote.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,46 @@
import datetime
from urllib.parse import quote
from git import Repo # type: ignore

from llm4papers.models import EditTrigger, EditResult, EditType, DocumentID, RevisionID
from typing import Iterable
import re

from llm4papers.models import (
EditTrigger,
EditResult,
EditType,
DocumentID,
RevisionID,
LineRange,
)
from llm4papers.paper_remote.MultiDocumentPaperRemote import MultiDocumentPaperRemote
from llm4papers.logger import logger


diff_line_edit_re = re.compile(
r"@{2,}\s*-(?P<old_line>\d+),(?P<old_count>\d+)\s*\+(?P<new_line>\d+),(?P<new_count>\d+)\s*@{2,}"
wrongu marked this conversation as resolved.
Show resolved Hide resolved
)


def _diff_to_ranges(diff: str) -> Iterable[LineRange]:
"""Given a git diff, return LineRange object(s) indicating which lines in the
original document were changed.
"""
for match in diff_line_edit_re.finditer(diff):
git_line_number = int(match.group("new_line"))
git_line_count = int(match.group("new_count"))
# Git counts from 1 and gives (start, length), inclusive. LineRange counts from
# 0 and gives start:end indices (exclusive).
zero_index_start = git_line_number - 1
yield zero_index_start, zero_index_start + git_line_count


def _ranges_overlap(a: LineRange, b: LineRange) -> bool:
"""Given two LineRanges, return True if they overlap, False otherwise."""
return not (a[1] < b[0] or b[1] < a[0])


def _too_close_to_human_edits(
repo: Repo, filename: str, line_number: int, last_n: int = 2
repo: Repo, filename: str, line_range: LineRange, last_n: int = 2
) -> bool:
"""
Determine if the line `line_number` of the file `filename` was changed in
Expand Down Expand Up @@ -44,19 +76,13 @@ def _too_close_to_human_edits(
# Get the diff for HEAD~n:
total_diff = repo.git.diff(f"HEAD~{last_n}", filename, unified=0)

# Get the current repo state of that line:
current_line = repo.git.show(f"HEAD:{filename}").split("\n")[line_number]

logger.debug("Diff: " + total_diff)
logger.debug("Current line: " + current_line)

# Match the line in the diff:
if current_line in total_diff:
logger.info(
f"Found current line ({current_line[:10]}...) in diff, rejecting edit."
)
return True

for git_line_range in _diff_to_ranges(total_diff):
if _ranges_overlap(git_line_range, line_range):
logger.info(
f"Line range {line_range} overlaps with git-edited {git_line_range}, "
f"rejecting edit."
)
return True
return False


Expand Down Expand Up @@ -84,8 +110,6 @@ class OverleafGitPaperRemote(MultiDocumentPaperRemote):
PaperRemote protocol for use by the AI editor.
"""

current_revision_id: RevisionID

def __init__(self, git_cached_repo: str):
"""
Saves the git repo to a local temporary directory using gitpython.
Expand All @@ -100,6 +124,10 @@ def __init__(self, git_cached_repo: str):
self._cached_repo: Repo | None = None
self.refresh()

@property
def current_revision_id(self):
return self._get_repo().head.commit.hexsha

def _get_repo(self) -> Repo:
if self._cached_repo is None:
# TODO - this makes me anxious about race conditions. every time we refresh,
Expand All @@ -119,7 +147,7 @@ def _doc_id_to_path(self, doc_id: DocumentID) -> pathlib.Path:
# so we can cast to a string on this next line:
return pathlib.Path(git_root) / str(doc_id)

def refresh(self):
def refresh(self, retry: bool = True):
"""
This is a fallback method (that likely needs some love) to ensure that
the repo is up to date with the latest upstream changes.
Expand All @@ -143,7 +171,6 @@ def refresh(self):
f"Latest change at {self._get_repo().head.commit.committed_datetime}"
)
logger.info(f"Repo dirty: {self._get_repo().is_dirty()}")
self.current_revision_id = self._get_repo().head.commit.hexsha
try:
self._get_repo().git.stash("pop")
except Exception as e:
Expand All @@ -161,7 +188,10 @@ def refresh(self):
self._cached_repo = None
# recursively delete the repo
shutil.rmtree(f"/tmp/{self._reposlug}")
self.refresh()
if retry:
self.refresh(retry=False)
else:
raise e

def list_doc_ids(self) -> list[DocumentID]:
"""
Expand Down Expand Up @@ -196,14 +226,15 @@ def is_edit_ok(self, edit: EditTrigger) -> bool:
# want to wait for the user to move on to the next line.
for doc_range in edit.input_ranges + edit.output_ranges:
repo_scoped_file = str(self._doc_id_to_path(doc_range.doc_id))
for i in range(doc_range.selection[0], doc_range.selection[1]):
if _too_close_to_human_edits(self._get_repo(), repo_scoped_file, i):
logger.info(
f"Temporarily skipping edit request in {doc_range.doc_id}"
" at line {i} because it was still in progress"
" in the last commit."
)
return False
if _too_close_to_human_edits(
self._get_repo(), repo_scoped_file, doc_range.selection
):
logger.info(
f"Temporarily skipping edit request in {doc_range.doc_id}"
" at line {i} because it was still in progress"
" in the last commit."
)
return False
return True

def to_dict(self):
Expand All @@ -223,22 +254,15 @@ def perform_edit(self, edit: EditResult) -> bool:
"""
logger.info(f"Performing edit {edit} on remote {self._reposlug}")

if edit.type == EditType.replace:
success = self._perform_replace(edit)
elif edit.type == EditType.comment:
success = self._perform_comment(edit)
else:
raise ValueError(f"Unknown edit type {edit.type}")
with self.rewind(edit.range.revision_id, message="AI edit") as paper:
if edit.type == EditType.replace:
success = paper._perform_replace(edit)
elif edit.type == EditType.comment:
success = paper._perform_comment(edit)
else:
raise ValueError(f"Unknown edit type {edit.type}")

if success:
# TODO - apply edit relative to the edit.range.revision_id commit and then
# rebase onto HEAD for poor-man's operational transforms
self._get_repo().index.add([self._doc_id_to_path(str(edit.range.doc_id))])
self._get_repo().index.commit("AI edit completed.")
# Instead of just pushing, we need to rebase and then push.
# This is because we want to make sure that the AI edits are always
# on top of the stack.
self._get_repo().git.pull()
# TODO: We could do a better job catching WARNs here and then maybe setting
# success = False
self._get_repo().git.push()
Expand Down Expand Up @@ -284,3 +308,28 @@ def _perform_comment(self, edit: EditResult) -> bool:
# TODO - implement this for real
logger.info(f"Performing comment edit {edit} on remote {self._reposlug}")
return True

def rewind(self, commit: str, message: str):
return self.RewindContext(self, commit, message)

# Create an inner class for "with" semantics so that we can do
# `with remote.rewind(commit)` to rewind to a particular commit and play some edits
# onto it, then merge when the 'with' context exits.
class RewindContext:
wrongu marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, remote: "OverleafGitPaperRemote", commit: str, message: str):
self._remote = remote
self._message = message
self._rewind_commit = commit
self._restore_branch = remote._get_repo().active_branch

def __enter__(self):
self._remote._get_repo().git.checkout(self._rewind_commit)
self._remote._get_repo().git.checkout(b="tmp-edit-branch")
return self._remote

def __exit__(self, exc_type, exc_val, exc_tb):
self._remote._get_repo().git.add(all=True)
self._remote._get_repo().index.commit(self._message)
self._remote._get_repo().git.checkout(self._restore_branch.name)
self._remote._get_repo().git.merge("tmp-edit-branch")
self._remote._get_repo().git.branch("-D", "tmp-edit-branch")
Loading