From f2aedb0359278fc90c61bb7af5adf47d23ee686a Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Tue, 19 Dec 2023 19:47:12 +0000 Subject: [PATCH] add semaphore and illegal state exception to chat (#21) Co-authored-by: David Motsonashvili Co-authored-by: Rodrigo Lazo --- .changes/calculator-bag-baby-chair.json | 1 + .../com/google/ai/client/generativeai/Chat.kt | 38 +++++++++++++++---- 2 files changed, 32 insertions(+), 7 deletions(-) create mode 100644 .changes/calculator-bag-baby-chair.json diff --git a/.changes/calculator-bag-baby-chair.json b/.changes/calculator-bag-baby-chair.json new file mode 100644 index 00000000..f3adcea3 --- /dev/null +++ b/.changes/calculator-bag-baby-chair.json @@ -0,0 +1 @@ +{"type":"MINOR","changes":["An instance of Chat will now throw an InvalidStateException if multiple requests are made simultaneously."]} diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt index dbb76aa0..2abe4b72 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt @@ -25,6 +25,7 @@ import com.google.ai.client.generativeai.type.InvalidStateException import com.google.ai.client.generativeai.type.TextPart import com.google.ai.client.generativeai.type.content import java.util.LinkedList +import java.util.concurrent.Semaphore import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.onCompletion import kotlinx.coroutines.flow.onEach @@ -35,10 +36,14 @@ import kotlinx.coroutines.flow.onEach * Handles the capturing and storage of the communication with the model, providing methods for * further interaction. * + * Note: This object is not thread-safe, and calling [sendMessage] multiple times without waiting + * for a response will throw an [InvalidStateException]. + * * @param model the model to use for the interaction * @property history the previous interactions with the model */ class Chat(private val model: GenerativeModel, val history: MutableList = ArrayList()) { + private var lock = Semaphore(1) /** * Generates a response from the backend with the provided [Content], and any previous ones @@ -46,22 +51,26 @@ class Chat(private val model: GenerativeModel, val history: MutableList * * @param prompt A [Content] to send to the model. * @throws InvalidStateException if the prompt is not coming from the 'user' role + * @throws InvalidStateException if the [Chat] instance has an active request. */ suspend fun sendMessage(prompt: Content): GenerateContentResponse { prompt.assertComesFromUser() - - val response = model.generateContent(*history.toTypedArray(), prompt) - - history.add(prompt) - history.add(response.candidates.first().content) - - return response + attemptLock() + try { + val response = model.generateContent(*history.toTypedArray(), prompt) + history.add(prompt) + history.add(response.candidates.first().content) + return response + } finally { + lock.release() + } } /** * Generates a response from the backend with the provided text represented [Content]. * * @param prompt The text to be converted into a single piece of [Content] to send to the model. + * @throws InvalidStateException if the [Chat] instance has an active request. */ suspend fun sendMessage(prompt: String): GenerateContentResponse { val content = content("user") { text(prompt) } @@ -72,6 +81,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList * Generates a response from the backend with the provided image represented [Content]. * * @param prompt The image to be converted into a single piece of [Content] to send to the model. + * @throws InvalidStateException if the [Chat] instance has an active request. */ suspend fun sendMessage(prompt: Bitmap): GenerateContentResponse { val content = content("user") { image(prompt) } @@ -84,9 +94,11 @@ class Chat(private val model: GenerativeModel, val history: MutableList * @param prompt A [Content] to send to the model. * @return A [Flow] which will emit responses as they are returned from the model. * @throws InvalidStateException if the prompt is not coming from the 'user' role + * @throws InvalidStateException if the [Chat] instance has an active request. */ fun sendMessageStream(prompt: Content): Flow { prompt.assertComesFromUser() + attemptLock() val flow = model.generateContentStream(*history.toTypedArray(), prompt) val bitmaps = LinkedList() @@ -109,6 +121,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList } } .onCompletion { + lock.release() if (it == null) { val content = content("model") { @@ -134,6 +147,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList * * @param prompt A [Content] to send to the model. * @return A [Flow] which will emit responses as they are returned from the model. + * @throws InvalidStateException if the [Chat] instance has an active request. */ fun sendMessageStream(prompt: String): Flow { val content = content("user") { text(prompt) } @@ -145,6 +159,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList * * @param prompt A [Content] to send to the model. * @return A [Flow] which will emit responses as they are returned from the model. + * @throws InvalidStateException if the [Chat] instance has an active request. */ fun sendMessageStream(prompt: Bitmap): Flow { val content = content("user") { image(prompt) } @@ -156,4 +171,13 @@ class Chat(private val model: GenerativeModel, val history: MutableList throw InvalidStateException("Chat prompts should come from the 'user' role.") } } + + private fun attemptLock() { + if (!lock.tryAcquire()) { + throw InvalidStateException( + "This chat instance currently has an ongoing request, please wait for it to complete " + + "before sending more messages" + ) + } + } }