From a0db29cf9c4ccccac5adf7d0396ee9d95e402b9a Mon Sep 17 00:00:00 2001 From: Blarc Date: Wed, 9 Oct 2024 20:19:10 +0200 Subject: [PATCH] feat(clients): show progress for streaming response --- .../settings/clients/LLMClientService.kt | 33 +++++++++++-------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/clients/LLMClientService.kt b/src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/clients/LLMClientService.kt index 4d04940..695dd97 100644 --- a/src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/clients/LLMClientService.kt +++ b/src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/clients/LLMClientService.kt @@ -26,13 +26,11 @@ import dev.langchain4j.data.message.UserMessage import dev.langchain4j.model.StreamingResponseHandler import dev.langchain4j.model.chat.ChatLanguageModel import dev.langchain4j.model.chat.StreamingChatLanguageModel +import dev.langchain4j.model.output.Response import git4idea.GitCommit import git4idea.history.GitHistoryUtils import git4idea.repo.GitRepositoryManager -import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.launch -import kotlinx.coroutines.withContext +import kotlinx.coroutines.* abstract class LLMClientService(private val cs: CoroutineScope) { @@ -81,7 +79,7 @@ abstract class LLMClientService(private val cs: Coro fun verifyConfiguration(client: C, label: JBLabel) { label.text = message("settings.verify.running") cs.launch(ModalityState.current().asContextElement()) { - sendRequest(client, "test", onSuccess = { + makeRequest(client, "test", onSuccess = { withContext(Dispatchers.EDT) { label.text = message("settings.verify.valid") label.icon = AllIcons.General.InspectionsOK @@ -99,11 +97,11 @@ abstract class LLMClientService(private val cs: Coro try { if (AppSettings2.instance.useStreamingResponse) { buildStreamingChatModel(client)?.let { streamingChatModel -> - sendStreamingRequest(streamingChatModel, text, onSuccess, onError) + sendStreamingRequest(streamingChatModel, text, onSuccess) return } } - sendRequest(client, text, onSuccess, onError) + sendRequest(client, text, onSuccess) } catch (e: IllegalArgumentException) { onError(message("settings.verify.invalid", e.message ?: message("unknown-error"))) } catch (e: Exception) { @@ -113,8 +111,10 @@ abstract class LLMClientService(private val cs: Coro } } - private suspend fun sendStreamingRequest(streamingModel: StreamingChatLanguageModel, text: String, onSuccess: suspend (r: String) -> Unit, onError: suspend (r: String) -> Unit) { + private suspend fun sendStreamingRequest(streamingModel: StreamingChatLanguageModel, text: String, onSuccess: suspend (r: String) -> Unit) { var response = "" + val completionDeferred = CompletableDeferred() + withContext(Dispatchers.IO) { streamingModel.generate( listOf( @@ -131,18 +131,23 @@ abstract class LLMClientService(private val cs: Coro } } - override fun onError(error: Throwable?) { - response = error?.message.toString() - cs.launch { - onError(response) - } + override fun onError(error: Throwable) { + completionDeferred.completeExceptionally(error) + } + + override fun onComplete(response: Response) { + super.onComplete(response) + completionDeferred.complete(response.content().text()) } } ) + // This throws exception if completionDeferred.completeExceptionally(error) is called + // which is handled by the function calling this function + onSuccess(completionDeferred.await()) } } - private suspend fun sendRequest(client: C, text: String, onSuccess: suspend (r: String) -> Unit, onError: suspend (r: String) -> Unit) { + private suspend fun sendRequest(client: C, text: String, onSuccess: suspend (r: String) -> Unit) { val model = buildChatModel(client) val response = withContext(Dispatchers.IO) { model.generate(