diff --git a/src/main/java/com/zhongan/devpilot/DevPilotStartupActivity.java b/src/main/java/com/zhongan/devpilot/DevPilotStartupActivity.java index a107df4d..49c0cb0b 100644 --- a/src/main/java/com/zhongan/devpilot/DevPilotStartupActivity.java +++ b/src/main/java/com/zhongan/devpilot/DevPilotStartupActivity.java @@ -3,7 +3,6 @@ import com.intellij.openapi.project.Project; import com.intellij.openapi.startup.StartupActivity; import com.zhongan.devpilot.actions.editor.popupmenu.PopupMenuEditorActionGroupUtil; -import com.zhongan.devpilot.listener.DevPilotFileEditorListener; import com.zhongan.devpilot.network.DevPilotAvailabilityChecker; import com.zhongan.devpilot.update.DevPilotUpdate; @@ -13,7 +12,6 @@ public class DevPilotStartupActivity implements StartupActivity { @Override public void runActivity(@NotNull Project project) { PopupMenuEditorActionGroupUtil.refreshActions(project); - DevPilotFileEditorListener.registerListener(); new DevPilotUpdate.DevPilotUpdateTask(project).queue(); new DevPilotAvailabilityChecker(project).checkNetworkAndLogStatus(); diff --git a/src/main/java/com/zhongan/devpilot/actions/changesview/GenerateGitCommitMessageAction.java b/src/main/java/com/zhongan/devpilot/actions/changesview/GenerateGitCommitMessageAction.java index c3b6d6e1..ccf3eefd 100644 --- a/src/main/java/com/zhongan/devpilot/actions/changesview/GenerateGitCommitMessageAction.java +++ b/src/main/java/com/zhongan/devpilot/actions/changesview/GenerateGitCommitMessageAction.java @@ -45,6 +45,8 @@ import git4idea.repo.GitRepository; import git4idea.repo.GitRepositoryManager; +import static com.zhongan.devpilot.constant.DefaultConst.GIT_COMMIT_PROMPT_VERSION; + public class GenerateGitCommitMessageAction extends AnAction { private static final Logger log = Logger.getInstance(GenerateGitCommitMessageAction.class); @@ -97,8 +99,8 @@ public void run(@NotNull ProgressIndicator progressIndicator) { if (editor != null) { ((EditorEx) editor).setCaretVisible(false); DevPilotChatCompletionRequest devPilotChatCompletionRequest = new DevPilotChatCompletionRequest(); - devPilotChatCompletionRequest.setVersion("V240801"); - devPilotChatCompletionRequest.getMessages().add(MessageUtil.createPromptMessage("-1", "GENERATE_COMMIT", Map.of("locale", getLocale(), "diff", diff))); + devPilotChatCompletionRequest.setVersion(GIT_COMMIT_PROMPT_VERSION); + devPilotChatCompletionRequest.getMessages().add(MessageUtil.createPromptMessage(System.currentTimeMillis() + "", "GENERATE_COMMIT", Map.of("locale", getLocale(), "diff", diff))); devPilotChatCompletionRequest.setStream(Boolean.FALSE); var llmProvider = new LlmProviderFactory().getLlmProvider(project); DevPilotChatCompletionResponse result = llmProvider.chatCompletionSync(devPilotChatCompletionRequest); diff --git a/src/main/java/com/zhongan/devpilot/actions/editor/GenerateMethodCommentAction.java b/src/main/java/com/zhongan/devpilot/actions/editor/GenerateMethodCommentAction.java deleted file mode 100644 index d648896e..00000000 --- a/src/main/java/com/zhongan/devpilot/actions/editor/GenerateMethodCommentAction.java +++ /dev/null @@ -1,22 +0,0 @@ -package com.zhongan.devpilot.actions.editor; - -import com.zhongan.devpilot.enums.EditorActionEnum; -import com.zhongan.devpilot.util.DevPilotMessageBundle; - -public class GenerateMethodCommentAction extends SelectedCodeGenerateBaseAction { - - @Override - protected EditorActionEnum getEditorActionEnum() { - return EditorActionEnum.COMMENT_METHOD; - } - - @Override - protected String getShowText() { - return DevPilotMessageBundle.get("devpilot.inlay.shortcut.methodComments"); - } - - @Override - protected void handleValidResult(String result) { - - } -} diff --git a/src/main/java/com/zhongan/devpilot/actions/editor/SelectedCodeGenerateBaseAction.java b/src/main/java/com/zhongan/devpilot/actions/editor/SelectedCodeGenerateBaseAction.java deleted file mode 100644 index 69823a5f..00000000 --- a/src/main/java/com/zhongan/devpilot/actions/editor/SelectedCodeGenerateBaseAction.java +++ /dev/null @@ -1,72 +0,0 @@ -package com.zhongan.devpilot.actions.editor; - -import com.intellij.openapi.actionSystem.AnAction; -import com.intellij.openapi.actionSystem.AnActionEvent; -import com.intellij.openapi.editor.Editor; -import com.intellij.openapi.fileEditor.FileEditorManager; -import com.intellij.openapi.project.Project; -import com.intellij.openapi.wm.ToolWindow; -import com.intellij.openapi.wm.ToolWindowManager; -import com.zhongan.devpilot.actions.notifications.DevPilotNotification; -import com.zhongan.devpilot.enums.EditorActionEnum; -import com.zhongan.devpilot.enums.SessionTypeEnum; -import com.zhongan.devpilot.gui.toolwindows.chat.DevPilotChatToolWindowService; -import com.zhongan.devpilot.gui.toolwindows.components.EditorInfo; -import com.zhongan.devpilot.settings.state.DevPilotLlmSettingsState; -import com.zhongan.devpilot.util.DevPilotMessageBundle; -import com.zhongan.devpilot.webview.model.CodeReferenceModel; -import com.zhongan.devpilot.webview.model.MessageModel; - -import java.util.Map; -import java.util.UUID; -import java.util.function.Consumer; - -import org.jetbrains.annotations.NotNull; - -import static com.zhongan.devpilot.actions.editor.popupmenu.PopupMenuEditorActionGroupUtil.validateResult; -import static com.zhongan.devpilot.enums.EditorActionEnum.COMMENT_METHOD; - -public abstract class SelectedCodeGenerateBaseAction extends AnAction { - - @Override - public void actionPerformed(@NotNull AnActionEvent e) { - Project project = e.getProject(); - if (project == null) { - return; - } - - ToolWindow toolWindow = ToolWindowManager.getInstance(project).getToolWindow("DevPilot"); - if (toolWindow != null) { - toolWindow.show(); - } - - Consumer callback = result -> { - if (validateResult(result)) { - DevPilotNotification.info(DevPilotMessageBundle.get("devpilot.notification.input.tooLong")); - } - handleValidResult(result); - }; - - Editor editor = FileEditorManager.getInstance(project).getSelectedTextEditor(); - String selectedText = editor.getSelectionModel().getSelectedText(); - - EditorInfo editorInfo = new EditorInfo(editor); - var service = project.getService(DevPilotChatToolWindowService.class); - var username = DevPilotLlmSettingsState.getInstance().getFullName(); - service.clearRequestSession(); - - var showText = getShowText(); - var codeReference = CodeReferenceModel.getCodeRefFromEditor(editorInfo, getEditorActionEnum()); - - var codeMessage = MessageModel.buildCodeMessage( - UUID.randomUUID().toString(), System.currentTimeMillis(), showText, username, codeReference); - - service.sendMessage(SessionTypeEnum.MULTI_TURN.getCode(), COMMENT_METHOD.name(), Map.of("selectedCode", selectedText), null, callback, codeMessage); - } - - protected abstract EditorActionEnum getEditorActionEnum(); - - protected abstract String getShowText(); - - protected abstract void handleValidResult(String result); -} diff --git a/src/main/java/com/zhongan/devpilot/actions/editor/inlay/ChatShortcutHintCollector.java b/src/main/java/com/zhongan/devpilot/actions/editor/inlay/ChatShortcutHintCollector.java index f97c408c..6f29482b 100644 --- a/src/main/java/com/zhongan/devpilot/actions/editor/inlay/ChatShortcutHintCollector.java +++ b/src/main/java/com/zhongan/devpilot/actions/editor/inlay/ChatShortcutHintCollector.java @@ -6,21 +6,15 @@ import com.intellij.codeInsight.hints.presentation.PresentationFactory; import com.intellij.codeInsight.hints.presentation.SequencePresentation; import com.intellij.icons.AllIcons; -import com.intellij.openapi.actionSystem.ActionManager; -import com.intellij.openapi.actionSystem.AnAction; -import com.intellij.openapi.actionSystem.AnActionEvent; -import com.intellij.openapi.actionSystem.DataContext; -import com.intellij.openapi.actionSystem.Presentation; -import com.intellij.openapi.actionSystem.impl.SimpleDataContext; import com.intellij.openapi.application.ApplicationManager; import com.intellij.openapi.editor.CaretModel; import com.intellij.openapi.editor.Document; import com.intellij.openapi.editor.Editor; -import com.intellij.openapi.fileEditor.FileDocumentManager; import com.intellij.openapi.fileEditor.FileEditorManager; import com.intellij.openapi.fileTypes.FileType; import com.intellij.openapi.fileTypes.FileTypeManager; import com.intellij.openapi.project.Project; +import com.intellij.openapi.roots.FileIndexFacade; import com.intellij.openapi.ui.popup.JBPopup; import com.intellij.openapi.ui.popup.JBPopupFactory; import com.intellij.openapi.ui.popup.PopupStep; @@ -72,7 +66,7 @@ public boolean collect(@NotNull PsiElement psiElement, @NotNull Editor editor, @ return true; } - boolean isSourceCode = isSourceCode(editor); + boolean isSourceCode = isSourceCode(psiElement, editor); var elementType = PsiUtilCore.getElementType(psiElement).toString(); if ("CLASS".equals(elementType)) { @@ -174,14 +168,15 @@ private List buildInlayPresentationGroupData() { return options; } - public static boolean isSourceCode(Editor editor) { + public static boolean isSourceCode(PsiElement psiElement, Editor editor) { Document document = editor.getDocument(); - VirtualFile file = FileDocumentManager.getInstance().getFile(document); - if (file == null || !file.isWritable()) { + VirtualFile virtualFile = PsiUtilCore.getVirtualFile(psiElement); + if (virtualFile == null || !virtualFile.isWritable()) { return false; } - FileType fileType = FileTypeManager.getInstance().getFileTypeByFileName(file.getName()); - return !fileType.isBinary() && !document.getText().trim().isEmpty(); + FileIndexFacade indexFacade = FileIndexFacade.getInstance(psiElement.getProject()); + FileType fileType = FileTypeManager.getInstance().getFileTypeByFileName(virtualFile.getName()); + return !fileType.isBinary() && !document.getText().trim().isEmpty() && (!indexFacade.isInLibrarySource(virtualFile) && !indexFacade.isInLibraryClasses(virtualFile)); } private InlayPresentation buildClickableInlayPresentation(String displayPrefixText, String displaySuffixText, EditorActionEnum actionEnum, PsiElement psiElement) { @@ -197,13 +192,10 @@ private void handleActionCallback(EditorActionEnum actionEnum, PsiElement psiEle if (EditorActionEnum.COMMENT_METHOD.equals(actionEnum)) { ApplicationManager.getApplication().invokeLater(() -> { moveCareToPreviousLineStart(editor, textRange.getStartOffset()); - AnAction action = ActionManager.getInstance().getAction("com.zhongan.devpilot.actions.editor.generate.method.comments"); - DataContext context = SimpleDataContext.getProjectContext(editor.getProject()); - action.actionPerformed(new AnActionEvent(null, context, "", new Presentation(), ActionManager.getInstance(), 0)); }); - } else { - service.handleActions(actionEnum, psiElement); } + + service.handleActions(actionEnum, psiElement); } private static int getAnchorOffset(@NotNull PsiElement psiElement) { diff --git a/src/main/java/com/zhongan/devpilot/actions/editor/popupmenu/PopupMenuEditorActionGroupUtil.java b/src/main/java/com/zhongan/devpilot/actions/editor/popupmenu/PopupMenuEditorActionGroupUtil.java index 3a07ab08..40efadb4 100644 --- a/src/main/java/com/zhongan/devpilot/actions/editor/popupmenu/PopupMenuEditorActionGroupUtil.java +++ b/src/main/java/com/zhongan/devpilot/actions/editor/popupmenu/PopupMenuEditorActionGroupUtil.java @@ -13,22 +13,17 @@ import com.intellij.psi.PsiElement; import com.zhongan.devpilot.actions.notifications.DevPilotNotification; import com.zhongan.devpilot.constant.DefaultConst; -import com.zhongan.devpilot.constant.PromptConst; import com.zhongan.devpilot.enums.EditorActionEnum; import com.zhongan.devpilot.enums.SessionTypeEnum; -import com.zhongan.devpilot.enums.UtFrameTypeEnum; import com.zhongan.devpilot.gui.toolwindows.chat.DevPilotChatToolWindowService; import com.zhongan.devpilot.gui.toolwindows.components.EditorInfo; -import com.zhongan.devpilot.provider.ut.UtFrameworkProvider; -import com.zhongan.devpilot.provider.ut.UtFrameworkProviderFactory; +import com.zhongan.devpilot.provider.file.FileAnalyzeProviderFactory; import com.zhongan.devpilot.settings.actionconfiguration.EditorActionConfigurationState; import com.zhongan.devpilot.settings.state.DevPilotLlmSettingsState; import com.zhongan.devpilot.settings.state.LanguageSettingsState; import com.zhongan.devpilot.util.DevPilotMessageBundle; import com.zhongan.devpilot.util.DocumentUtil; import com.zhongan.devpilot.util.LanguageUtil; -import com.zhongan.devpilot.util.PsiElementUtils; -import com.zhongan.devpilot.util.PsiFileUtil; import com.zhongan.devpilot.webview.model.CodeReferenceModel; import com.zhongan.devpilot.webview.model.MessageModel; @@ -41,14 +36,9 @@ import javax.swing.Icon; -import static com.zhongan.devpilot.constant.PlaceholderConst.ADDITIONAL_MOCK_PROMPT; import static com.zhongan.devpilot.constant.PlaceholderConst.ANSWER_LANGUAGE; -import static com.zhongan.devpilot.constant.PlaceholderConst.CLASS_FULL_NAME; import static com.zhongan.devpilot.constant.PlaceholderConst.LANGUAGE; -import static com.zhongan.devpilot.constant.PlaceholderConst.MOCK_FRAMEWORK; -import static com.zhongan.devpilot.constant.PlaceholderConst.RELATED_CLASS; import static com.zhongan.devpilot.constant.PlaceholderConst.SELECTED_CODE; -import static com.zhongan.devpilot.constant.PlaceholderConst.TEST_FRAMEWORK; public class PopupMenuEditorActionGroupUtil { @@ -106,30 +96,8 @@ protected void actionPerformed(Project project, Editor editor, String selectedTe EditorInfo editorInfo = new EditorInfo(editor); if (editorActionEnum == EditorActionEnum.GENERATE_TESTS) { - if (language != null && language.isJvmPlatform() - && PsiFileUtil.isCaretInWebClass(project, editor)) { - data.put(ADDITIONAL_MOCK_PROMPT, PromptConst.MOCK_WEB_MVC); - } - UtFrameworkProvider utFrameworkProvider = UtFrameworkProviderFactory.create(language); - if (utFrameworkProvider != null) { - UtFrameTypeEnum utFramework = utFrameworkProvider.getUTFramework(project, editor); - data.put(TEST_FRAMEWORK, utFramework.getUtFrameType()); - data.put(MOCK_FRAMEWORK, utFramework.getMockFrameType()); - } - if (language != null && "java".equalsIgnoreCase(language.getLanguageName())) { - if (psiElement != null) { - var relatedClass = PsiElementUtils.getRelatedClass(psiElement); - var fullClassName = PsiElementUtils.getFullClassName(psiElement); - - if (relatedClass != null) { - data.put(RELATED_CLASS, relatedClass); - } - - if (fullClassName != null) { - data.put(CLASS_FULL_NAME, fullClassName); - } - } - } + FileAnalyzeProviderFactory.getProvider(language == null ? null : language.getLanguageName()) + .buildTestDataMap(project, editor, data); } if (LanguageSettingsState.getInstance().getLanguageIndex() == 1) { @@ -151,10 +119,15 @@ protected void actionPerformed(Project project, Editor editor, String selectedTe var codeMessage = MessageModel.buildCodeMessage( UUID.randomUUID().toString(), System.currentTimeMillis(), showText, username, codeReferenceModel); - service.sendMessage(SessionTypeEnum.MULTI_TURN.getCode(), editorActionEnum.name(), data, null, callback, codeMessage); + FileAnalyzeProviderFactory.getProvider(language == null ? null : language.getLanguageName()) + .buildChatDataMap(project, psiElement, codeReferenceModel, data); + + service.chat(SessionTypeEnum.MULTI_TURN.getCode(), editorActionEnum.name(), data, null, callback, codeMessage); } }; - group.add(action); + if (!label.equals(EditorActionEnum.COMMENT_METHOD.getLabel())) { + group.add(action); + } }); } } diff --git a/src/main/java/com/zhongan/devpilot/constant/DefaultConst.java b/src/main/java/com/zhongan/devpilot/constant/DefaultConst.java index e4beef92..751999de 100644 --- a/src/main/java/com/zhongan/devpilot/constant/DefaultConst.java +++ b/src/main/java/com/zhongan/devpilot/constant/DefaultConst.java @@ -54,4 +54,20 @@ private DefaultConst() { public static final int COMPLETION_TRIGGER_INTERVAL = 1000; + public static final int CHAT_STEP_ONE = 1; + + public static final int CHAT_STEP_TWO = 2; + + public static final int CHAT_STEP_THREE = 3; + + public static final String DEFAULT_PROMPT_VERSION = "V240923"; + + public static final String CODE_PREDICT_PROMPT_VERSION = "V240923"; + + public static final String GIT_COMMIT_PROMPT_VERSION = "V240923"; + + public static final int NORMAL_CHAT_TYPE = 1; + + public static final int SMART_CHAT_TYPE = 2; + } \ No newline at end of file diff --git a/src/main/java/com/zhongan/devpilot/enums/EditorActionEnum.java b/src/main/java/com/zhongan/devpilot/enums/EditorActionEnum.java index 0b46e174..8073b84c 100644 --- a/src/main/java/com/zhongan/devpilot/enums/EditorActionEnum.java +++ b/src/main/java/com/zhongan/devpilot/enums/EditorActionEnum.java @@ -40,6 +40,18 @@ public static EditorActionEnum getEnumByLabel(String label) { return null; } + public static EditorActionEnum getEnumByName(String name) { + if (Objects.isNull(name)) { + return null; + } + for (EditorActionEnum type : EditorActionEnum.values()) { + if (type.name().equals(name)) { + return type; + } + } + return null; + } + public String getLabel() { return label; } diff --git a/src/main/java/com/zhongan/devpilot/gui/toolwindows/chat/DevPilotChatToolWindow.java b/src/main/java/com/zhongan/devpilot/gui/toolwindows/chat/DevPilotChatToolWindow.java index a021c2b0..ecd63252 100644 --- a/src/main/java/com/zhongan/devpilot/gui/toolwindows/chat/DevPilotChatToolWindow.java +++ b/src/main/java/com/zhongan/devpilot/gui/toolwindows/chat/DevPilotChatToolWindow.java @@ -119,7 +119,7 @@ private void registerJsCallJavaHandler(JBCefBrowser browser) { var message = service.getUserContentCode(messageModel); var userMessageModel = MessageModel.buildCodeMessage(uuid, time, message.getContent(), username, message.getCodeRef()); - service.sendMessage(SessionTypeEnum.MULTI_TURN.getCode(), "PURE_CHAT", null, message.getContent(), null, userMessageModel); + service.chat(SessionTypeEnum.MULTI_TURN.getCode(), "PURE_CHAT", null, message.getContent(), null, userMessageModel); return new JBCefJSQuery.Response("success"); } diff --git a/src/main/java/com/zhongan/devpilot/gui/toolwindows/chat/DevPilotChatToolWindowService.java b/src/main/java/com/zhongan/devpilot/gui/toolwindows/chat/DevPilotChatToolWindowService.java index 319ef61a..0c446db5 100644 --- a/src/main/java/com/zhongan/devpilot/gui/toolwindows/chat/DevPilotChatToolWindowService.java +++ b/src/main/java/com/zhongan/devpilot/gui/toolwindows/chat/DevPilotChatToolWindowService.java @@ -7,6 +7,7 @@ import com.intellij.openapi.fileEditor.FileEditorManager; import com.intellij.openapi.project.Project; import com.intellij.openapi.ui.popup.Balloon; +import com.intellij.openapi.util.Computable; import com.intellij.psi.PsiElement; import com.zhongan.devpilot.actions.editor.popupmenu.BasicEditorAction; import com.zhongan.devpilot.constant.DefaultConst; @@ -16,11 +17,14 @@ import com.zhongan.devpilot.integrations.llms.LlmProvider; import com.zhongan.devpilot.integrations.llms.LlmProviderFactory; import com.zhongan.devpilot.integrations.llms.entity.DevPilotChatCompletionRequest; +import com.zhongan.devpilot.integrations.llms.entity.DevPilotCodePrediction; import com.zhongan.devpilot.integrations.llms.entity.DevPilotMessage; +import com.zhongan.devpilot.provider.file.FileAnalyzeProviderFactory; import com.zhongan.devpilot.util.BalloonAlertUtils; import com.zhongan.devpilot.util.DevPilotMessageBundle; import com.zhongan.devpilot.util.JsonUtils; import com.zhongan.devpilot.util.MessageUtil; +import com.zhongan.devpilot.util.PsiElementUtils; import com.zhongan.devpilot.util.TokenUtils; import com.zhongan.devpilot.webview.model.CodeReferenceModel; import com.zhongan.devpilot.webview.model.EmbeddedModel; @@ -28,16 +32,27 @@ import com.zhongan.devpilot.webview.model.LocaleModel; import com.zhongan.devpilot.webview.model.LoginModel; import com.zhongan.devpilot.webview.model.MessageModel; +import com.zhongan.devpilot.webview.model.RecallModel; import com.zhongan.devpilot.webview.model.ThemeModel; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; +import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.StringUtils; +import static com.zhongan.devpilot.constant.DefaultConst.CHAT_STEP_ONE; +import static com.zhongan.devpilot.constant.DefaultConst.CHAT_STEP_THREE; +import static com.zhongan.devpilot.constant.DefaultConst.CHAT_STEP_TWO; +import static com.zhongan.devpilot.constant.DefaultConst.CODE_PREDICT_PROMPT_VERSION; +import static com.zhongan.devpilot.constant.DefaultConst.NORMAL_CHAT_TYPE; +import static com.zhongan.devpilot.constant.DefaultConst.SMART_CHAT_TYPE; import static com.zhongan.devpilot.enums.SessionTypeEnum.MULTI_TURN; @Service @@ -52,6 +67,14 @@ public final class DevPilotChatToolWindowService { private final List historyRequestMessageList = new ArrayList<>(); + private final AtomicBoolean cancel = new AtomicBoolean(false); + + private MessageModel lastMessage = new MessageModel(); + + private final AtomicInteger nowStep = new AtomicInteger(1); + + private volatile String currentMessageId = null; + public DevPilotChatToolWindowService(Project project) { this.project = project; this.devPilotChatToolWindow = new DevPilotChatToolWindow(project); @@ -65,12 +88,215 @@ public Project getProject() { return this.project; } - public String sendMessage(Integer sessionType, String msgType, Map data, String message, Consumer callback, MessageModel messageModel) { + public void chat(Integer sessionType, String msgType, Map data, + String message, Consumer callback, MessageModel messageModel) { + this.cancel.set(false); + this.currentMessageId = messageModel.getId(); + this.lastMessage = messageModel; + + callWebView(messageModel); + addMessage(messageModel); + callWebView(MessageModel.buildLoadingMessage()); + + this.llmProvider = new LlmProviderFactory().getLlmProvider(project); + + // todo normal chat should be used in some case +// if (messageModel.getCodeRef() == null) { +// normalChat(sessionType, msgType, data, message, callback, messageModel); +// return; +// } + + smartChat(sessionType, msgType, data, message, callback, messageModel); + } + + public void normalChat(Integer sessionType, String msgType, Map data, + String message, Consumer callback, MessageModel messageModel) { + sendMessage(sessionType, msgType, data, message, callback, messageModel, null, null, NORMAL_CHAT_TYPE); + } + + public void smartChat(Integer sessionType, String msgType, Map data, + String message, Consumer callback, MessageModel messageModel) { + ApplicationManager.getApplication().executeOnPooledThread(() -> { + // step1 call model to do code prediction + if (shouldCancelChat(messageModel)) { + return; + } + + this.nowStep.set(CHAT_STEP_ONE); + var references = codePredict(messageModel.getContent(), messageModel.getCodeRef(), msgType); + + // step2 call rag to analyze code + if (shouldCancelChat(messageModel)) { + return; + } + + this.nowStep.set(CHAT_STEP_TWO); + var localRef = callRag(references, messageModel.getCodeRef()); + + // step3 call model to get the final result + if (shouldCancelChat(messageModel)) { + return; + } + + this.nowStep.set(CHAT_STEP_THREE); + + // avoid immutable map + Map newMap; + if (data != null) { + newMap = new HashMap<>(data); + } else { + newMap = new HashMap<>(); + } + final List[] localRefs = new List[1]; + + if (localRef != null) { + ApplicationManager.getApplication().runReadAction(() -> { + var relatedCode = PsiElementUtils.transformElementToString(localRef); + newMap.put("relatedContext", relatedCode); + // newMap.put("additionalRelatedContext", null); + localRefs[0] = CodeReferenceModel.getCodeRefListFromPsiElement(localRef, EditorActionEnum.getEnumByName(msgType)); + }); + } + + sendMessage(sessionType, msgType, newMap, message, callback, messageModel, null, localRefs[0], SMART_CHAT_TYPE); + }); + } + + public void regenerateChat(MessageModel messageModel, Consumer callback) { + this.cancel.set(false); + + callWebView(MessageModel.buildLoadingMessage()); + + this.llmProvider = new LlmProviderFactory().getLlmProvider(project); + + // todo normal chat should be used in some case +// if (messageModel.getCodeRef() == null) { +// regenerateNormalChat(messageModel, callback); +// return; +// } + + regenerateSmartChat(messageModel, callback); + } + + public void regenerateNormalChat(MessageModel messageModel, Consumer callback) { + sendMessage(callback, null, null, null, NORMAL_CHAT_TYPE); + } + + public void regenerateSmartChat(MessageModel messageModel, Consumer callback) { + ApplicationManager.getApplication().executeOnPooledThread(() -> { + // step1 call model to do code prediction + if (shouldCancelChat(messageModel)) { + return; + } + + this.nowStep.set(CHAT_STEP_ONE); + var references = codePredict(messageModel.getContent(), messageModel.getCodeRef(), null); + + // step2 call rag to analyze code + if (shouldCancelChat(messageModel)) { + return; + } + + this.nowStep.set(CHAT_STEP_TWO); + var localRef = callRag(references, messageModel.getCodeRef()); + + // step3 call model to get the final result + if (shouldCancelChat(messageModel)) { + return; + } + + this.nowStep.set(CHAT_STEP_THREE); + + var data = new HashMap(); + final List[] localRefs = new List[1]; + + if (localRef != null) { + ApplicationManager.getApplication().runReadAction(() -> { + var relatedCode = PsiElementUtils.transformElementToString(localRef); + data.put("relatedContext", relatedCode); +// data.put("additionalRelatedContext", null); + localRefs[0] = CodeReferenceModel.getCodeRefListFromPsiElement(localRef, messageModel.getCodeRef().getType()); + }); + } + + sendMessage(callback, data, null, localRefs[0], SMART_CHAT_TYPE); + }); + } + + private boolean shouldCancelChat(MessageModel messageModel) { + if (cancel.get()) { + return true; + } + + return !StringUtils.equals(currentMessageId, messageModel.getId()); + } + + private DevPilotCodePrediction codePredict(String content, CodeReferenceModel codeReference, String commandType) { + this.lastMessage = MessageModel + .buildAssistantMessage(System.currentTimeMillis() + "", System.currentTimeMillis(), "", true, RecallModel.create(1)); + callWebView(this.lastMessage); + + final Map dataMap = new HashMap<>(); + + if (commandType == null) { + if (codeReference == null || codeReference.getType() == null) { + commandType = "PURE_CHAT"; + } else { + commandType = codeReference.getType().name(); + } + } + + dataMap.put("commandTypeFor", commandType); + + if (codeReference != null) { + ApplicationManager.getApplication().runReadAction(() -> { + FileAnalyzeProviderFactory.getProvider(codeReference.getLanguageId()) + .buildCodePredictDataMap(project, codeReference, dataMap); + }); + } + + var devPilotChatCompletionRequest = new DevPilotChatCompletionRequest(); + devPilotChatCompletionRequest.setVersion(CODE_PREDICT_PROMPT_VERSION); + devPilotChatCompletionRequest.getMessages().addAll(removeRedundantRelatedContext(copyHistoryRequestMessageList(historyRequestMessageList))); + devPilotChatCompletionRequest.getMessages().add( + MessageUtil.createPromptMessage(System.currentTimeMillis() + "", "CODE_PREDICTION", content, dataMap)); + devPilotChatCompletionRequest.setStream(Boolean.FALSE); + var response = this.llmProvider.codePrediction(devPilotChatCompletionRequest); + if (!response.isSuccessful()) { + return null; + } + return JsonUtils.fromJson(response.getContent(), DevPilotCodePrediction.class); + } + + private List callRag(DevPilotCodePrediction codePredict, CodeReferenceModel codeReference) { + this.lastMessage = MessageModel + .buildAssistantMessage(System.currentTimeMillis() + "", System.currentTimeMillis(), "", true, RecallModel.create(2)); + callWebView(this.lastMessage); + + // call local rag + if (codePredict == null) { + return null; + } + + // todo call remote rag + return ApplicationManager.getApplication().runReadAction( + (Computable>) () -> { + var language = codeReference == null ? null : codeReference.getLanguageId(); + + return FileAnalyzeProviderFactory + .getProvider(language).callLocalRag(project, codePredict); + } + ); + } + + public String sendMessage(Integer sessionType, String msgType, Map data, + String message, Consumer callback, MessageModel messageModel, + List remoteRefs, List localRefs, int chatType) { DevPilotMessage userMessage; if (data == null || data.isEmpty()) { userMessage = MessageUtil.createUserMessage(message, msgType, messageModel.getId()); } else { - userMessage = MessageUtil.createPromptMessage(messageModel.getId(), msgType, data); + userMessage = MessageUtil.createPromptMessage(messageModel.getId(), msgType, message, data); } // check session type,default multi session @@ -81,20 +307,13 @@ public String sendMessage(Integer sessionType, String msgType, Map historyRequestMessageList.size()) { // update multi session request @@ -105,17 +324,24 @@ public String sendMessage(Integer sessionType, String msgType, Map callback) { - // check session type,default multi session + public String sendMessage(Consumer callback, Map data, + List remoteRefs, List localRefs, int chatType) { + // if data is not empty, the data should add into last history request message + if (data != null && !data.isEmpty() && !historyMessageList.isEmpty()) { + var lastHistoryRequestMessage = historyRequestMessageList.get(historyRequestMessageList.size() - 1); + if (lastHistoryRequestMessage.getPromptData() == null) { + lastHistoryRequestMessage.setPromptData(new HashMap<>()); + } + lastHistoryRequestMessage.getPromptData().putAll(data); + } + var devPilotChatCompletionRequest = new DevPilotChatCompletionRequest(); devPilotChatCompletionRequest.setStream(true); devPilotChatCompletionRequest.getMessages().addAll(copyHistoryRequestMessageList(historyRequestMessageList)); - callWebView(MessageModel.buildLoadingMessage()); - this.llmProvider = new LlmProviderFactory().getLlmProvider(project); - var chatCompletion = this.llmProvider.chatCompletion(project, devPilotChatCompletionRequest, callback); + var chatCompletion = this.llmProvider.chatCompletion(project, devPilotChatCompletionRequest, callback, remoteRefs, localRefs, chatType); if (devPilotChatCompletionRequest.getMessages().size() > historyRequestMessageList.size()) { // update multi session request historyRequestMessageList.add( @@ -126,11 +352,39 @@ public String sendMessage(Consumer callback) { } public void interruptSend() { - if (this.llmProvider != null) { + this.cancel.set(true); + if (this.lastMessage.getRecall() == null || this.nowStep.get() >= 3) { this.llmProvider.interruptSend(); + } else { + if (this.lastMessage != null) { + this.lastMessage.setStreaming(false); + this.lastMessage.setRecall(RecallModel.createTerminated(this.nowStep.get())); + addMessage(this.lastMessage); + callWebView(); + this.lastMessage = null; + } } } + /** + * Only used in CODE_PREDICTION for minimizing request data size. + * @param devPilotMessages + */ + private List removeRedundantRelatedContext(List devPilotMessages) { + if (CollectionUtils.isEmpty(devPilotMessages)) { + return Collections.emptyList(); + } + ArrayList copy = new ArrayList<>(devPilotMessages); + copy.forEach( + msg -> { + if (msg.getPromptData() != null) { + msg.getPromptData().remove("relatedContext"); + } + } + ); + return copy; + } + public List getHistoryMessageList() { return historyMessageList; } @@ -202,8 +456,15 @@ public void regenerateMessage() { var id = lastMessage.getId(); historyMessageList.removeIf(item -> item.getId().equals(id)); historyRequestMessageList.removeIf(item -> item.getId().equals(id)); + // todo handle real callback - sendMessage(null); + lastMessage = historyMessageList.get(historyMessageList.size() - 1); + + if (!lastMessage.getRole().equals("user")) { + return; + } + + regenerateChat(lastMessage, null); } public void handleActions(EditorActionEnum actionEnum, PsiElement psiElement) { diff --git a/src/main/java/com/zhongan/devpilot/integrations/llms/LlmProvider.java b/src/main/java/com/zhongan/devpilot/integrations/llms/LlmProvider.java index ad991363..84b0199c 100644 --- a/src/main/java/com/zhongan/devpilot/integrations/llms/LlmProvider.java +++ b/src/main/java/com/zhongan/devpilot/integrations/llms/LlmProvider.java @@ -11,9 +11,12 @@ import com.zhongan.devpilot.util.DevPilotMessageBundle; import com.zhongan.devpilot.util.JsonUtils; import com.zhongan.devpilot.util.OkhttpUtils; +import com.zhongan.devpilot.webview.model.CodeReferenceModel; import com.zhongan.devpilot.webview.model.MessageModel; +import com.zhongan.devpilot.webview.model.RecallModel; import java.io.IOException; +import java.util.List; import java.util.UUID; import java.util.function.Consumer; @@ -27,14 +30,19 @@ import okhttp3.sse.EventSourceListener; import okhttp3.sse.EventSources; +import static com.zhongan.devpilot.constant.DefaultConst.SMART_CHAT_TYPE; + public interface LlmProvider { - String chatCompletion(Project project, DevPilotChatCompletionRequest chatCompletionRequest, Consumer callback); + String chatCompletion(Project project, DevPilotChatCompletionRequest chatCompletionRequest, + Consumer callback, List remoteRefs, List localRefs, int type); DevPilotChatCompletionResponse chatCompletionSync(DevPilotChatCompletionRequest chatCompletionRequest); DevPilotMessage instructCompletion(DevPilotInstructCompletionRequest instructCompletionRequest); + DevPilotChatCompletionResponse codePrediction(DevPilotChatCompletionRequest chatCompletionRequest); + void interruptSend(); default void restoreMessage(MessageModel messageModel) { @@ -66,8 +74,8 @@ default void handlePluginVersionTooLow(DevPilotChatToolWindowService service, bo DevPilotNotification.upgradePluginNotification(service.getProject()); } - default EventSource buildEventSource(Request request, - DevPilotChatToolWindowService service, Consumer callback) { + default EventSource buildEventSource(Request request, DevPilotChatToolWindowService service, Consumer callback, + List remoteRefs, List localRefs, int chatType) { var time = System.currentTimeMillis(); var result = new StringBuilder(); var client = OkhttpUtils.getClient(); @@ -110,22 +118,39 @@ public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Null streaming = !"stop".equals(finishReason); } + RecallModel recallModel = null; + + if (chatType == SMART_CHAT_TYPE) { + recallModel = RecallModel.create(3, remoteRefs, localRefs); + } + var assistantMessage = MessageModel - .buildAssistantMessage(response.getId(), time, result.toString(), streaming); + .buildAssistantMessage(response.getId(), time, result.toString(), streaming, recallModel); restoreMessage(assistantMessage); service.callWebView(assistantMessage); if (!streaming) { + if (chatType == SMART_CHAT_TYPE) { + recallModel = RecallModel.create(4, remoteRefs, localRefs); + } + assistantMessage = MessageModel + .buildAssistantMessage(response.getId(), time, result.toString(), streaming, recallModel); + service.callWebView(assistantMessage); + service.addMessage(assistantMessage); var devPilotMessage = new DevPilotMessage(); devPilotMessage.setId(response.getId()); devPilotMessage.setRole("assistant"); devPilotMessage.setContent(result.toString()); service.addRequestMessage(devPilotMessage); + if (callback != null) { callback.accept(result.toString()); } + + // clear message cache + restoreMessage(null); } } @@ -164,7 +189,7 @@ public void onFailure(@NotNull EventSource eventSource, @Nullable Throwable t, @ } var assistantMessage = MessageModel - .buildAssistantMessage(UUID.randomUUID().toString(), time, message, false); + .buildAssistantMessage(UUID.randomUUID().toString(), time, message, false, null); service.callWebView(assistantMessage); service.addMessage(assistantMessage); diff --git a/src/main/java/com/zhongan/devpilot/integrations/llms/aigateway/AIGatewayServiceProvider.java b/src/main/java/com/zhongan/devpilot/integrations/llms/aigateway/AIGatewayServiceProvider.java index 188af054..df77ee7f 100644 --- a/src/main/java/com/zhongan/devpilot/integrations/llms/aigateway/AIGatewayServiceProvider.java +++ b/src/main/java/com/zhongan/devpilot/integrations/llms/aigateway/AIGatewayServiceProvider.java @@ -19,13 +19,14 @@ import com.zhongan.devpilot.settings.state.AIGatewaySettingsState; import com.zhongan.devpilot.settings.state.LanguageSettingsState; import com.zhongan.devpilot.util.DevPilotMessageBundle; -import com.zhongan.devpilot.util.EditorUtils; import com.zhongan.devpilot.util.GatewayRequestUtils; -import com.zhongan.devpilot.util.GitUtil; +import com.zhongan.devpilot.util.GatewayRequestV2Utils; import com.zhongan.devpilot.util.LoginUtils; import com.zhongan.devpilot.util.OkhttpUtils; import com.zhongan.devpilot.util.UserAgentUtils; +import com.zhongan.devpilot.webview.model.CodeReferenceModel; import com.zhongan.devpilot.webview.model.MessageModel; +import com.zhongan.devpilot.webview.model.RecallModel; import java.io.IOException; import java.util.List; @@ -57,7 +58,8 @@ public final class AIGatewayServiceProvider implements LlmProvider { private MessageModel resultModel = new MessageModel(); @Override - public String chatCompletion(Project project, DevPilotChatCompletionRequest chatCompletionRequest, Consumer callback) { + public String chatCompletion(Project project, DevPilotChatCompletionRequest chatCompletionRequest, + Consumer callback, List remoteRefs, List localRefs, int chatType) { var service = project.getService(DevPilotChatToolWindowService.class); this.toolWindowService = service; @@ -76,24 +78,23 @@ public String chatCompletion(Project project, DevPilotChatCompletionRequest chat } try { + String requestBody = GatewayRequestV2Utils.encodeRequest(chatCompletionRequest); + if (requestBody == null) { + service.callErrorInfo("Chat completion failed: request body is null"); + return ""; + } + var requestBuilder = new Request.Builder() - .url(host + "/devpilot/v1/chat/completions") + .url(host + "/devpilot/v2/chat/completions") .header("User-Agent", UserAgentUtils.buildUserAgent()) .header("Auth-Type", LoginUtils.getLoginType()); - if (isLatestUserContentContainsRepo(chatCompletionRequest)) { - String repoName = EditorUtils.getCurrentEditorRepositoryName(project); - if (repoName != null && GitUtil.isRepoEmbedded(repoName)) { - requestBuilder.header("Embedded-Repos-V2", repoName); - requestBuilder.header("X-B3-Language", LanguageSettingsState.getInstance().getLanguageIndex() == 1 ? "zh-CN" : "en-US"); - } - } var request = requestBuilder - .post(RequestBody.create(GatewayRequestUtils.chatRequestJson(chatCompletionRequest), MediaType.parse("application/json"))) + .post(RequestBody.create(requestBody, MediaType.parse("application/json"))) .build(); DevPilotNotification.debug(LoginUtils.getLoginType() + "---" + UserAgentUtils.buildUserAgent()); - this.es = this.buildEventSource(request, service, callback); + this.es = this.buildEventSource(request, service, callback, remoteRefs, localRefs, chatType); } catch (Exception e) { DevPilotNotification.debug("Chat completion failed: " + e.getMessage()); @@ -111,6 +112,11 @@ public void interruptSend() { // remember the broken message if (resultModel != null && !StringUtils.isEmpty(resultModel.getContent())) { resultModel.setStreaming(false); + var recall = resultModel.getRecall(); + if (recall != null) { + var newRecall = RecallModel.createTerminated(3, recall.getRemoteRefs(), recall.getLocalRefs()); + resultModel.setRecall(newRecall); + } toolWindowService.addMessage(resultModel); } @@ -144,11 +150,15 @@ public DevPilotChatCompletionResponse chatCompletionSync(DevPilotChatCompletionR Response response; try { - String requestBody = GatewayRequestUtils.chatRequestJson(chatCompletionRequest); + String requestBody = GatewayRequestV2Utils.encodeRequest(chatCompletionRequest); + if (requestBody == null) { + return DevPilotChatCompletionResponse.failed("Chat completion failed: request body is null"); + } + DevPilotNotification.debug("Send Request :[" + requestBody + "]."); var request = new Request.Builder() - .url(host + "/devpilot/v1/chat/completions") + .url(host + "/devpilot/v2/chat/completions") .header("User-Agent", UserAgentUtils.buildUserAgent()) .header("Auth-Type", LoginUtils.getLoginType()) .post(RequestBody.create(requestBody, MediaType.parse("application/json"))) @@ -302,4 +312,45 @@ private Boolean isLatestUserContentContainsRepo(DevPilotChatCompletionRequest ch } return false; } + + @Override + public DevPilotChatCompletionResponse codePrediction(DevPilotChatCompletionRequest chatCompletionRequest) { + var selectedModel = AIGatewaySettingsState.getInstance().getSelectedModel(); + var host = AIGatewaySettingsState.getInstance().getModelBaseHost(selectedModel); + + if (StringUtils.isEmpty(host)) { + return DevPilotChatCompletionResponse.failed("Chat completion failed: host is empty"); + } + + Response response; + + try { + String requestBody = GatewayRequestV2Utils.encodeRequest(chatCompletionRequest); + if (requestBody == null) { + return DevPilotChatCompletionResponse.failed("Chat completion failed: request body is null"); + } + + DevPilotNotification.debug("Send Request :[" + requestBody + "]."); + + var request = new Request.Builder() + .url(host + "/devpilot/v2/chat/completions") + .header("User-Agent", UserAgentUtils.buildUserAgent()) + .header("Auth-Type", LoginUtils.getLoginType()) + .post(RequestBody.create(requestBody, MediaType.parse("application/json"))) + .build(); + + Call call = OkhttpUtils.getClient().newCall(request); + response = call.execute(); + } catch (Exception e) { + DevPilotNotification.debug("Chat completion failed: " + e.getMessage()); + return DevPilotChatCompletionResponse.failed("Chat completion failed: " + e.getMessage()); + } + + try { + return parseCompletionsResult(chatCompletionRequest, response); + } catch (IOException e) { + DevPilotNotification.debug("Chat completion failed: " + e.getMessage()); + return DevPilotChatCompletionResponse.failed("Chat completion failed: " + e.getMessage()); + } + } } diff --git a/src/main/java/com/zhongan/devpilot/integrations/llms/entity/DevPilotChatCompletionRequest.java b/src/main/java/com/zhongan/devpilot/integrations/llms/entity/DevPilotChatCompletionRequest.java index ef723696..71b0cf9f 100644 --- a/src/main/java/com/zhongan/devpilot/integrations/llms/entity/DevPilotChatCompletionRequest.java +++ b/src/main/java/com/zhongan/devpilot/integrations/llms/entity/DevPilotChatCompletionRequest.java @@ -3,9 +3,11 @@ import java.util.ArrayList; import java.util.List; +import static com.zhongan.devpilot.constant.DefaultConst.DEFAULT_PROMPT_VERSION; + public class DevPilotChatCompletionRequest { - String version = "V240801"; + String version = DEFAULT_PROMPT_VERSION; String encoding = null; diff --git a/src/main/java/com/zhongan/devpilot/integrations/llms/entity/DevPilotCodePrediction.java b/src/main/java/com/zhongan/devpilot/integrations/llms/entity/DevPilotCodePrediction.java new file mode 100644 index 00000000..c0eb083d --- /dev/null +++ b/src/main/java/com/zhongan/devpilot/integrations/llms/entity/DevPilotCodePrediction.java @@ -0,0 +1,46 @@ +package com.zhongan.devpilot.integrations.llms.entity; + +import java.util.ArrayList; +import java.util.List; + +public class DevPilotCodePrediction { + private List inputArgs = new ArrayList<>(2); + + private List outputArgs = new ArrayList<>(2); + + private List references = new ArrayList<>(); + + private String comments; + + public List getInputArgs() { + return inputArgs; + } + + public void setInputArgs(List inputArgs) { + this.inputArgs = inputArgs; + } + + public List getOutputArgs() { + return outputArgs; + } + + public void setOutputArgs(List outputArgs) { + this.outputArgs = outputArgs; + } + + public List getReferences() { + return references; + } + + public void setReferences(List references) { + this.references = references; + } + + public String getComments() { + return comments; + } + + public void setComments(String comments) { + this.comments = comments; + } +} diff --git a/src/main/java/com/zhongan/devpilot/integrations/llms/trial/TrialServiceProvider.java b/src/main/java/com/zhongan/devpilot/integrations/llms/trial/TrialServiceProvider.java index 2232dec4..17a423ab 100644 --- a/src/main/java/com/zhongan/devpilot/integrations/llms/trial/TrialServiceProvider.java +++ b/src/main/java/com/zhongan/devpilot/integrations/llms/trial/TrialServiceProvider.java @@ -18,12 +18,16 @@ import com.zhongan.devpilot.settings.state.LanguageSettingsState; import com.zhongan.devpilot.util.DevPilotMessageBundle; import com.zhongan.devpilot.util.GatewayRequestUtils; +import com.zhongan.devpilot.util.GatewayRequestV2Utils; import com.zhongan.devpilot.util.LoginUtils; import com.zhongan.devpilot.util.OkhttpUtils; import com.zhongan.devpilot.util.UserAgentUtils; +import com.zhongan.devpilot.webview.model.CodeReferenceModel; import com.zhongan.devpilot.webview.model.MessageModel; +import com.zhongan.devpilot.webview.model.RecallModel; import java.io.IOException; +import java.util.List; import java.util.Objects; import java.util.function.Consumer; @@ -51,7 +55,8 @@ public final class TrialServiceProvider implements LlmProvider { private MessageModel resultModel = new MessageModel(); @Override - public String chatCompletion(Project project, DevPilotChatCompletionRequest chatCompletionRequest, Consumer callback) { + public String chatCompletion(Project project, DevPilotChatCompletionRequest chatCompletionRequest, + Consumer callback, List remoteRefs, List localRefs, int chatType) { var service = project.getService(DevPilotChatToolWindowService.class); this.toolWindowService = service; @@ -62,14 +67,20 @@ public String chatCompletion(Project project, DevPilotChatCompletionRequest chat } try { + var requestBody = GatewayRequestV2Utils.encodeRequest(chatCompletionRequest); + if (requestBody == null) { + service.callErrorInfo("Chat completion failed: request body is null"); + return ""; + } + var request = new Request.Builder() - .url(TRIAL_DEFAULT_HOST + "/v1/chat/completions") + .url(TRIAL_DEFAULT_HOST + "/v2/chat/completions") .header("User-Agent", UserAgentUtils.buildUserAgent()) .header("Auth-Type", "wx") - .post(RequestBody.create(GatewayRequestUtils.chatRequestJson(chatCompletionRequest), MediaType.parse("application/json"))) + .post(RequestBody.create(requestBody, MediaType.parse("application/json"))) .build(); - this.es = this.buildEventSource(request, service, callback); + this.es = this.buildEventSource(request, service, callback, remoteRefs, localRefs, chatType); } catch (Exception e) { service.callErrorInfo("Chat completion failed: " + e.getMessage()); return ""; @@ -87,11 +98,16 @@ public DevPilotChatCompletionResponse chatCompletionSync(DevPilotChatCompletionR Response response; try { + var requestBody = GatewayRequestV2Utils.encodeRequest(chatCompletionRequest); + if (requestBody == null) { + return DevPilotChatCompletionResponse.failed("Chat completion failed: request body is null"); + } + var request = new Request.Builder() - .url(TRIAL_DEFAULT_HOST + "/v1/chat/completions") + .url(TRIAL_DEFAULT_HOST + "/v2/chat/completions") .header("User-Agent", UserAgentUtils.buildUserAgent()) .header("Auth-Type", "wx") - .post(RequestBody.create(GatewayRequestUtils.chatRequestJson(chatCompletionRequest), MediaType.parse("application/json"))) + .post(RequestBody.create(requestBody, MediaType.parse("application/json"))) .build(); var call = OkhttpUtils.getClient().newCall(request); @@ -158,6 +174,11 @@ public void interruptSend() { // remember the broken message if (resultModel != null && !StringUtils.isEmpty(resultModel.getContent())) { resultModel.setStreaming(false); + var recall = resultModel.getRecall(); + if (recall != null) { + var newRecall = RecallModel.createTerminated(3, recall.getRemoteRefs(), recall.getLocalRefs()); + resultModel.setRecall(newRecall); + } toolWindowService.addMessage(resultModel); } @@ -202,4 +223,38 @@ private DevPilotChatCompletionResponse parseResult(DevPilotChatCompletionRequest .getMessage()); } } + + @Override + public DevPilotChatCompletionResponse codePrediction(DevPilotChatCompletionRequest chatCompletionRequest) { + if (!LoginUtils.isLogin()) { + return DevPilotChatCompletionResponse.failed("Chat completion failed: please login Wechat Login"); + } + + Response response; + + try { + var requestBody = GatewayRequestV2Utils.encodeRequest(chatCompletionRequest); + if (requestBody == null) { + return DevPilotChatCompletionResponse.failed("Chat completion failed: request body is null"); + } + + var request = new Request.Builder() + .url(TRIAL_DEFAULT_HOST + "/v2/chat/completions") + .header("User-Agent", UserAgentUtils.buildUserAgent()) + .header("Auth-Type", "wx") + .post(RequestBody.create(requestBody, MediaType.parse("application/json"))) + .build(); + + var call = OkhttpUtils.getClient().newCall(request); + response = call.execute(); + } catch (Exception e) { + return DevPilotChatCompletionResponse.failed("Chat completion failed: " + e.getMessage()); + } + + try { + return parseResult(chatCompletionRequest, response); + } catch (Exception e) { + return DevPilotChatCompletionResponse.failed("Chat completion failed: " + e.getMessage()); + } + } } diff --git a/src/main/java/com/zhongan/devpilot/provider/file/DefaultFileAnalyzeProvider.java b/src/main/java/com/zhongan/devpilot/provider/file/DefaultFileAnalyzeProvider.java new file mode 100644 index 00000000..98a1d568 --- /dev/null +++ b/src/main/java/com/zhongan/devpilot/provider/file/DefaultFileAnalyzeProvider.java @@ -0,0 +1,42 @@ +package com.zhongan.devpilot.provider.file; + +import com.intellij.openapi.editor.Editor; +import com.intellij.openapi.project.Project; +import com.intellij.psi.PsiElement; +import com.zhongan.devpilot.integrations.llms.entity.DevPilotCodePrediction; +import com.zhongan.devpilot.webview.model.CodeReferenceModel; + +import java.util.List; +import java.util.Map; + +public class DefaultFileAnalyzeProvider implements FileAnalyzeProvider { + @Override + public String languageName() { + return "none"; + } + + @Override + public String moduleName() { + return "default"; + } + + @Override + public void buildCodePredictDataMap(Project project, CodeReferenceModel codeReference, Map data) { + // default do nothing + } + + @Override + public void buildChatDataMap(Project project, PsiElement psiElement, CodeReferenceModel codeReference, Map data) { + // default do nothing + } + + @Override + public void buildTestDataMap(Project project, Editor editor, Map data) { + // default do nothing + } + + @Override + public List callLocalRag(Project project, DevPilotCodePrediction codePrediction) { + return List.of(); + } +} diff --git a/src/main/java/com/zhongan/devpilot/provider/file/FileAnalyzeProvider.java b/src/main/java/com/zhongan/devpilot/provider/file/FileAnalyzeProvider.java new file mode 100644 index 00000000..5ba43958 --- /dev/null +++ b/src/main/java/com/zhongan/devpilot/provider/file/FileAnalyzeProvider.java @@ -0,0 +1,24 @@ +package com.zhongan.devpilot.provider.file; + +import com.intellij.openapi.editor.Editor; +import com.intellij.openapi.project.Project; +import com.intellij.psi.PsiElement; +import com.zhongan.devpilot.integrations.llms.entity.DevPilotCodePrediction; +import com.zhongan.devpilot.webview.model.CodeReferenceModel; + +import java.util.List; +import java.util.Map; + +public interface FileAnalyzeProvider { + String languageName(); + + String moduleName(); + + void buildCodePredictDataMap(Project project, CodeReferenceModel codeReference, Map data); + + void buildChatDataMap(Project project, PsiElement psiElement, CodeReferenceModel codeReference, Map data); + + void buildTestDataMap(Project project, Editor editor, Map data); + + List callLocalRag(Project project, DevPilotCodePrediction codePrediction); +} diff --git a/src/main/java/com/zhongan/devpilot/provider/file/FileAnalyzeProviderFactory.java b/src/main/java/com/zhongan/devpilot/provider/file/FileAnalyzeProviderFactory.java new file mode 100644 index 00000000..f135414a --- /dev/null +++ b/src/main/java/com/zhongan/devpilot/provider/file/FileAnalyzeProviderFactory.java @@ -0,0 +1,38 @@ +package com.zhongan.devpilot.provider.file; + +import com.intellij.ide.plugins.PluginManagerCore; +import com.intellij.openapi.extensions.PluginId; +import com.zhongan.devpilot.provider.file.java.JavaFileAnalyzeProvider; + +import java.util.Locale; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +public class FileAnalyzeProviderFactory { + private static final Map providerMap = new ConcurrentHashMap<>(); + + private static final DefaultFileAnalyzeProvider defaultProvider = new DefaultFileAnalyzeProvider(); + + static { + providerMap.put("java", new JavaFileAnalyzeProvider()); + } + + public static FileAnalyzeProvider getProvider(String language) { + if (language == null) { + return defaultProvider; + } + + language = language.toLowerCase(Locale.ROOT); + var provider = providerMap.get(language); + if (provider == null) { + return defaultProvider; + } + + var model = PluginManagerCore.getPlugin(PluginId.getId(provider.moduleName())); + if (model == null) { + return defaultProvider; + } + + return provider; + } +} diff --git a/src/main/java/com/zhongan/devpilot/provider/file/java/JavaFileAnalyzeProvider.java b/src/main/java/com/zhongan/devpilot/provider/file/java/JavaFileAnalyzeProvider.java new file mode 100644 index 00000000..2f1bf85f --- /dev/null +++ b/src/main/java/com/zhongan/devpilot/provider/file/java/JavaFileAnalyzeProvider.java @@ -0,0 +1,93 @@ +package com.zhongan.devpilot.provider.file.java; + +import com.intellij.openapi.editor.Editor; +import com.intellij.openapi.project.Project; +import com.intellij.psi.PsiElement; +import com.zhongan.devpilot.constant.PromptConst; +import com.zhongan.devpilot.enums.UtFrameTypeEnum; +import com.zhongan.devpilot.integrations.llms.entity.DevPilotCodePrediction; +import com.zhongan.devpilot.provider.file.FileAnalyzeProvider; +import com.zhongan.devpilot.provider.ut.UtFrameworkProvider; +import com.zhongan.devpilot.provider.ut.UtFrameworkProviderFactory; +import com.zhongan.devpilot.util.PsiElementUtils; +import com.zhongan.devpilot.util.PsiFileUtil; +import com.zhongan.devpilot.webview.model.CodeReferenceModel; + +import java.util.List; +import java.util.Map; + +import static com.zhongan.devpilot.constant.PlaceholderConst.ADDITIONAL_MOCK_PROMPT; +import static com.zhongan.devpilot.constant.PlaceholderConst.CLASS_FULL_NAME; +import static com.zhongan.devpilot.constant.PlaceholderConst.MOCK_FRAMEWORK; +import static com.zhongan.devpilot.constant.PlaceholderConst.TEST_FRAMEWORK; + +public class JavaFileAnalyzeProvider implements FileAnalyzeProvider { + @Override + public String languageName() { + return "java"; + } + + @Override + public String moduleName() { + return "com.intellij.java"; + } + + @Override + public void buildCodePredictDataMap(Project project, CodeReferenceModel codeReference, Map data) { + var psiJavaFile = PsiElementUtils.getPsiJavaFileByFilePath(project, codeReference.getFileUrl()); + + if (psiJavaFile != null) { + data.putAll( + Map.of( + "imports", PsiElementUtils.getImportInfo(psiJavaFile), + "package", psiJavaFile.getPackageName(), + "fields", PsiElementUtils.getFieldList(psiJavaFile), + "selectedCode", codeReference.getSourceCode(), + "filePath", codeReference.getFileUrl() + ) + ); + } + } + + @Override + public void buildChatDataMap(Project project, PsiElement psiElement, CodeReferenceModel codeReference, Map data) { + var psiJavaFile = PsiElementUtils.getPsiJavaFileByFilePath(project, codeReference.getFileUrl()); + + if (psiJavaFile != null) { + data.putAll( + Map.of( + "imports", PsiElementUtils.getImportInfo(psiJavaFile), + "package", psiJavaFile.getPackageName(), + "fields", PsiElementUtils.getFieldList(psiJavaFile), + "filePath", codeReference.getFileUrl() + ) + ); + } + + if (psiElement != null) { + var fullClassName = PsiElementUtils.getFullClassName(psiElement); + + if (fullClassName != null) { + data.put(CLASS_FULL_NAME, fullClassName); + } + } + } + + @Override + public void buildTestDataMap(Project project, Editor editor, Map data) { + if (PsiFileUtil.isCaretInWebClass(project, editor)) { + data.put(ADDITIONAL_MOCK_PROMPT, PromptConst.MOCK_WEB_MVC); + } + UtFrameworkProvider utFrameworkProvider = UtFrameworkProviderFactory.create(languageName()); + if (utFrameworkProvider != null) { + UtFrameTypeEnum utFramework = utFrameworkProvider.getUTFramework(project, editor); + data.put(TEST_FRAMEWORK, utFramework.getUtFrameType()); + data.put(MOCK_FRAMEWORK, utFramework.getMockFrameType()); + } + } + + @Override + public List callLocalRag(Project project, DevPilotCodePrediction codePrediction) { + return PsiElementUtils.contextRecall(project, codePrediction); + } +} diff --git a/src/main/java/com/zhongan/devpilot/provider/ut/UtFrameworkProviderFactory.java b/src/main/java/com/zhongan/devpilot/provider/ut/UtFrameworkProviderFactory.java index 35f2183f..978e8c3c 100644 --- a/src/main/java/com/zhongan/devpilot/provider/ut/UtFrameworkProviderFactory.java +++ b/src/main/java/com/zhongan/devpilot/provider/ut/UtFrameworkProviderFactory.java @@ -23,4 +23,20 @@ public static UtFrameworkProvider create(LanguageUtil.Language language) { return null; } + public static UtFrameworkProvider create(String language) { + + if (language == null) { + return null; + } + + switch (language.toLowerCase(Locale.ROOT)) { + case "java": + return JavaUtFrameworkProvider.INSTANCE; + case "go": + case "python": + } + // todo support other languages test. + return null; + } + } diff --git a/src/main/java/com/zhongan/devpilot/settings/actionconfiguration/EditorActionConfigurationState.java b/src/main/java/com/zhongan/devpilot/settings/actionconfiguration/EditorActionConfigurationState.java index 7c3c2f6c..9c91dbcb 100644 --- a/src/main/java/com/zhongan/devpilot/settings/actionconfiguration/EditorActionConfigurationState.java +++ b/src/main/java/com/zhongan/devpilot/settings/actionconfiguration/EditorActionConfigurationState.java @@ -9,6 +9,7 @@ import java.util.ArrayList; import java.util.List; +import static com.zhongan.devpilot.enums.EditorActionEnum.COMMENT_METHOD; import static com.zhongan.devpilot.enums.EditorActionEnum.EXPLAIN_CODE; import static com.zhongan.devpilot.enums.EditorActionEnum.FIX_CODE; import static com.zhongan.devpilot.enums.EditorActionEnum.GENERATE_COMMENTS; @@ -28,6 +29,7 @@ public class EditorActionConfigurationState implements PersistentStateComponent< defaultActions.add(FIX_CODE.getLabel()); defaultActions.add(GENERATE_COMMENTS.getLabel()); defaultActions.add(GENERATE_TESTS.getLabel()); + defaultActions.add(COMMENT_METHOD.getLabel()); } public static EditorActionConfigurationState getInstance() { diff --git a/src/main/java/com/zhongan/devpilot/settings/state/ChatShortcutSettingState.java b/src/main/java/com/zhongan/devpilot/settings/state/ChatShortcutSettingState.java index 03db46b3..192650a4 100644 --- a/src/main/java/com/zhongan/devpilot/settings/state/ChatShortcutSettingState.java +++ b/src/main/java/com/zhongan/devpilot/settings/state/ChatShortcutSettingState.java @@ -15,7 +15,8 @@ public class ChatShortcutSettingState implements PersistentStateComponent T decodeRequest(String encodedRequest, Class valueType) throws Exception { + byte[] decodedBytes = Base64.getDecoder().decode(encodedRequest); + + ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(decodedBytes); + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + try (GZIPInputStream gzipInputStream = new GZIPInputStream(byteArrayInputStream)) { + byte[] buffer = new byte[1024]; + int len; + while ((len = gzipInputStream.read(buffer)) > 0) { + byteArrayOutputStream.write(buffer, 0, len); + } + } + + String jsonString = byteArrayOutputStream.toString(StandardCharsets.UTF_8); + return JsonUtils.fromJson(jsonString, valueType); + } +} diff --git a/src/main/java/com/zhongan/devpilot/util/MessageUtil.java b/src/main/java/com/zhongan/devpilot/util/MessageUtil.java index 7a57128c..ecbbd6e8 100644 --- a/src/main/java/com/zhongan/devpilot/util/MessageUtil.java +++ b/src/main/java/com/zhongan/devpilot/util/MessageUtil.java @@ -24,6 +24,16 @@ public static DevPilotMessage createPromptMessage(String id, String msgType, Map return message; } + public static DevPilotMessage createPromptMessage(String id, String msgType, String content, Map data) { + DevPilotMessage message = new DevPilotMessage(); + message.setId(id); + message.setRole("user"); + message.setPromptData(data); + message.setCommandType(msgType); + message.setContent(content); + return message; + } + public static DevPilotMessage createUserMessage(String content, String msgType, String id) { return createMessage(id, "user", msgType, content); } diff --git a/src/main/java/com/zhongan/devpilot/util/PsiElementUtils.java b/src/main/java/com/zhongan/devpilot/util/PsiElementUtils.java index c0239764..32c0c9d3 100644 --- a/src/main/java/com/zhongan/devpilot/util/PsiElementUtils.java +++ b/src/main/java/com/zhongan/devpilot/util/PsiElementUtils.java @@ -1,21 +1,52 @@ package com.zhongan.devpilot.util; -import com.intellij.lang.jvm.JvmParameter; +import com.intellij.openapi.project.Project; +import com.intellij.openapi.roots.ProjectFileIndex; +import com.intellij.openapi.vfs.LocalFileSystem; +import com.intellij.openapi.vfs.VirtualFile; +import com.intellij.psi.JavaPsiFacade; import com.intellij.psi.PsiClass; +import com.intellij.psi.PsiCodeBlock; +import com.intellij.psi.PsiDocumentManager; import com.intellij.psi.PsiElement; import com.intellij.psi.PsiField; +import com.intellij.psi.PsiFile; +import com.intellij.psi.PsiImportList; +import com.intellij.psi.PsiImportStatement; +import com.intellij.psi.PsiImportStatementBase; +import com.intellij.psi.PsiImportStaticStatement; +import com.intellij.psi.PsiJavaFile; +import com.intellij.psi.PsiManager; import com.intellij.psi.PsiMethod; import com.intellij.psi.PsiTypeParameter; +import com.intellij.psi.impl.compiled.ClsClassImpl; import com.intellij.psi.impl.source.PsiClassReferenceType; +import com.intellij.psi.search.GlobalSearchScope; +import com.intellij.psi.search.searches.ClassInheritorsSearch; +import com.intellij.psi.util.PropertyUtil; +import com.zhongan.devpilot.integrations.llms.entity.DevPilotCodePrediction; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; import java.util.HashSet; import java.util.List; import java.util.Set; - +import java.util.TreeSet; +import java.util.jar.Attributes; +import java.util.jar.JarFile; +import java.util.jar.Manifest; + +import org.apache.commons.collections.CollectionUtils; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.Pair; import org.jetbrains.annotations.NotNull; public class PsiElementUtils { + private static final int MAX_LINE_COUNT = 1000; + public static String getFullClassName(@NotNull PsiElement element) { if (element instanceof PsiMethod) { var psiClass = ((PsiMethod) element).getContainingClass(); @@ -29,55 +60,31 @@ public static String getFullClassName(@NotNull PsiElement element) { return null; } - public static String getRelatedClass(@NotNull PsiElement element) { - Set classSet = new HashSet<>(); - - if (element instanceof PsiMethod) { - classSet = getMethodRelatedClass(element); - } else if (element instanceof PsiClass) { - classSet = getClassRelatedClass(element); - } - + public static String transformElementToString(Collection elements) { var result = new StringBuilder(); - for (PsiClass psiClass : classSet) { - if (ignoreClass(psiClass)) { + for (T element : elements) { + if (shouldIgnorePsiElement(element)) { continue; } - result.append(psiClass.getText()).append("\n"); - } - - return result.toString(); - } - - private static Set getClassRelatedClass(@NotNull PsiElement element) { - Set result = new HashSet<>(); - - if (element instanceof PsiClass) { - var psiClass = (PsiClass) element; - var methods = psiClass.getMethods(); - var fields = psiClass.getFields(); - - for (PsiMethod psiMethod : methods) { - result.addAll(getMethodRelatedClass(psiMethod)); + if (element instanceof PsiClass) { + PsiClass psiClass = (PsiClass) element; + result.append("Class: ").append(psiClass.getQualifiedName()).append("\n\n"); } - for (PsiField psiField : fields) { - result.addAll(getFieldTypeClass(psiField)); + if (element instanceof PsiMethod) { + PsiMethod psiMethod = (PsiMethod) element; + if (psiMethod.getContainingClass() != null) { + result.append("Method: ").append(StringUtils.join(psiMethod.getContainingClass().getQualifiedName(), psiMethod.getName(), "#")).append("\n"); + } else { + result.append("Method: ").append("\n"); + } } - } - - return result; - } - - private static Set getMethodRelatedClass(@NotNull PsiElement element) { - var parameterClass = getMethodParameterTypeClass(element); - var returnClass = getMethodReturnTypeClass(element); - var result = new HashSet<>(parameterClass); - result.addAll(returnClass); + result.append(simplifyElement(element)).append("\n\n"); + } - return result; + return result.toString(); } private static List getMethodReturnTypeClass(@NotNull PsiElement element) { @@ -88,33 +95,8 @@ private static List getMethodReturnTypeClass(@NotNull PsiElement eleme if (returnType instanceof PsiClassReferenceType) { var referenceType = (PsiClassReferenceType) returnType; - var returnTypeClass = referenceType.resolve(); - result.addAll(getGenericType(referenceType)); - if (returnTypeClass != null) { - result.add(returnTypeClass); - return result; - } - } - } - - return result; - } - - private static List getMethodParameterTypeClass(@NotNull PsiElement element) { - var result = new ArrayList(); - - if (element instanceof PsiMethod) { - var params = ((PsiMethod) element).getParameterList().getParameters(); - - for (JvmParameter parameter : params) { - if (parameter.getType() instanceof PsiClassReferenceType) { - var referenceType = (PsiClassReferenceType) parameter.getType(); - var psiClass = referenceType.resolve(); - if (psiClass != null) { - result.add(psiClass); - } - result.addAll(getGenericType(referenceType)); - } + result.addAll(getTypeClassAndGenericType(referenceType)); + return result; } } @@ -129,11 +111,7 @@ private static List getFieldTypeClass(@NotNull PsiElement element) { if (field.getType() instanceof PsiClassReferenceType) { var referenceType = (PsiClassReferenceType) field.getType(); - var psiClass = referenceType.resolve(); - if (psiClass != null) { - result.add(psiClass); - } - result.addAll(getGenericType(referenceType)); + result.addAll(getTypeClassAndGenericType(referenceType)); } } @@ -166,6 +144,18 @@ private static List getGenericType(PsiClassReferenceType referenceType return result; } + private static List getTypeClassAndGenericType(PsiClassReferenceType referenceType) { + var result = new ArrayList(); + + var psiClass = referenceType.resolve(); + if (psiClass != null) { + result.add(psiClass); + } + result.addAll(getGenericType(referenceType)); + + return result; + } + private static boolean ignoreClass(PsiClass psiClass) { if (psiClass == null) { return true; @@ -194,4 +184,383 @@ private static boolean ignoreClass(PsiClass psiClass) { return false; } + + private static boolean ignoreMethod(PsiMethod psiMethod) { + var psiClass = psiMethod.getContainingClass(); + return ignoreClass(psiClass); + } + + public static boolean shouldIgnorePsiElement(PsiElement psiElement) { + if (psiElement == null) { + return true; + } + + if (psiElement instanceof PsiMethod) { + return ignoreMethod((PsiMethod) psiElement); + } + + if (psiElement instanceof PsiClass) { + return ignoreClass((PsiClass) psiElement); + } + + return false; + } + + public static Set parseElementsList(Project project, List elements) { + var result = new HashSet(); + + for (String element : elements) { + // element format class#method + var arrays = element.split("#"); + var classFullName = arrays[0]; + String methodName = null; + + if (arrays.length > 1) { + methodName = arrays[1]; + } + + var e = getElementByName(project, classFullName, methodName); + if (e != null) { + result.add(e); + } + } + + return result; + } + + private static PsiElement getElementByName(Project project, String className, String methodName) { + var psiClass = findRealClass(project, className); + + if (psiClass != null) { + if (methodName == null) { + return psiClass; + } else { + var methods = psiClass.getMethods(); + for (var method : methods) { + if (methodName.equals(method.getName())) { + return method; + } + } + } + } + + return null; + } + + // AI model may confuse between inner class and normal class, so we should resolve this situation + private static PsiClass findRealClass(Project project, String className) { + var javaPsiFacade = JavaPsiFacade.getInstance(project); + var factory = javaPsiFacade.getElementFactory(); + + var classType = factory.createTypeByFQClassName(className); + var psiClass = classType.resolve(); + + if (psiClass != null) { + return psiClass; + } + + if (className.contains("$")) { + className = className.replace('$', '.'); + } else { + var lastDot = className.lastIndexOf('.'); + if (lastDot != -1) { + className = className.substring(0, lastDot) + "$" + className.substring(lastDot + 1); + } + } + + classType = factory.createTypeByFQClassName(className); + return classType.resolve(); + } + + public static PsiJavaFile getPsiJavaFileByFilePath(Project project, String filePath) { + var virtualFile = LocalFileSystem.getInstance().findFileByPath(filePath); + if (virtualFile != null) { + var psiFile = PsiManager.getInstance(project).findFile(virtualFile); + if (!(psiFile instanceof PsiJavaFile)) { + return null; + } + return (PsiJavaFile) psiFile; + } + return null; + } + + private static PsiClass getPsiClassByFile(PsiJavaFile psiJavaFile) { + var classes = psiJavaFile.getClasses(); + if (classes.length > 0) { + return classes[0]; + } + return null; + } + + public static String getImportList(PsiJavaFile psiJavaFile) { + var importList = new StringBuilder(); + var importStatements = psiJavaFile.getImportList(); + if (importStatements != null) { + var imports = importStatements.getImportStatements(); + for (PsiImportStatement importStatement : imports) { + importList.append(importStatement.getQualifiedName()).append(";"); + } + + for (PsiImportStaticStatement importStatement : importStatements.getImportStaticStatements()) { + if (importStatement.getImportReference() != null) { + importList.append(importStatement.getImportReference().getQualifiedName()).append(";"); + } + } + } + return importList.toString(); + } + + public static String getFieldList(PsiJavaFile psiJavaFile) { + var fieldList = new StringBuilder(); + + var psiClass = getPsiClassByFile(psiJavaFile); + if (psiClass == null) { + return ""; + } + + var fields = psiClass.getFields(); + for (PsiField field : fields) { + fieldList.append(field.getText()).append(System.lineSeparator()); + } + return fieldList.toString(); + } + + public static String getImportInfo(PsiFile psiFile) { + StringBuilder importedClasses = new StringBuilder(); + if (!(psiFile instanceof PsiJavaFile)) { + return ""; + } + PsiImportList importList = ((PsiJavaFile) psiFile).getImportList(); + if (importList != null) { + PsiImportStatementBase[] importStatements = Arrays.stream(importList.getAllImportStatements()).toArray(PsiImportStatementBase[]::new); + for (PsiImportStatementBase importStatement : importStatements) { + importedClasses.append(importStatement.getText()).append(System.lineSeparator()); + } + } + return importedClasses.toString(); + } + + public static List contextRecall(Project project, DevPilotCodePrediction codePrediction) { + if (codePrediction == null) { + return Collections.emptyList(); + } + List refs = new ArrayList<>(); + refs.addAll(codePrediction.getInputArgs()); + refs.addAll(codePrediction.getOutputArgs()); + refs.addAll(codePrediction.getReferences()); + if (CollectionUtils.isEmpty(refs)) { + return Collections.emptyList(); + } + List finalRefs = removeDuplicates(refs); + return doRecall(project, finalRefs); + } + + public static String filterLargeElement(PsiElement element) { + var lineCount = getLineCount(element); + if (lineCount > MAX_LINE_COUNT) { + return simplifyElement(element); + } else { + return element.getText(); + } + } + + public static String simplifyElement(PsiElement element) { + if (element instanceof PsiClass) { + var psiClass = (PsiClass) element; + return simplifyClass(psiClass); + } else if (element instanceof PsiMethod) { + var psiMethod = (PsiMethod) element; + return simplifyMethod(psiMethod); + } + + return null; + } + + private static String simplifyClass(PsiClass psiClass) { + var children = psiClass.getChildren(); + var result = new StringBuilder(); + + for (PsiElement child : children) { + if (child instanceof PsiMethod) { + result.append(simplifyMethod((PsiMethod) child)); + } else if (child instanceof PsiClass) { + // handle inner class + result.append(simplifyClass((PsiClass) child)); + } else { + result.append(child.getText()); + } + } + + return result.toString(); + } + + private static String simplifyMethod(PsiMethod method) { + var children = method.getChildren(); + var result = new StringBuilder(); + + for (PsiElement child : children) { + if (!(child instanceof PsiCodeBlock)) { + result.append(child.getText()); + } + } + + return result.toString(); + } + + public static int getLineCount(PsiElement element) { + var document = PsiDocumentManager.getInstance(element.getProject()).getDocument(element.getContainingFile()); + + if (document != null) { + int startOffset = element.getTextRange().getStartOffset(); + int endOffset = element.getTextRange().getEndOffset(); + + return document.getLineNumber(endOffset) - document.getLineNumber(startOffset) + 1; + } + + return 0; + } + + private static List removeDuplicates(List refs) { + if (CollectionUtils.isEmpty(refs)) { + return Collections.emptyList(); + } + Set uniqueRefs = new HashSet<>(refs); + TreeSet res = new TreeSet<>(uniqueRefs); + uniqueRefs.forEach(ref -> { + if (StringUtils.contains(ref, "#")) { + String[] split = StringUtils.split(ref, "#"); + String className = split[0]; + if (res.contains(className)) { + res.remove(ref); + } + } + }); + return new ArrayList<>(res); + } + + private static List doRecall(Project project, List references) { + List res = new ArrayList<>(); + if (CollectionUtils.isEmpty(references)) { + return Collections.emptyList(); + } + for (String ref : references) { + if (StringUtils.contains(ref, "#")) { + String[] split = StringUtils.split(ref, "#"); + String className = split[0]; + String methodName = split[1]; + res.add(methodRecall(project, className, methodName)); + } else { + res.add(classRecall(project, ref)); + } + } + return res; + } + + public static PsiElement classRecall(Project project, String clz) { + PsiClass psiClass = findPsiClass(project, clz); + if (psiClass == null) { + return null; + } + + if (psiClass.isInterface()) { + PsiClass first = ClassInheritorsSearch.search(psiClass).findFirst(); + if (first != null) { + psiClass = first; + } + } + + if (isCompiled(psiClass)) { + PsiClass sourceMirror = ((ClsClassImpl) psiClass).getSourceMirrorClass(); + if (sourceMirror != null) { + psiClass = sourceMirror; + } + if (StringUtils.contains(psiClass.getText(), "/* compiled code */")) { + return null; + } + } + return psiClass; + } + + private static PsiElement methodRecall(Project project, String clz, String methodName) { + PsiClass psiClass = findPsiClass(project, clz); + if (psiClass == null) { + return null; + } + if (psiClass.isInterface()) { + PsiClass first = ClassInheritorsSearch.search(psiClass).findFirst(); + if (first != null) { + psiClass = first; + } + } + if (isCompiled(psiClass)) { + PsiClass sourceMirror = ((ClsClassImpl) psiClass).getSourceMirrorClass(); + if (sourceMirror != null) { + psiClass = sourceMirror; + } + } + PsiMethod psiMethod = Arrays.stream(psiClass.findMethodsByName(methodName, true)).max(Comparator.comparingInt(o -> o.getParameters().length)).orElse(null); + if (isValidMethod(psiMethod, psiClass)) { + return psiMethod; + } + return null; + } + + private static PsiClass findPsiClass(Project project, String clz) { + if (project == null || StringUtils.isEmpty(clz)) { + return null; + } + if (StringUtils.contains(clz, "$")) { + clz = clz.replace('$', '.'); + } + + // this method can only find out inner classes in xxx.Innerclass + return JavaPsiFacade.getInstance(project).findClass(clz, GlobalSearchScope.allScope(project)); + } + + private static boolean isValidMethod(PsiMethod psiMethod, PsiClass psiClass) { + if (psiMethod == null) return false; + if (PropertyUtil.isSimpleGetter(psiMethod) || PropertyUtil.isSimpleSetter(psiMethod)) return false; + if (isCompiled(psiClass)) { + return !StringUtils.contains(psiMethod.getText(), "/* compiled code */"); + } else { + return psiMethod.getBody() != null; + } + } + + public static boolean isCompiled(@NotNull PsiClass psiClass) { + return psiClass instanceof ClsClassImpl; + } + + /** + * used in rag case if need. + * + * @param project + * @param psiClass + * @return + */ + public static Pair getJarTitleAndVersion(Project project, PsiClass psiClass) { + PsiFile psiFile = psiClass.getContainingFile(); + VirtualFile virtualFile = psiFile.getVirtualFile(); + ProjectFileIndex fileIndex = ProjectFileIndex.SERVICE.getInstance(project); + String title = "", version = ""; + if (fileIndex.isInLibraryClasses(virtualFile)) { + VirtualFile classRoot = fileIndex.getClassRootForFile(virtualFile); + if (classRoot != null) { + String jarPath = classRoot.getPresentableUrl(); + Manifest manifest; + try (JarFile jarFile = new JarFile(jarPath)) { + manifest = jarFile.getManifest(); + } catch (Exception e) { + return Pair.of(title, version); + } + if (manifest != null) { + Attributes mainAttributes = manifest.getMainAttributes(); + version = mainAttributes.getValue(Attributes.Name.IMPLEMENTATION_VERSION); + title = mainAttributes.getValue(Attributes.Name.IMPLEMENTATION_TITLE); + } + } + } + return Pair.of(title, version); + } + } diff --git a/src/main/java/com/zhongan/devpilot/webview/model/CodeReferenceModel.java b/src/main/java/com/zhongan/devpilot/webview/model/CodeReferenceModel.java index 64c457fe..f516f918 100644 --- a/src/main/java/com/zhongan/devpilot/webview/model/CodeReferenceModel.java +++ b/src/main/java/com/zhongan/devpilot/webview/model/CodeReferenceModel.java @@ -2,9 +2,19 @@ import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.intellij.openapi.editor.Document; +import com.intellij.openapi.vfs.VirtualFile; +import com.intellij.psi.PsiDocumentManager; +import com.intellij.psi.PsiElement; import com.zhongan.devpilot.enums.EditorActionEnum; import com.zhongan.devpilot.gui.toolwindows.components.EditorInfo; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +import static com.zhongan.devpilot.util.PsiElementUtils.shouldIgnorePsiElement; + @JsonIgnoreProperties(ignoreUnknown = true) public class CodeReferenceModel { private String languageId; @@ -38,6 +48,74 @@ public static CodeReferenceModel getCodeRefFromEditor(EditorInfo editorInfo, Edi editorInfo.getSelectedStartColumn(), editorInfo.getSelectedEndLine(), editorInfo.getSelectedEndColumn(), actionEnum); } + public static List getCodeRefListFromPsiElement(Collection list, EditorActionEnum actionEnum) { + if (list == null) { + return null; + } + + var result = new ArrayList(); + + for (PsiElement element : list) { + if (shouldIgnorePsiElement(element)) { + continue; + } + var ref = getCodeRefFromPsiElement(element, actionEnum); + if (ref != null) { + result.add(ref); + } + } + + return result; + } + + public static CodeReferenceModel getCodeRefFromPsiElement(PsiElement element, EditorActionEnum actionEnum) { + if (element == null) { + return null; + } + + var languageId = element.getLanguage().getID(); + var sourceCode = element.getText(); + + var psiFile = element.getContainingFile(); + VirtualFile file = null; + Document document = null; + + if (psiFile != null) { + file = psiFile.getVirtualFile(); + var project = element.getProject(); + document = PsiDocumentManager.getInstance(project).getDocument(psiFile); + } + + String filePath = null; + String fileName = null; + + if (file != null) { + filePath = file.getPath(); + fileName = file.getName(); + } + + Integer startLine = null; + Integer endLine = null; + + Integer startColumn = null; + Integer endColumn = null; + + if (document != null) { + var textRange = element.getTextRange(); + int startOffset = textRange.getStartOffset(); + int endOffset = textRange.getEndOffset(); + + startLine = document.getLineNumber(startOffset); + endLine = document.getLineNumber(endOffset); + + startColumn = startOffset - document.getLineStartOffset(startLine); + endColumn = endOffset - document.getLineStartOffset(endLine); + } + + return new CodeReferenceModel( + languageId, filePath, fileName, sourceCode, startLine, startColumn, endLine, endColumn, actionEnum); + } + public CodeReferenceModel(String languageId, String fileUrl, String fileName, String sourceCode, Integer selectedStartLine, Integer selectedStartColumn, Integer selectedEndLine, Integer selectedEndColumn, EditorActionEnum type) { diff --git a/src/main/java/com/zhongan/devpilot/webview/model/MessageModel.java b/src/main/java/com/zhongan/devpilot/webview/model/MessageModel.java index 70ca24ea..76e211fe 100644 --- a/src/main/java/com/zhongan/devpilot/webview/model/MessageModel.java +++ b/src/main/java/com/zhongan/devpilot/webview/model/MessageModel.java @@ -22,6 +22,8 @@ public class MessageModel { private CodeReferenceModel codeRef; + private RecallModel recall; + public static MessageModel buildCodeMessage(String id, Long time, String content, String username, CodeReferenceModel codeReference) { MessageModel messageModel = new MessageModel(); @@ -41,10 +43,10 @@ public static MessageModel buildUserMessage(String id, Long time, String content } public static MessageModel buildInfoMessage(String content) { - return buildAssistantMessage(UUID.randomUUID().toString(), System.currentTimeMillis(), content, false); + return buildAssistantMessage(UUID.randomUUID().toString(), System.currentTimeMillis(), content, false, null); } - public static MessageModel buildAssistantMessage(String id, Long time, String content, boolean streaming) { + public static MessageModel buildAssistantMessage(String id, Long time, String content, boolean streaming, RecallModel recall) { MessageModel messageModel = new MessageModel(); messageModel.setId(id); messageModel.setTime(time); @@ -54,12 +56,13 @@ public static MessageModel buildAssistantMessage(String id, Long time, String co messageModel.setAvatar(null); messageModel.setStreaming(streaming); messageModel.setCodeRef(null); + messageModel.setRecall(recall); return messageModel; } public static MessageModel buildDividerMessage() { MessageModel messageModel = new MessageModel(); - messageModel.setId("-1"); + messageModel.setId(System.currentTimeMillis() + ""); messageModel.setTime(System.currentTimeMillis()); messageModel.setRole("divider"); messageModel.setContent(null); @@ -71,7 +74,7 @@ public static MessageModel buildDividerMessage() { public static MessageModel buildLoadingMessage() { MessageModel messageModel = new MessageModel(); - messageModel.setId("-1"); + messageModel.setId(System.currentTimeMillis() + ""); messageModel.setTime(System.currentTimeMillis()); messageModel.setRole("assistant"); messageModel.setContent("..."); @@ -144,4 +147,12 @@ public CodeReferenceModel getCodeRef() { public void setCodeRef(CodeReferenceModel codeRef) { this.codeRef = codeRef; } + + public RecallModel getRecall() { + return recall; + } + + public void setRecall(RecallModel recall) { + this.recall = recall; + } } diff --git a/src/main/java/com/zhongan/devpilot/webview/model/RecallModel.java b/src/main/java/com/zhongan/devpilot/webview/model/RecallModel.java new file mode 100644 index 00000000..06c7ed86 --- /dev/null +++ b/src/main/java/com/zhongan/devpilot/webview/model/RecallModel.java @@ -0,0 +1,105 @@ +package com.zhongan.devpilot.webview.model; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; + +import java.util.ArrayList; +import java.util.List; + +@JsonIgnoreProperties(ignoreUnknown = true) +public class RecallModel { + private List steps; + + private List remoteRefs; + + private List localRefs; + + public RecallModel(List steps, List remoteRefs, List localRefs) { + this.steps = steps; + this.remoteRefs = remoteRefs; + this.localRefs = localRefs; + } + + public static RecallModel create(int step) { + return create(step, null, null, 1); + } + + public static RecallModel createTerminated(int step) { + return create(step, null, null, 2); + } + + public static RecallModel createTerminated(int step, List remoteRefs, List localRefs) { + return create(step, remoteRefs, localRefs, 2); + } + + public static RecallModel create(int step, List remoteRefs, List localRefs) { + return create(step, remoteRefs, localRefs, 1); + } + + // 1 - step1 doing; 2 - step2 doing; 3 - step3 doing; 4 - step3 done + // type: 1 - loading; 2 - terminated + private static RecallModel create(int step, List remoteRefs, List localRefs, int type) { + var steps = new ArrayList(); + + for (int i = 1; i <= step - 1; i++) { + steps.add(new Step("done")); + } + + if (step < 4) { + if (type == 2) { + steps.add(new Step("terminated")); + } else { + steps.add(new Step("loading")); + } + } + + if (remoteRefs == null) { + remoteRefs = new ArrayList<>(); + } + + if (localRefs == null) { + localRefs = new ArrayList<>(); + } + + return new RecallModel(steps, remoteRefs, localRefs); + } + + public List getSteps() { + return steps; + } + + public void setSteps(List steps) { + this.steps = steps; + } + + public List getRemoteRefs() { + return remoteRefs; + } + + public void setRemoteRefs(List remoteRefs) { + this.remoteRefs = remoteRefs; + } + + public List getLocalRefs() { + return localRefs; + } + + public void setLocalRefs(List localRefs) { + this.localRefs = localRefs; + } + + static class Step { + private String status; + + Step(String status) { + this.status = status; + } + + public String getStatus() { + return status; + } + + public void setStatus(String status) { + this.status = status; + } + } +} diff --git a/src/main/resources/META-INF/plugin.xml b/src/main/resources/META-INF/plugin.xml index f0a175df..539b0608 100644 --- a/src/main/resources/META-INF/plugin.xml +++ b/src/main/resources/META-INF/plugin.xml @@ -197,10 +197,6 @@ id="DevPilot.GenerateGitCommitMessage" class="com.zhongan.devpilot.actions.changesview.GenerateGitCommitMessageAction" /> - - - diff --git a/src/main/resources/messages/devpilot_en.properties b/src/main/resources/messages/devpilot_en.properties index 1ccbc543..d209476b 100644 --- a/src/main/resources/messages/devpilot_en.properties +++ b/src/main/resources/messages/devpilot_en.properties @@ -17,6 +17,7 @@ devpilot.action.new.chat=Open DevPilot Chat devpilot.action.new.chat.desc=New chat with DevPilot devpilot.action.reference.chat=Reference Code in Chat devpilot.action.generate.comments=Generate Comments +devpilot.action.generate.method.comments=Generate Method Comments devpilot.action.generate.tests=Generate Tests devpilot.action.fix=Fix This devpilot.action.review=Code Review diff --git a/src/main/resources/messages/devpilot_zh.properties b/src/main/resources/messages/devpilot_zh.properties index 1e9575ec..30ecd012 100644 --- a/src/main/resources/messages/devpilot_zh.properties +++ b/src/main/resources/messages/devpilot_zh.properties @@ -17,6 +17,7 @@ devpilot.action.new.chat=\u6253\u5f00DevPilot\u4f1A\u8BDD devpilot.action.new.chat.desc=\u6253\u5F00\u65B0\u7684DevPilot\u4F1A\u8BDD devpilot.action.reference.chat=\u5f15\u7528\u4ee3\u7801\u5757 devpilot.action.generate.comments=\u884C\u5185\u6CE8\u91CA +devpilot.action.generate.method.comments=\u751F\u6210\u65B9\u6CD5\u6CE8\u91CA devpilot.action.generate.tests=\u751F\u6210\u5355\u6D4B devpilot.action.fix=\u4FEE\u590D\u4EE3\u7801 devpilot.action.review=\u4EE3\u7801\u5BA1\u67E5 diff --git a/src/main/resources/webview/index.html b/src/main/resources/webview/index.html index 78dcdf3a..9248e5f8 100644 --- a/src/main/resources/webview/index.html +++ b/src/main/resources/webview/index.html @@ -7,85 +7,14 @@ DevPilot