From 478633834269372bd28c21ea9580bd6d523b1044 Mon Sep 17 00:00:00 2001 From: Daymon Date: Tue, 19 Dec 2023 13:25:54 -0600 Subject: [PATCH] Migrate model name to APIController --- .../ai/client/generativeai/GenerativeModel.kt | 16 +--------------- .../generativeai/internal/api/APIController.kt | 12 +++++++++++- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt index e9b205d7..15064aa7 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt @@ -61,13 +61,7 @@ internal constructor( apiKey: String, generationConfig: GenerationConfig? = null, safetySettings: List? = null, - ) : this( - modelName, - apiKey, - generationConfig, - safetySettings, - APIController(apiKey, fullModelName(modelName)) - ) + ) : this(modelName, apiKey, generationConfig, safetySettings, APIController(apiKey, modelName)) /** * Generates a response from the backend with the provided [Content]s. @@ -188,11 +182,3 @@ internal constructor( ?.let { throw ResponseStoppedException(this) } } } - -/** - * Ensures the model name provided has a `models/` prefix - * - * Models must be prepended with the `models/` prefix when communicating with the backend. - */ -private fun fullModelName(name: String): String = - name.takeIf { it.startsWith("models/") } ?: "models/$name" diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt index 499fae40..e46016ca 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt @@ -63,9 +63,11 @@ internal val JSON = Json { */ internal class APIController( private val key: String, - private val model: String, + model: String, httpEngine: HttpClientEngine = OkHttp.create() ) { + private val model = fullModelName(model) + private val client = HttpClient(httpEngine) { install(HttpTimeout) { @@ -106,6 +108,14 @@ internal class APIController( } } +/** + * Ensures the model name provided has a `models/` prefix + * + * Models must be prepended with the `models/` prefix when communicating with the backend. + */ +private fun fullModelName(name: String): String = + name.takeIf { it.startsWith("models/") } ?: "models/$name" + /** * Makes a POST request to the specified [url] and returns a [Flow] of deserialized response objects * of type [R]. The response is expected to be a stream of JSON objects that are parsed in real-time