Skip to content

Commit

Permalink
feat(prompts): support amend commits (#230)
Browse files Browse the repository at this point in the history
Move logic for building prompt from AICommitAction to LLMClientService in order to run the code in coroutines.

Closes #230
  • Loading branch information
Blarc committed Sep 13, 2024
1 parent 1189812 commit 1077f09
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 45 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Added

- Option to choose prompt per project.
- Amending commits now adds the changes from previous commit to the prompt.

### Fixed

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
package com.github.blarc.ai.commits.intellij.plugin

import com.github.blarc.ai.commits.intellij.plugin.AICommitsUtils.commonBranch
import com.github.blarc.ai.commits.intellij.plugin.AICommitsUtils.computeDiff
import com.github.blarc.ai.commits.intellij.plugin.AICommitsUtils.constructPrompt
import com.github.blarc.ai.commits.intellij.plugin.notifications.Notification
import com.github.blarc.ai.commits.intellij.plugin.notifications.sendNotification
import com.github.blarc.ai.commits.intellij.plugin.settings.AppSettings2
Expand All @@ -17,42 +14,23 @@ class AICommitAction : AnAction(), DumbAware {
override fun actionPerformed(e: AnActionEvent) {
val llmClient = AppSettings2.instance.getActiveLLMClientConfiguration()
if (llmClient == null) {
Notification.clientNotSet()
sendNotification(Notification.clientNotSet())
return
}
val project = e.project ?: return

val commitWorkflowHandler = e.getData(VcsDataKeys.COMMIT_WORKFLOW_HANDLER) as AbstractCommitWorkflowHandler<*, *>?
if (commitWorkflowHandler == null) {
sendNotification(Notification.noCommitMessage())
return
}

val includedChanges = commitWorkflowHandler.ui.getIncludedChanges()
val commitMessage = VcsDataKeys.COMMIT_MESSAGE_CONTROL.getData(e.dataContext) as CommitMessage?

val diff = computeDiff(includedChanges, false, project)
if (diff.isBlank()) {
sendNotification(Notification.emptyDiff())
return
}

val branch = commonBranch(includedChanges, project)
val hint = commitMessage?.text

val prompt = constructPrompt(AppSettings2.instance.activePrompt.content, diff, branch, hint, project)

// TODO @Blarc: add support for different clients
// if (isPromptTooLarge(prompt)) {
// sendNotification(Notification.promptTooLarge())
// return@runBackgroundableTask
// }

if (commitMessage == null) {
sendNotification(Notification.noCommitMessage())
return
}

llmClient.generateCommitMessage(prompt, project, commitMessage)
val project = e.project ?: return
llmClient.generateCommitMessage(commitWorkflowHandler, commitMessage, project)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import com.intellij.openapi.project.Project
import com.intellij.openapi.ui.ComboBox
import com.intellij.openapi.vcs.ui.CommitMessage
import com.intellij.util.xmlb.annotations.Attribute
import com.intellij.vcs.commit.AbstractCommitWorkflowHandler
import java.util.*
import javax.swing.Icon

Expand Down Expand Up @@ -38,7 +39,7 @@ abstract class LLMClientConfiguration(
getSharedState().modelIds.add(modelId)
}

abstract fun generateCommitMessage(prompt: String, project: Project, commitMessage: CommitMessage)
abstract fun generateCommitMessage(commitWorkflowHandler: AbstractCommitWorkflowHandler<*, *>, commitMessage: CommitMessage, project: Project)

abstract fun getRefreshModelsFunction(): ((ComboBox<String>) -> Unit)?

Expand Down
Original file line number Diff line number Diff line change
@@ -1,31 +1,62 @@
package com.github.blarc.ai.commits.intellij.plugin.settings.clients

import com.github.blarc.ai.commits.intellij.plugin.AICommitsBundle.message
import com.github.blarc.ai.commits.intellij.plugin.AICommitsUtils.commonBranch
import com.github.blarc.ai.commits.intellij.plugin.AICommitsUtils.computeDiff
import com.github.blarc.ai.commits.intellij.plugin.AICommitsUtils.constructPrompt
import com.github.blarc.ai.commits.intellij.plugin.notifications.Notification
import com.github.blarc.ai.commits.intellij.plugin.notifications.sendNotification
import com.github.blarc.ai.commits.intellij.plugin.settings.AppSettings2
import com.github.blarc.ai.commits.intellij.plugin.wrap
import com.intellij.icons.AllIcons
import com.intellij.openapi.application.EDT
import com.intellij.openapi.application.ModalityState
import com.intellij.openapi.application.asContextElement
import com.intellij.openapi.project.Project
import com.intellij.openapi.vcs.changes.Change
import com.intellij.openapi.vcs.ui.CommitMessage
import com.intellij.platform.ide.progress.withBackgroundProgress
import com.intellij.ui.components.JBLabel
import com.intellij.vcs.commit.AbstractCommitWorkflowHandler
import com.intellij.vcs.commit.isAmendCommitMode
import dev.langchain4j.data.message.UserMessage
import dev.langchain4j.model.chat.ChatLanguageModel
import git4idea.GitCommit
import git4idea.history.GitHistoryUtils
import git4idea.repo.GitRepositoryManager
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext

abstract class LLMClientService<T : LLMClientConfiguration>(private val cs: CoroutineScope) {
abstract class LLMClientService<C : LLMClientConfiguration>(private val cs: CoroutineScope) {

abstract suspend fun buildChatModel(client: T): ChatLanguageModel
abstract suspend fun buildChatModel(client: C): ChatLanguageModel

fun generateCommitMessage(client: T, prompt: String, project: Project, commitMessage: CommitMessage) {
cs.launch(Dispatchers.IO + ModalityState.current().asContextElement()) {
fun generateCommitMessage(clientConfiguration: C, commitWorkflowHandler: AbstractCommitWorkflowHandler<*, *>, commitMessage: CommitMessage, project: Project) {

val commitContext = commitWorkflowHandler.workflow.commitContext
val includedChanges = commitWorkflowHandler.ui.getIncludedChanges().toMutableList()

cs.launch(ModalityState.current().asContextElement()) {
withBackgroundProgress(project, message("action.background")) {
sendRequest(client, prompt, onSuccess = {

if (commitContext.isAmendCommitMode) {
includedChanges += getLastCommitChanges(project)
}

val diff = computeDiff(includedChanges, false, project)
if (diff.isBlank()) {
withContext(Dispatchers.EDT) {
sendNotification(Notification.emptyDiff())
}
return@withBackgroundProgress
}

val branch = commonBranch(includedChanges, project)
val prompt = constructPrompt(AppSettings2.instance.activePrompt.content, diff, branch, commitMessage.text, project)

sendRequest(clientConfiguration, prompt, onSuccess = {
withContext(Dispatchers.EDT) {
commitMessage.setCommitMessage(it)
}
Expand All @@ -39,9 +70,9 @@ abstract class LLMClientService<T : LLMClientConfiguration>(private val cs: Coro
}
}

fun verifyConfiguration(client: T, label: JBLabel) {
fun verifyConfiguration(client: C, label: JBLabel) {
label.text = message("settings.verify.running")
cs.launch(Dispatchers.IO + ModalityState.current().asContextElement()) {
cs.launch(ModalityState.current().asContextElement()) {
sendRequest(client, "test", onSuccess = {
withContext(Dispatchers.EDT) {
label.text = message("settings.verify.valid")
Expand All @@ -56,7 +87,7 @@ abstract class LLMClientService<T : LLMClientConfiguration>(private val cs: Coro
}
}

private suspend fun sendRequest(client: T, text: String, onSuccess: suspend (r: String) -> Unit, onError: suspend (r: String) -> Unit) {
private suspend fun sendRequest(client: C, text: String, onSuccess: suspend (r: String) -> Unit, onError: suspend (r: String) -> Unit) {
try {
val model = buildChatModel(client)
val response = withContext(Dispatchers.IO) {
Expand All @@ -78,4 +109,16 @@ abstract class LLMClientService<T : LLMClientConfiguration>(private val cs: Coro
throw e
}
}

private suspend fun getLastCommitChanges(project: Project): List<Change> {
return withContext(Dispatchers.IO) {
GitRepositoryManager.getInstance(project).repositories.map { repo ->
GitHistoryUtils.history(project, repo.root, "--max-count=1")
}.filter { commits ->
commits.isNotEmpty()
}.map { commits ->
(commits.first() as GitCommit).changes
}.flatten()
}
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.github.blarc.ai.commits.intellij.plugin.settings.clients.anthropic;
package com.github.blarc.ai.commits.intellij.plugin.settings.clients.anthropic

import com.github.blarc.ai.commits.intellij.plugin.Icons
import com.github.blarc.ai.commits.intellij.plugin.settings.clients.LLMClientConfiguration
Expand All @@ -7,6 +7,7 @@ import com.intellij.openapi.project.Project
import com.intellij.openapi.vcs.ui.CommitMessage
import com.intellij.util.xmlb.annotations.Attribute
import com.intellij.util.xmlb.annotations.Transient
import com.intellij.vcs.commit.AbstractCommitWorkflowHandler
import dev.langchain4j.model.anthropic.AnthropicChatModelName
import javax.swing.Icon

Expand Down Expand Up @@ -44,8 +45,8 @@ class AnthropicClientConfiguration : LLMClientConfiguration(
return AnthropicClientSharedState.getInstance()
}

override fun generateCommitMessage(prompt: String, project: Project, commitMessage: CommitMessage) {
return AnthropicClientService.getInstance().generateCommitMessage(this, prompt, project, commitMessage)
override fun generateCommitMessage(commitWorkflowHandler: AbstractCommitWorkflowHandler<*, *>, commitMessage: CommitMessage, project: Project) {
return AnthropicClientService.getInstance().generateCommitMessage(this, commitWorkflowHandler, commitMessage, project)
}

override fun getRefreshModelsFunction() = null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.github.blarc.ai.commits.intellij.plugin.settings.clients.LLMClientSha
import com.intellij.openapi.project.Project
import com.intellij.openapi.vcs.ui.CommitMessage
import com.intellij.util.xmlb.annotations.Attribute
import com.intellij.vcs.commit.AbstractCommitWorkflowHandler
import javax.swing.Icon

class GeminiClientConfiguration : LLMClientConfiguration(
Expand Down Expand Up @@ -34,8 +35,8 @@ class GeminiClientConfiguration : LLMClientConfiguration(
return GeminiClientSharedState.getInstance()
}

override fun generateCommitMessage(prompt: String, project: Project, commitMessage: CommitMessage) {
return GeminiClientService.getInstance().generateCommitMessage(this, prompt, project, commitMessage)
override fun generateCommitMessage(commitWorkflowHandler: AbstractCommitWorkflowHandler<*, *>, commitMessage: CommitMessage, project: Project) {
return GeminiClientService.getInstance().generateCommitMessage(this, commitWorkflowHandler, commitMessage, project)
}

// Model names are hard-coded and do not need to be refreshed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import com.intellij.openapi.project.Project
import com.intellij.openapi.ui.ComboBox
import com.intellij.openapi.vcs.ui.CommitMessage
import com.intellij.util.xmlb.annotations.Attribute
import com.intellij.vcs.commit.AbstractCommitWorkflowHandler
import javax.swing.Icon

class OllamaClientConfiguration : LLMClientConfiguration(
Expand Down Expand Up @@ -36,8 +37,8 @@ class OllamaClientConfiguration : LLMClientConfiguration(
return OllamaClientSharedState.getInstance()
}

override fun generateCommitMessage(prompt: String, project: Project, commitMessage: CommitMessage) {
return OllamaClientService.getInstance().generateCommitMessage(this, prompt, project, commitMessage)
override fun generateCommitMessage(commitWorkflowHandler: AbstractCommitWorkflowHandler<*, *>, commitMessage: CommitMessage, project: Project) {
return OllamaClientService.getInstance().generateCommitMessage(this, commitWorkflowHandler, commitMessage, project)
}

override fun getRefreshModelsFunction() = fun (cb: ComboBox<String>) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.intellij.openapi.project.Project
import com.intellij.openapi.vcs.ui.CommitMessage
import com.intellij.util.xmlb.annotations.Attribute
import com.intellij.util.xmlb.annotations.Transient
import com.intellij.vcs.commit.AbstractCommitWorkflowHandler
import javax.swing.Icon

class OpenAiClientConfiguration : LLMClientConfiguration(
Expand Down Expand Up @@ -43,8 +44,8 @@ class OpenAiClientConfiguration : LLMClientConfiguration(
return OpenAiClientSharedState.getInstance()
}

override fun generateCommitMessage(prompt: String, project: Project, commitMessage: CommitMessage) {
return OpenAiClientService.getInstance().generateCommitMessage(this, prompt, project, commitMessage)
override fun generateCommitMessage(commitWorkflowHandler: AbstractCommitWorkflowHandler<*, *>, commitMessage: CommitMessage, project: Project) {
return OpenAiClientService.getInstance().generateCommitMessage(this, commitWorkflowHandler, commitMessage, project)
}

// Model names are retrieved from Enum and do not need to be refreshed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.intellij.openapi.project.Project
import com.intellij.openapi.vcs.ui.CommitMessage
import com.intellij.util.xmlb.annotations.Attribute
import com.intellij.util.xmlb.annotations.Transient
import com.intellij.vcs.commit.AbstractCommitWorkflowHandler
import dev.langchain4j.model.qianfan.QianfanChatModelNameEnum
import javax.swing.Icon

Expand Down Expand Up @@ -41,8 +42,8 @@ class QianfanClientConfiguration : LLMClientConfiguration(
return QianfanClientSharedState.getInstance()
}

override fun generateCommitMessage(prompt: String, project: Project, commitMessage: CommitMessage) {
return QianfanClientService.getInstance().generateCommitMessage(this, prompt, project, commitMessage)
override fun generateCommitMessage(commitWorkflowHandler: AbstractCommitWorkflowHandler<*, *>, commitMessage: CommitMessage, project: Project) {
return QianfanClientService.getInstance().generateCommitMessage(this, commitWorkflowHandler, commitMessage, project)
}

// Model names are retrieved from Enum and do not need to be refreshed.
Expand Down

0 comments on commit 1077f09

Please sign in to comment.