Skip to content

Commit

Permalink
improve completion result, handle some special case of specific AI mo…
Browse files Browse the repository at this point in the history
…del (#44)
  • Loading branch information
xiangtianyu authored Jul 16, 2024
1 parent 69b2078 commit 234f0bd
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 2 deletions.
4 changes: 4 additions & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ dependencies {
implementation("com.vladsch.flexmark:flexmark-all:0.64.8")
implementation("org.apache.commons:commons-text:1.10.0")
implementation("com.knuddels:jtokkit:1.0.0")
implementation("io.github.bonede:tree-sitter:0.22.6")
implementation("io.github.bonede:tree-sitter-java:0.21.0a")
implementation("io.github.bonede:tree-sitter-python:0.21.0a")
implementation("io.github.bonede:tree-sitter-go:0.21.0a")
compileOnly("com.puppycrawl.tools:checkstyle:10.9.1")
testImplementation("org.mockito:mockito-core:5.7.0")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.zhongan.devpilot.completions.inline.listeners.InlineCaretListener;
import com.zhongan.devpilot.completions.inline.render.DevPilotInlay;
import com.zhongan.devpilot.completions.prediction.DevPilotCompletion;
import com.zhongan.devpilot.treesitter.TreeSitterParser;
import com.zhongan.devpilot.util.TelemetryUtils;

import java.util.List;
Expand Down Expand Up @@ -162,9 +163,13 @@ private void applyPreviewInternal(@NotNull Integer cursorOffset, Project project
editor.getDocument().deleteString(cursorOffset, cursorOffset + completion.oldSuffix.length());
}

//TODO 代码自动格式化
var name = file.getName();
var fileExtension = name.substring(name.lastIndexOf(".") + 1);
suffix = TreeSitterParser.getInstance(fileExtension)
.parse(editor.getDocument().getText(), cursorOffset, suffix);

editor.getDocument().insertString(cursorOffset, suffix);
editor.getCaretModel().moveToOffset(startOffset + completion.newPrefix.length());
editor.getCaretModel().moveToOffset(startOffset + suffix.length());

PsiDocumentManager.getInstance(project).commitAllDocuments();

Expand Down
111 changes: 111 additions & 0 deletions src/main/java/com/zhongan/devpilot/treesitter/TreeSitterParser.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package com.zhongan.devpilot.treesitter;

import com.zhongan.devpilot.util.LanguageUtil;

import java.util.Locale;

import org.treesitter.TSLanguage;
import org.treesitter.TSParser;
import org.treesitter.TSTree;
import org.treesitter.TreeSitterGo;
import org.treesitter.TreeSitterJava;
import org.treesitter.TreeSitterPython;

public class TreeSitterParser {
private final TSLanguage language;

public TreeSitterParser(TSLanguage language) {
this.language = language;
}

public String clearRedundantWhitespace(String originCode, int position, String output) {
if (language == null) {
return output;
}

var result = new StringBuilder(output);
while (result.length() != 0 && result.charAt(0) == ' ') {
result.deleteCharAt(0);
if (containsError(buildFullCode(originCode, position, result.toString()))) {
return " " + result;
}
}

return result.toString();
}

public String parse(String originCode, int position, String output) {
if (!output.startsWith(" ")) {
return parseInner(originCode, position, output);
}

// handle special case : start with several whitespace
var noWhitespaceResult = parseInner(originCode, position, output.trim());
var whitespaceResult = parseInner(originCode, position, " " + output.trim());

var result = whitespaceResult.length() < noWhitespaceResult.length()
? noWhitespaceResult : whitespaceResult;

return clearRedundantWhitespace(originCode, position, result);
}

private String parseInner(String originCode, int position, String output) {
if (language == null) {
return output;
}

var result = new StringBuilder(output);
while (result.length() != 0) {
if (containsError(buildFullCode(originCode, position, result.toString()))) {
result.deleteCharAt(result.length() - 1);
} else {
return result.toString();
}
}

return output;
}

private String buildFullCode(String originCode, int position, String output) {
StringBuilder stringBuilder = new StringBuilder(originCode);
stringBuilder.insert(position, output);
return stringBuilder.toString();
}

private boolean containsError(String input) {
var treeString = getTree(input).getRootNode().toString();
return treeString.contains("ERROR")
|| treeString.contains("MISSING \"}\"")
|| treeString.contains("MISSING \")\"");
}

private TSTree getTree(String input) {
var parser = new TSParser();
parser.setLanguage(language);
return parser.parseString(null, input);
}

public static TreeSitterParser getInstance(String extension) {
var language = LanguageUtil.getLanguageByExtension(extension);

if (language == null) {
return new TreeSitterParser(null);
}

TSLanguage tsLanguage = null;

switch (language.getLanguageName().toLowerCase(Locale.ROOT)) {
case "java":
tsLanguage = new TreeSitterJava();
break;
case "go":
tsLanguage = new TreeSitterGo();
break;
case "python":
tsLanguage = new TreeSitterPython();
break;
}

return new TreeSitterParser(tsLanguage);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package com.zhongan.devpilot.treesitter;

import org.junit.Assert;
import org.junit.Test;

public class TreeSitterParserTest {
@Test
public void testJavaParse() {
var wholeFile = "package org.example;\n" +
"\n" +
"public class AgentTest {\n" +
" public static void main(String[] args) {\n" +
" System.out.println(\"\");\n" +
" }\n" +
"}\n" +
"\n" +
"\n";
var position = 120;
var insertCode = "AgentTest completed.\");";

var parser = TreeSitterParser.getInstance("java");
var result = parser.parse(wholeFile, position, insertCode);

Assert.assertEquals(result, "AgentTest completed.");

wholeFile = "package org.example;\n" +
"\n" +
"public class AgentTest {\n" +
" public static void main(String[] args) {\n" +
" System.out.println(\"\");\n" +
" }\n" +
"}\n" +
"\n" +
"\n";
insertCode = " AgentTest completed.\");";

parser = TreeSitterParser.getInstance("java");
result = parser.parse(wholeFile, position, insertCode);

Assert.assertEquals(result, "AgentTest completed.");

wholeFile = "package org.example;\n" +
"\n" +
"public class AgentTest {\n" +
" public static void main(String[] args) {\n" +
" String" +
" System.out.println(\"AgentTest completed\");\n" +
" }\n" +
"}\n" +
"\n" +
"\n";
position = 106;
insertCode = " str = \"AgentTest\";";

parser = TreeSitterParser.getInstance("java");
result = parser.parse(wholeFile, position, insertCode);

Assert.assertEquals(result, "str = \"AgentTest\";");

wholeFile = "package org.example;\n" +
"\n" +
"public class AgentTest {\n" +
" public static void main(String[] args) {\n" +
" Sys" +
" System.out.println(\"AgentTest completed2\");\n" +
" }\n" +
"}\n" +
"\n" +
"\n";
position = 103;
insertCode = " tem.out.println(\"AgentTest completed1\");";

parser = TreeSitterParser.getInstance("java");
result = parser.parse(wholeFile, position, insertCode);

Assert.assertEquals(result, "tem.out.println(\"AgentTest completed1\");");
}

@Test
public void testPythonParse() {
var wholeFile = "import pandas as pd\n" +
"\n" +
"if __name__ == '__main__':\n" +
" df = pd.read_excel(r'test.xlsx')\n" +
" for idx, row in tqdm(df.iterrows(), t ):\n" +
"\n" +
"\n" +
" print('done')";
var position = 126;
var insertCode = "otal=df.shape[0]):";

var parser = TreeSitterParser.getInstance("py");
var result = parser.parse(wholeFile, position, insertCode);

Assert.assertEquals(result, "otal=df.shape[0]");
}
}

0 comments on commit 234f0bd

Please sign in to comment.