diff --git a/core/build.gradle.kts b/core/build.gradle.kts index de1530459..bc3d922f3 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -106,7 +106,10 @@ kotlin { } val jvmTest by getting { dependencies { - implementation(libs.kotest.junit5) + implementation(libs.ollama.testcontainers) + implementation(libs.junit.jupiter.api) + implementation(libs.junit.jupiter.engine) + implementation(libs.logback) } } val linuxX64Main by getting { diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/AI.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/AI.kt index 96355222d..1ccea9382 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/AI.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/AI.kt @@ -7,6 +7,7 @@ import com.xebia.functional.xef.conversation.AiDsl import com.xebia.functional.xef.conversation.Conversation import com.xebia.functional.xef.conversation.Description import com.xebia.functional.xef.prompt.Prompt +import com.xebia.functional.xef.prompt.configuration.PromptConfiguration import kotlin.coroutines.cancellation.CancellationException import kotlin.reflect.KClass import kotlin.reflect.KType @@ -193,7 +194,19 @@ sealed interface AI { config: Config = Config(), api: Chat = OpenAI(config).chat, conversation: Conversation = Conversation() - ): A = chat(Prompt(model, prompt), target, config, api, conversation) + ): A = + chat( + prompt = + Prompt( + model = model, + value = prompt, + configuration = PromptConfiguration { supportsLogitBias = config.supportsLogitBias } + ), + target = target, + config = config, + api = api, + conversation = conversation + ) @AiDsl suspend inline operator fun invoke( diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/Config.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/Config.kt index 4645fbd4f..f96a6b75e 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/Config.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/Config.kt @@ -1,6 +1,5 @@ package com.xebia.functional.xef -import arrow.core.nonEmptyListOf import com.xebia.functional.openai.Config as OpenAIConfig import com.xebia.functional.openai.generated.api.OpenAI import com.xebia.functional.xef.env.getenv @@ -26,7 +25,8 @@ data class Config( classDiscriminator = "_type_" }, val streamingPrefix: String = "data:", - val streamingDelimiter: String = "data: [DONE]" + val streamingDelimiter: String = "data: [DONE]", + val supportsLogitBias: Boolean = true, ) { companion object { val DEFAULT = Config() @@ -47,10 +47,8 @@ fun OpenAI( httpClientConfig: ((HttpClientConfig<*>) -> Unit)? = null, logRequests: Boolean = false ): OpenAI { - val token = - config.token - ?: getenv(KEY_ENV_VAR) - ?: throw AIError.Env.OpenAI(nonEmptyListOf("missing $KEY_ENV_VAR env var")) + val token = config.token ?: getenv(KEY_ENV_VAR) ?: "" + // var")) val clientConfig: HttpClientConfig<*>.() -> Unit = { install(ContentNegotiation) { json(config.json) } install(HttpTimeout) { diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/DefaultAI.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/DefaultAI.kt index 4d3b1b9cd..5b65d0c39 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/DefaultAI.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/DefaultAI.kt @@ -9,6 +9,8 @@ import com.xebia.functional.xef.llm.models.modelType import com.xebia.functional.xef.llm.prompt import com.xebia.functional.xef.llm.promptStreaming import com.xebia.functional.xef.prompt.Prompt +import com.xebia.functional.xef.prompt.PromptBuilder.Companion.user +import com.xebia.functional.xef.prompt.contentAsString import kotlin.reflect.KClass import kotlin.reflect.KType import kotlin.reflect.typeOf @@ -50,7 +52,7 @@ data class DefaultAI( val serializer = serializer() return when (serializer.descriptor.kind) { SerialKind.ENUM -> { - runWithEnumSingleTokenSerializer(serializer, prompt) + runWithEnumSerializer(serializer, prompt) } // else -> runWithSerializer(prompt, serializer) PolymorphicKind.OPEN -> @@ -82,11 +84,61 @@ data class DefaultAI( } } - @OptIn(ExperimentalSerializationApi::class) - suspend fun runWithEnumSingleTokenSerializer(serializer: KSerializer, prompt: Prompt): A { + private suspend fun runWithEnumSerializer(serializer: KSerializer, prompt: Prompt): A = + if (prompt.configuration.supportsLogitBias) { + runWithEnumSingleTokenSerializer(serializer, prompt) + } else { + runWithEnumRegexResponseSerializer(serializer, prompt) + } + + private suspend fun runWithEnumRegexResponseSerializer( + serializer: KSerializer, + prompt: Prompt + ): A { + val cases = casesFromEnumSerializer(serializer) + val classificationMessage = + user( + """ + + You are an AI, expected to classify the `context` into one of the `cases`: + + ${prompt.messages.joinToString("\n") { it.contentAsString() }} + + ${cases.map { s -> "$s" }.joinToString("\n")} + + Select the `case` corresponding to the `context`. + IMPORTANT. Reply exclusively with the selected `case`. + + """ + .trimIndent() + ) + val result = + api.createChatCompletion( + CreateChatCompletionRequest( + messages = prompt.messages + classificationMessage, + model = model, + maxTokens = prompt.configuration.maxTokens, + temperature = 0.0 + ) + ) + val casesRegexes = cases.map { ".*$it.*" } + val responseContent = result.choices[0].message.content ?: "" + val choice = + casesRegexes + .zip(cases) + .firstOrNull { + Regex(it.first, RegexOption.IGNORE_CASE).containsMatchIn(responseContent.trim()) + } + ?.second + return serializeWithEnumSerializer(choice, enumSerializer) + } + + private suspend fun runWithEnumSingleTokenSerializer( + serializer: KSerializer, + prompt: Prompt + ): A { val encoding = model.modelType(forFunctions = false).encoding - val cases = - serializer.descriptor.elementDescriptors.map { it.serialName.substringAfterLast(".") } + val cases = casesFromEnumSerializer(serializer) val logitBias = cases .flatMap { @@ -108,7 +160,17 @@ data class DefaultAI( ) ) val choice = result.choices[0].message.content - val enumSerializer = enumSerializer + return serializeWithEnumSerializer(choice, enumSerializer) + } + + @OptIn(ExperimentalSerializationApi::class) + private fun casesFromEnumSerializer(serializer: KSerializer): List = + serializer.descriptor.elementDescriptors.map { it.serialName.substringAfterLast(".") } + + private fun serializeWithEnumSerializer( + choice: String?, + enumSerializer: ((case: String) -> A)? + ): A { return if (choice != null && enumSerializer != null) { enumSerializer(choice) } else { diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/configuration/PromptConfiguration.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/configuration/PromptConfiguration.kt index 3a8f05e53..4fc84a2cc 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/configuration/PromptConfiguration.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/prompt/configuration/PromptConfiguration.kt @@ -17,6 +17,7 @@ constructor( var maxTokens: Int = 500, var messagePolicy: MessagePolicy = MessagePolicy(), var seed: Int? = null, + var supportsLogitBias: Boolean = true, ) { fun messagePolicy(block: MessagePolicy.() -> Unit) = messagePolicy.apply { block() } diff --git a/core/src/jvmTest/kotlin/com/xebia/functional/xef/ollama/tests/EnumClassificationTest.kt b/core/src/jvmTest/kotlin/com/xebia/functional/xef/ollama/tests/EnumClassificationTest.kt new file mode 100644 index 000000000..100c32374 --- /dev/null +++ b/core/src/jvmTest/kotlin/com/xebia/functional/xef/ollama/tests/EnumClassificationTest.kt @@ -0,0 +1,33 @@ +package com.xebia.functional.xef.ollama.tests + +import com.xebia.functional.xef.ollama.tests.models.OllamaModels +import com.xebia.functional.xef.ollama.tests.models.Sentiment +import kotlinx.coroutines.runBlocking +import org.junit.jupiter.api.Test + +class EnumClassificationTest : OllamaTests() { + + @Test + fun `positive sentiment`() { + runBlocking { + val sentiment = + ollama( + model = OllamaModels.Gemma2B, + prompt = "The context of the situation is very positive.", + ) + assert(sentiment == Sentiment.POSITIVE) { "Expected POSITIVE but got $sentiment" } + } + } + + @Test + fun `negative sentiment`() { + runBlocking { + val sentiment = + ollama( + model = OllamaModels.LLama3_8B, + prompt = "The context of the situation is very negative.", + ) + assert(sentiment == Sentiment.NEGATIVE) { "Expected NEGATIVE but got $sentiment" } + } + } +} diff --git a/core/src/jvmTest/kotlin/com/xebia/functional/xef/ollama/tests/OllamaTests.kt b/core/src/jvmTest/kotlin/com/xebia/functional/xef/ollama/tests/OllamaTests.kt new file mode 100644 index 000000000..b3f18f7f1 --- /dev/null +++ b/core/src/jvmTest/kotlin/com/xebia/functional/xef/ollama/tests/OllamaTests.kt @@ -0,0 +1,100 @@ +package com.xebia.functional.xef.ollama.tests + +import com.github.dockerjava.api.model.Image +import com.xebia.functional.openai.generated.model.CreateChatCompletionRequestModel +import com.xebia.functional.xef.AI +import com.xebia.functional.xef.Config +import com.xebia.functional.xef.OpenAI +import io.github.oshai.kotlinlogging.KotlinLogging +import java.util.concurrent.ConcurrentHashMap +import org.junit.jupiter.api.AfterAll +import org.testcontainers.DockerClientFactory +import org.testcontainers.ollama.OllamaContainer +import org.testcontainers.utility.DockerImageName + +abstract class OllamaTests { + + val logger = KotlinLogging.logger {} + + companion object { + private const val OLLAMA_IMAGE = "ollama/ollama:0.1.26" + + private val registeredContainers: MutableMap = ConcurrentHashMap() + + @PublishedApi + internal fun useModel(model: String): OllamaContainer = + if (registeredContainers.containsKey(model)) { + registeredContainers[model]!! + } else { + ollamaContainer(model) + } + + private fun ollamaContainer(model: String, imageName: String = model): OllamaContainer { + if (registeredContainers.containsKey(model)) { + return registeredContainers[model]!! + } + // create the new image if it is not already a docker image + val listImagesCmd: List = + DockerClientFactory.lazyClient().listImagesCmd().withImageNameFilter(imageName).exec() + + val ollama = + if (listImagesCmd.isEmpty()) { + // ship container emoji: 🚢 + println("🐳 Creating a new Ollama container with $model image...") + val ollama = OllamaContainer(OLLAMA_IMAGE) + ollama.start() + println("🐳 Pulling $model image...") + ollama.execInContainer("ollama", "pull", model) + println("🐳 Committing $model image...") + ollama.commitToImage(imageName) + ollama.withReuse(true) + } else { + println("🐳 Using existing Ollama container with $model image...") + // Substitute the default Ollama image with our model variant + val ollama = + OllamaContainer( + DockerImageName.parse(imageName).asCompatibleSubstituteFor("ollama/ollama") + ) + .withReuse(true) + ollama.start() + ollama + } + println("🐳 Starting Ollama container with $model image...") + registeredContainers[model] = ollama + ollama.execInContainer("ollama", "run", model) + return ollama + } + + @AfterAll + @JvmStatic + fun teardown() { + registeredContainers.forEach { (model, container) -> + println("🐳 Stopping Ollama container for model $model") + container.stop() + } + } + } + + protected suspend inline fun ollama( + model: String, + prompt: String, + ): A { + useModel(model) + val config = Config(supportsLogitBias = false, baseUrl = ollamaBaseUrl(model)) + val api = OpenAI(config = config, logRequests = true).chat + val result: A = + AI( + prompt = prompt, + config = config, + api = api, + model = CreateChatCompletionRequestModel.Custom(model), + ) + logger.info { "🚀 Inference on model $model: $result" } + return result + } + + fun ollamaBaseUrl(model: String): String { + val ollama = registeredContainers[model]!! + return "http://${ollama.host}:${ollama.getMappedPort(ollama.exposedPorts.first())}/v1/" + } +} diff --git a/core/src/jvmTest/kotlin/com/xebia/functional/xef/ollama/tests/models/OllamaModels.kt b/core/src/jvmTest/kotlin/com/xebia/functional/xef/ollama/tests/models/OllamaModels.kt new file mode 100644 index 000000000..76b5a8fca --- /dev/null +++ b/core/src/jvmTest/kotlin/com/xebia/functional/xef/ollama/tests/models/OllamaModels.kt @@ -0,0 +1,8 @@ +package com.xebia.functional.xef.ollama.tests.models + +object OllamaModels { + const val Gemma2B = "gemma:2b" + const val Phi3Latest = "phi3:latest" + const val LLama3_8B = "llama3:8b" + const val Qwen0_5B = "qwen:0.5b" +} diff --git a/core/src/jvmTest/kotlin/com/xebia/functional/xef/ollama/tests/models/Sentiment.kt b/core/src/jvmTest/kotlin/com/xebia/functional/xef/ollama/tests/models/Sentiment.kt new file mode 100644 index 000000000..f1e3374d5 --- /dev/null +++ b/core/src/jvmTest/kotlin/com/xebia/functional/xef/ollama/tests/models/Sentiment.kt @@ -0,0 +1,9 @@ +package com.xebia.functional.xef.ollama.tests.models + +import kotlinx.serialization.Serializable + +@Serializable +enum class Sentiment { + POSITIVE, + NEGATIVE, +} diff --git a/core/src/jvmTest/resources/logback.xml b/core/src/jvmTest/resources/logback.xml new file mode 100644 index 000000000..fb058d664 --- /dev/null +++ b/core/src/jvmTest/resources/logback.xml @@ -0,0 +1,26 @@ + + + + + + + + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} -%kvp- %msg%n + + + + + + + + + + + + + + + diff --git a/examples/src/main/kotlin/com/xebia/functional/xef/dsl/chat/EnumOllama.kt b/examples/src/main/kotlin/com/xebia/functional/xef/dsl/chat/EnumOllama.kt new file mode 100644 index 000000000..eaf647c06 --- /dev/null +++ b/examples/src/main/kotlin/com/xebia/functional/xef/dsl/chat/EnumOllama.kt @@ -0,0 +1,18 @@ +package com.xebia.functional.xef.dsl.chat + +import com.xebia.functional.openai.generated.model.CreateChatCompletionRequestModel +import com.xebia.functional.xef.AI +import com.xebia.functional.xef.Config +import com.xebia.functional.xef.OpenAI + +suspend fun main() { + val config = Config(baseUrl = "http://localhost:11434/v1/", supportsLogitBias = false) + val sentiment = + AI( + prompt = "I love Xef!", + model = CreateChatCompletionRequestModel.Custom("orca-mini:3b"), + config = config, + api = OpenAI(config, logRequests = true).chat, + ) + println(sentiment) // positive +} diff --git a/examples/src/main/resources/logback.xml b/examples/src/main/resources/logback.xml index fb058d664..9a90533b5 100644 --- a/examples/src/main/resources/logback.xml +++ b/examples/src/main/resources/logback.xml @@ -12,7 +12,7 @@ - + diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 7d4ce680f..af0c6af96 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -15,7 +15,7 @@ kotest-arrow = "1.4.0" klogging = "6.0.9" uuid = "0.0.22" postgresql = "42.7.3" -testcontainers = "1.19.5" +testcontainers = "1.19.7" hikari = "5.1.0" dokka = "1.9.20" logback = "1.5.5" @@ -115,6 +115,7 @@ opentelemetry-extension-kotlin = { module = "io.opentelemetry:opentelemetry-exte progressbar = { module = "me.tongfei:progressbar", version.ref = "progressbar" } jmf = { module = "javax.media:jmf", version.ref = "jmf" } mp3-wav-converter = { module = "com.sipgate:mp3-wav", version.ref = "mp3-wav-converter" } +ollama-testcontainers = { module = "org.testcontainers:ollama", version.ref = "testcontainers" }