Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enum/Classification support for models that do not support logitBias #746

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@ kotlin {
}
val jvmTest by getting {
dependencies {
implementation(libs.kotest.junit5)
implementation(libs.ollama.testcontainers)
implementation(libs.junit.jupiter.api)
implementation(libs.junit.jupiter.engine)
implementation(libs.logback)
}
}
val linuxX64Main by getting {
Expand Down
15 changes: 14 additions & 1 deletion core/src/commonMain/kotlin/com/xebia/functional/xef/AI.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import com.xebia.functional.xef.conversation.AiDsl
import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.conversation.Description
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.prompt.configuration.PromptConfiguration
import kotlin.coroutines.cancellation.CancellationException
import kotlin.reflect.KClass
import kotlin.reflect.KType
Expand Down Expand Up @@ -193,7 +194,19 @@ sealed interface AI {
config: Config = Config(),
api: Chat = OpenAI(config).chat,
conversation: Conversation = Conversation()
): A = chat(Prompt(model, prompt), target, config, api, conversation)
): A =
chat(
prompt =
Prompt(
model = model,
value = prompt,
configuration = PromptConfiguration { supportsLogitBias = config.supportsLogitBias }
),
target = target,
config = config,
api = api,
conversation = conversation
)

@AiDsl
suspend inline operator fun <reified A : Any> invoke(
Expand Down
10 changes: 4 additions & 6 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/Config.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.xebia.functional.xef

import arrow.core.nonEmptyListOf
import com.xebia.functional.openai.Config as OpenAIConfig
import com.xebia.functional.openai.generated.api.OpenAI
import com.xebia.functional.xef.env.getenv
Expand All @@ -26,7 +25,8 @@ data class Config(
classDiscriminator = "_type_"
},
val streamingPrefix: String = "data:",
val streamingDelimiter: String = "data: [DONE]"
val streamingDelimiter: String = "data: [DONE]",
val supportsLogitBias: Boolean = true,
) {
companion object {
val DEFAULT = Config()
Expand All @@ -47,10 +47,8 @@ fun OpenAI(
httpClientConfig: ((HttpClientConfig<*>) -> Unit)? = null,
logRequests: Boolean = false
): OpenAI {
val token =
config.token
?: getenv(KEY_ENV_VAR)
?: throw AIError.Env.OpenAI(nonEmptyListOf("missing $KEY_ENV_VAR env var"))
val token = config.token ?: getenv(KEY_ENV_VAR) ?: "<not-provided>"
// var"))
val clientConfig: HttpClientConfig<*>.() -> Unit = {
install(ContentNegotiation) { json(config.json) }
install(HttpTimeout) {
Expand Down
74 changes: 68 additions & 6 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/DefaultAI.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import com.xebia.functional.xef.llm.models.modelType
import com.xebia.functional.xef.llm.prompt
import com.xebia.functional.xef.llm.promptStreaming
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.prompt.PromptBuilder.Companion.user
import com.xebia.functional.xef.prompt.contentAsString
import kotlin.reflect.KClass
import kotlin.reflect.KType
import kotlin.reflect.typeOf
Expand Down Expand Up @@ -50,7 +52,7 @@ data class DefaultAI<A : Any>(
val serializer = serializer()
return when (serializer.descriptor.kind) {
SerialKind.ENUM -> {
runWithEnumSingleTokenSerializer(serializer, prompt)
runWithEnumSerializer(serializer, prompt)
}
// else -> runWithSerializer(prompt, serializer)
PolymorphicKind.OPEN ->
Expand Down Expand Up @@ -82,11 +84,61 @@ data class DefaultAI<A : Any>(
}
}

@OptIn(ExperimentalSerializationApi::class)
suspend fun runWithEnumSingleTokenSerializer(serializer: KSerializer<A>, prompt: Prompt): A {
private suspend fun runWithEnumSerializer(serializer: KSerializer<A>, prompt: Prompt): A =
if (prompt.configuration.supportsLogitBias) {
runWithEnumSingleTokenSerializer(serializer, prompt)
} else {
runWithEnumRegexResponseSerializer(serializer, prompt)
}

private suspend fun runWithEnumRegexResponseSerializer(
serializer: KSerializer<A>,
prompt: Prompt
): A {
val cases = casesFromEnumSerializer(serializer)
val classificationMessage =
user(
"""
<instructions>
You are an AI, expected to classify the `context` into one of the `cases`:
<context>
${prompt.messages.joinToString("\n") { it.contentAsString() }}
<cases>
${cases.map { s -> "<case>$s</case>" }.joinToString("\n")}
</cases>
Select the `case` corresponding to the `context`.
IMPORTANT. Reply exclusively with the selected `case`.
</instructions>
"""
.trimIndent()
)
val result =
api.createChatCompletion(
CreateChatCompletionRequest(
messages = prompt.messages + classificationMessage,
model = model,
maxTokens = prompt.configuration.maxTokens,
temperature = 0.0
)
)
val casesRegexes = cases.map { ".*$it.*" }
val responseContent = result.choices[0].message.content ?: ""
val choice =
casesRegexes
.zip(cases)
.firstOrNull {
Regex(it.first, RegexOption.IGNORE_CASE).containsMatchIn(responseContent.trim())
}
?.second
return serializeWithEnumSerializer(choice, enumSerializer)
}

private suspend fun runWithEnumSingleTokenSerializer(
serializer: KSerializer<A>,
prompt: Prompt
): A {
val encoding = model.modelType(forFunctions = false).encoding
val cases =
serializer.descriptor.elementDescriptors.map { it.serialName.substringAfterLast(".") }
val cases = casesFromEnumSerializer(serializer)
val logitBias =
cases
.flatMap {
Expand All @@ -108,7 +160,17 @@ data class DefaultAI<A : Any>(
)
)
val choice = result.choices[0].message.content
val enumSerializer = enumSerializer
return serializeWithEnumSerializer(choice, enumSerializer)
}

@OptIn(ExperimentalSerializationApi::class)
private fun casesFromEnumSerializer(serializer: KSerializer<A>): List<String> =
serializer.descriptor.elementDescriptors.map { it.serialName.substringAfterLast(".") }

private fun serializeWithEnumSerializer(
choice: String?,
enumSerializer: ((case: String) -> A)?
): A {
return if (choice != null && enumSerializer != null) {
enumSerializer(choice)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ constructor(
var maxTokens: Int = 500,
var messagePolicy: MessagePolicy = MessagePolicy(),
var seed: Int? = null,
var supportsLogitBias: Boolean = true,
) {

fun messagePolicy(block: MessagePolicy.() -> Unit) = messagePolicy.apply { block() }
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package com.xebia.functional.xef.ollama.tests

import com.xebia.functional.xef.ollama.tests.models.OllamaModels
import com.xebia.functional.xef.ollama.tests.models.Sentiment
import kotlinx.coroutines.runBlocking
import org.junit.jupiter.api.Test

class EnumClassificationTest : OllamaTests() {

@Test
fun `positive sentiment`() {
runBlocking {
val sentiment =
ollama<Sentiment>(
model = OllamaModels.Gemma2B,
prompt = "The context of the situation is very positive.",
)
assert(sentiment == Sentiment.POSITIVE) { "Expected POSITIVE but got $sentiment" }
}
}

@Test
fun `negative sentiment`() {
runBlocking {
val sentiment =
ollama<Sentiment>(
model = OllamaModels.LLama3_8B,
prompt = "The context of the situation is very negative.",
)
assert(sentiment == Sentiment.NEGATIVE) { "Expected NEGATIVE but got $sentiment" }
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package com.xebia.functional.xef.ollama.tests

import com.github.dockerjava.api.model.Image
import com.xebia.functional.openai.generated.model.CreateChatCompletionRequestModel
import com.xebia.functional.xef.AI
import com.xebia.functional.xef.Config
import com.xebia.functional.xef.OpenAI
import io.github.oshai.kotlinlogging.KotlinLogging
import java.util.concurrent.ConcurrentHashMap
import org.junit.jupiter.api.AfterAll
import org.testcontainers.DockerClientFactory
import org.testcontainers.ollama.OllamaContainer
import org.testcontainers.utility.DockerImageName

abstract class OllamaTests {

val logger = KotlinLogging.logger {}

companion object {
private const val OLLAMA_IMAGE = "ollama/ollama:0.1.26"

private val registeredContainers: MutableMap<String, OllamaContainer> = ConcurrentHashMap()

@PublishedApi
internal fun useModel(model: String): OllamaContainer =
if (registeredContainers.containsKey(model)) {
registeredContainers[model]!!
} else {
ollamaContainer(model)
}

private fun ollamaContainer(model: String, imageName: String = model): OllamaContainer {
if (registeredContainers.containsKey(model)) {
return registeredContainers[model]!!
}
// create the new image if it is not already a docker image
val listImagesCmd: List<Image> =
DockerClientFactory.lazyClient().listImagesCmd().withImageNameFilter(imageName).exec()

val ollama =
if (listImagesCmd.isEmpty()) {
// ship container emoji: 🚢
println("🐳 Creating a new Ollama container with $model image...")
val ollama = OllamaContainer(OLLAMA_IMAGE)
ollama.start()
println("🐳 Pulling $model image...")
ollama.execInContainer("ollama", "pull", model)
println("🐳 Committing $model image...")
ollama.commitToImage(imageName)
ollama.withReuse(true)
} else {
println("🐳 Using existing Ollama container with $model image...")
// Substitute the default Ollama image with our model variant
val ollama =
OllamaContainer(
DockerImageName.parse(imageName).asCompatibleSubstituteFor("ollama/ollama")
)
.withReuse(true)
ollama.start()
ollama
}
println("🐳 Starting Ollama container with $model image...")
registeredContainers[model] = ollama
ollama.execInContainer("ollama", "run", model)
return ollama
}

@AfterAll
@JvmStatic
fun teardown() {
registeredContainers.forEach { (model, container) ->
println("🐳 Stopping Ollama container for model $model")
container.stop()
}
}
}

protected suspend inline fun <reified A> ollama(
model: String,
prompt: String,
): A {
useModel(model)
val config = Config(supportsLogitBias = false, baseUrl = ollamaBaseUrl(model))
val api = OpenAI(config = config, logRequests = true).chat
val result: A =
AI(
prompt = prompt,
config = config,
api = api,
model = CreateChatCompletionRequestModel.Custom(model),
)
logger.info { "🚀 Inference on model $model: $result" }
return result
}

fun ollamaBaseUrl(model: String): String {
val ollama = registeredContainers[model]!!
return "http://${ollama.host}:${ollama.getMappedPort(ollama.exposedPorts.first())}/v1/"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.xebia.functional.xef.ollama.tests.models

object OllamaModels {
const val Gemma2B = "gemma:2b"
const val Phi3Latest = "phi3:latest"
const val LLama3_8B = "llama3:8b"
const val Qwen0_5B = "qwen:0.5b"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package com.xebia.functional.xef.ollama.tests.models

import kotlinx.serialization.Serializable

@Serializable
enum class Sentiment {
POSITIVE,
NEGATIVE,
}
26 changes: 26 additions & 0 deletions core/src/jvmTest/resources/logback.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
<?xml version="1.0" encoding="UTF-8" ?>
<!DOCTYPE configuration>

<configuration>
<statusListener class="ch.qos.logback.core.status.NopStatusListener" />

<appender name="NOOP" class="ch.qos.logback.core.helpers.NOPAppender" />

<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
<encoder class="ch.qos.logback.classic.encoder.PatternLayoutEncoder">
<pattern>%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} -%kvp- %msg%n</pattern>
</encoder>
</appender>

<root level="info">
<appender-ref ref="STDOUT"/>
</root>

<logger name="com.xebia.functional.xef" level="debug">
<appender-ref ref="STDOUT" />
</logger>

<logger name="com.gargoylesoftware.htmlunit" level="off">
<appender-ref ref="STDOUT" />
</logger>
</configuration>
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package com.xebia.functional.xef.dsl.chat

import com.xebia.functional.openai.generated.model.CreateChatCompletionRequestModel
import com.xebia.functional.xef.AI
import com.xebia.functional.xef.Config
import com.xebia.functional.xef.OpenAI

suspend fun main() {
val config = Config(baseUrl = "http://localhost:11434/v1/", supportsLogitBias = false)
val sentiment =
AI<Sentiment>(
prompt = "I love Xef!",
model = CreateChatCompletionRequestModel.Custom("orca-mini:3b"),
config = config,
api = OpenAI(config, logRequests = true).chat,
)
println(sentiment) // positive
}
2 changes: 1 addition & 1 deletion examples/src/main/resources/logback.xml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
</encoder>
</appender>

<root level="info">
<root level="trace">
<appender-ref ref="STDOUT"/>
</root>

Expand Down
Loading
Loading