diff --git a/build.gradle.kts b/build.gradle.kts index 188a4dd9..bcc04254 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -28,6 +28,10 @@ dependencies { implementation("com.vladsch.flexmark:flexmark-all:0.64.8") implementation("org.apache.commons:commons-text:1.10.0") implementation("com.knuddels:jtokkit:1.0.0") + implementation("io.github.bonede:tree-sitter:0.22.6") + implementation("io.github.bonede:tree-sitter-java:0.21.0a") + implementation("io.github.bonede:tree-sitter-python:0.21.0a") + implementation("io.github.bonede:tree-sitter-go:0.21.0a") compileOnly("com.puppycrawl.tools:checkstyle:10.9.1") testImplementation("org.mockito:mockito-core:5.7.0") } diff --git a/src/main/java/com/zhongan/devpilot/completions/inline/CompletionPreview.java b/src/main/java/com/zhongan/devpilot/completions/inline/CompletionPreview.java index c443c5ae..06b741b1 100644 --- a/src/main/java/com/zhongan/devpilot/completions/inline/CompletionPreview.java +++ b/src/main/java/com/zhongan/devpilot/completions/inline/CompletionPreview.java @@ -17,6 +17,7 @@ import com.zhongan.devpilot.completions.inline.listeners.InlineCaretListener; import com.zhongan.devpilot.completions.inline.render.DevPilotInlay; import com.zhongan.devpilot.completions.prediction.DevPilotCompletion; +import com.zhongan.devpilot.treesitter.TreeSitterParser; import com.zhongan.devpilot.util.TelemetryUtils; import java.util.List; @@ -162,9 +163,13 @@ private void applyPreviewInternal(@NotNull Integer cursorOffset, Project project editor.getDocument().deleteString(cursorOffset, cursorOffset + completion.oldSuffix.length()); } - //TODO 代码自动格式化 + var name = file.getName(); + var fileExtension = name.substring(name.lastIndexOf(".") + 1); + suffix = TreeSitterParser.getInstance(fileExtension) + .parse(editor.getDocument().getText(), cursorOffset, suffix); + editor.getDocument().insertString(cursorOffset, suffix); - editor.getCaretModel().moveToOffset(startOffset + completion.newPrefix.length()); + editor.getCaretModel().moveToOffset(startOffset + suffix.length()); PsiDocumentManager.getInstance(project).commitAllDocuments(); diff --git a/src/main/java/com/zhongan/devpilot/treesitter/TreeSitterParser.java b/src/main/java/com/zhongan/devpilot/treesitter/TreeSitterParser.java new file mode 100644 index 00000000..a5731aa5 --- /dev/null +++ b/src/main/java/com/zhongan/devpilot/treesitter/TreeSitterParser.java @@ -0,0 +1,111 @@ +package com.zhongan.devpilot.treesitter; + +import com.zhongan.devpilot.util.LanguageUtil; + +import java.util.Locale; + +import org.treesitter.TSLanguage; +import org.treesitter.TSParser; +import org.treesitter.TSTree; +import org.treesitter.TreeSitterGo; +import org.treesitter.TreeSitterJava; +import org.treesitter.TreeSitterPython; + +public class TreeSitterParser { + private final TSLanguage language; + + public TreeSitterParser(TSLanguage language) { + this.language = language; + } + + public String clearRedundantWhitespace(String originCode, int position, String output) { + if (language == null) { + return output; + } + + var result = new StringBuilder(output); + while (result.length() != 0 && result.charAt(0) == ' ') { + result.deleteCharAt(0); + if (containsError(buildFullCode(originCode, position, result.toString()))) { + return " " + result; + } + } + + return result.toString(); + } + + public String parse(String originCode, int position, String output) { + if (!output.startsWith(" ")) { + return parseInner(originCode, position, output); + } + + // handle special case : start with several whitespace + var noWhitespaceResult = parseInner(originCode, position, output.trim()); + var whitespaceResult = parseInner(originCode, position, " " + output.trim()); + + var result = whitespaceResult.length() < noWhitespaceResult.length() + ? noWhitespaceResult : whitespaceResult; + + return clearRedundantWhitespace(originCode, position, result); + } + + private String parseInner(String originCode, int position, String output) { + if (language == null) { + return output; + } + + var result = new StringBuilder(output); + while (result.length() != 0) { + if (containsError(buildFullCode(originCode, position, result.toString()))) { + result.deleteCharAt(result.length() - 1); + } else { + return result.toString(); + } + } + + return output; + } + + private String buildFullCode(String originCode, int position, String output) { + StringBuilder stringBuilder = new StringBuilder(originCode); + stringBuilder.insert(position, output); + return stringBuilder.toString(); + } + + private boolean containsError(String input) { + var treeString = getTree(input).getRootNode().toString(); + return treeString.contains("ERROR") + || treeString.contains("MISSING \"}\"") + || treeString.contains("MISSING \")\""); + } + + private TSTree getTree(String input) { + var parser = new TSParser(); + parser.setLanguage(language); + return parser.parseString(null, input); + } + + public static TreeSitterParser getInstance(String extension) { + var language = LanguageUtil.getLanguageByExtension(extension); + + if (language == null) { + return new TreeSitterParser(null); + } + + TSLanguage tsLanguage = null; + + switch (language.getLanguageName().toLowerCase(Locale.ROOT)) { + case "java": + tsLanguage = new TreeSitterJava(); + break; + case "go": + tsLanguage = new TreeSitterGo(); + break; + case "python": + tsLanguage = new TreeSitterPython(); + break; + } + + return new TreeSitterParser(tsLanguage); + } +} diff --git a/src/test/java/com/zhongan/devpilot/treesitter/TreeSitterParserTest.java b/src/test/java/com/zhongan/devpilot/treesitter/TreeSitterParserTest.java new file mode 100644 index 00000000..17910565 --- /dev/null +++ b/src/test/java/com/zhongan/devpilot/treesitter/TreeSitterParserTest.java @@ -0,0 +1,97 @@ +package com.zhongan.devpilot.treesitter; + +import org.junit.Assert; +import org.junit.Test; + +public class TreeSitterParserTest { + @Test + public void testJavaParse() { + var wholeFile = "package org.example;\n" + + "\n" + + "public class AgentTest {\n" + + " public static void main(String[] args) {\n" + + " System.out.println(\"\");\n" + + " }\n" + + "}\n" + + "\n" + + "\n"; + var position = 120; + var insertCode = "AgentTest completed.\");"; + + var parser = TreeSitterParser.getInstance("java"); + var result = parser.parse(wholeFile, position, insertCode); + + Assert.assertEquals(result, "AgentTest completed."); + + wholeFile = "package org.example;\n" + + "\n" + + "public class AgentTest {\n" + + " public static void main(String[] args) {\n" + + " System.out.println(\"\");\n" + + " }\n" + + "}\n" + + "\n" + + "\n"; + insertCode = " AgentTest completed.\");"; + + parser = TreeSitterParser.getInstance("java"); + result = parser.parse(wholeFile, position, insertCode); + + Assert.assertEquals(result, "AgentTest completed."); + + wholeFile = "package org.example;\n" + + "\n" + + "public class AgentTest {\n" + + " public static void main(String[] args) {\n" + + " String" + + " System.out.println(\"AgentTest completed\");\n" + + " }\n" + + "}\n" + + "\n" + + "\n"; + position = 106; + insertCode = " str = \"AgentTest\";"; + + parser = TreeSitterParser.getInstance("java"); + result = parser.parse(wholeFile, position, insertCode); + + Assert.assertEquals(result, "str = \"AgentTest\";"); + + wholeFile = "package org.example;\n" + + "\n" + + "public class AgentTest {\n" + + " public static void main(String[] args) {\n" + + " Sys" + + " System.out.println(\"AgentTest completed2\");\n" + + " }\n" + + "}\n" + + "\n" + + "\n"; + position = 103; + insertCode = " tem.out.println(\"AgentTest completed1\");"; + + parser = TreeSitterParser.getInstance("java"); + result = parser.parse(wholeFile, position, insertCode); + + Assert.assertEquals(result, "tem.out.println(\"AgentTest completed1\");"); + } + + @Test + public void testPythonParse() { + var wholeFile = "import pandas as pd\n" + + "\n" + + "if __name__ == '__main__':\n" + + " df = pd.read_excel(r'test.xlsx')\n" + + " for idx, row in tqdm(df.iterrows(), t ):\n" + + "\n" + + "\n" + + " print('done')"; + var position = 126; + var insertCode = "otal=df.shape[0]):"; + + var parser = TreeSitterParser.getInstance("py"); + var result = parser.parse(wholeFile, position, insertCode); + + Assert.assertEquals(result, "otal=df.shape[0]"); + } +}