Skip to content

Commit

Permalink
feat(clients): show progress for streaming response
Browse files Browse the repository at this point in the history
  • Loading branch information
Blarc committed Oct 9, 2024
1 parent 0772d43 commit a0db29c
Showing 1 changed file with 19 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,11 @@ import dev.langchain4j.data.message.UserMessage
import dev.langchain4j.model.StreamingResponseHandler
import dev.langchain4j.model.chat.ChatLanguageModel
import dev.langchain4j.model.chat.StreamingChatLanguageModel
import dev.langchain4j.model.output.Response
import git4idea.GitCommit
import git4idea.history.GitHistoryUtils
import git4idea.repo.GitRepositoryManager
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import kotlinx.coroutines.*

abstract class LLMClientService<C : LLMClientConfiguration>(private val cs: CoroutineScope) {

Expand Down Expand Up @@ -81,7 +79,7 @@ abstract class LLMClientService<C : LLMClientConfiguration>(private val cs: Coro
fun verifyConfiguration(client: C, label: JBLabel) {
label.text = message("settings.verify.running")
cs.launch(ModalityState.current().asContextElement()) {
sendRequest(client, "test", onSuccess = {
makeRequest(client, "test", onSuccess = {
withContext(Dispatchers.EDT) {
label.text = message("settings.verify.valid")
label.icon = AllIcons.General.InspectionsOK
Expand All @@ -99,11 +97,11 @@ abstract class LLMClientService<C : LLMClientConfiguration>(private val cs: Coro
try {
if (AppSettings2.instance.useStreamingResponse) {
buildStreamingChatModel(client)?.let { streamingChatModel ->
sendStreamingRequest(streamingChatModel, text, onSuccess, onError)
sendStreamingRequest(streamingChatModel, text, onSuccess)
return
}
}
sendRequest(client, text, onSuccess, onError)
sendRequest(client, text, onSuccess)
} catch (e: IllegalArgumentException) {
onError(message("settings.verify.invalid", e.message ?: message("unknown-error")))
} catch (e: Exception) {
Expand All @@ -113,8 +111,10 @@ abstract class LLMClientService<C : LLMClientConfiguration>(private val cs: Coro
}
}

private suspend fun sendStreamingRequest(streamingModel: StreamingChatLanguageModel, text: String, onSuccess: suspend (r: String) -> Unit, onError: suspend (r: String) -> Unit) {
private suspend fun sendStreamingRequest(streamingModel: StreamingChatLanguageModel, text: String, onSuccess: suspend (r: String) -> Unit) {
var response = ""
val completionDeferred = CompletableDeferred<String>()

withContext(Dispatchers.IO) {
streamingModel.generate(
listOf(
Expand All @@ -131,18 +131,23 @@ abstract class LLMClientService<C : LLMClientConfiguration>(private val cs: Coro
}
}

override fun onError(error: Throwable?) {
response = error?.message.toString()
cs.launch {
onError(response)
}
override fun onError(error: Throwable) {
completionDeferred.completeExceptionally(error)
}

override fun onComplete(response: Response<AiMessage>) {
super.onComplete(response)
completionDeferred.complete(response.content().text())
}
}
)
// This throws exception if completionDeferred.completeExceptionally(error) is called
// which is handled by the function calling this function
onSuccess(completionDeferred.await())
}
}

private suspend fun sendRequest(client: C, text: String, onSuccess: suspend (r: String) -> Unit, onError: suspend (r: String) -> Unit) {
private suspend fun sendRequest(client: C, text: String, onSuccess: suspend (r: String) -> Unit) {
val model = buildChatModel(client)
val response = withContext(Dispatchers.IO) {
model.generate(
Expand Down

0 comments on commit a0db29c

Please sign in to comment.