Skip to content

Commit

Permalink
Allow @Schema on Tool requests and read description from annotati…
Browse files Browse the repository at this point in the history
…on when available. (#771)
  • Loading branch information
raulraja authored Aug 20, 2024
1 parent 8f36c04 commit 7295e6c
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package com.xebia.functional.xef.conversation

import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.SerialInfo

/** Schema for a tool request */
@OptIn(ExperimentalSerializationApi::class)
@SerialInfo
@Retention(AnnotationRetention.RUNTIME)
@Target(AnnotationTarget.CLASS)
annotation class Schema(val value: String)
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ import com.xebia.functional.openai.generated.api.Chat
import com.xebia.functional.openai.generated.model.*
import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.AIEvent
import com.xebia.functional.xef.Config
import com.xebia.functional.xef.Tool
import com.xebia.functional.xef.conversation.AiDsl
import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.conversation.Description
import com.xebia.functional.xef.conversation.Schema
import com.xebia.functional.xef.llm.models.functions.buildJsonSchema
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.prompt.PromptBuilder.Companion.tool
Expand All @@ -21,17 +24,31 @@ import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.json.*

@OptIn(ExperimentalSerializationApi::class)
fun chatFunction(descriptor: SerialDescriptor): FunctionObject {
val fnName = descriptor.serialName.substringAfterLast(".")
return chatFunction(fnName, buildJsonSchema(descriptor))
val functionName = functionName(descriptor)
return FunctionObject(
name = functionName,
description = functionDescription(descriptor, functionName),
parameters = functionSchema(descriptor)
)
}

fun chatFunctions(descriptors: List<SerialDescriptor>): List<FunctionObject> =
descriptors.map(::chatFunction)
@OptIn(ExperimentalSerializationApi::class)
fun functionSchema(descriptor: SerialDescriptor): JsonObject =
descriptor.annotations.filterIsInstance<Schema>().firstOrNull()?.value?.let {
Config.DEFAULT.json.decodeFromString(JsonObject.serializer(), it)
} ?: buildJsonSchema(descriptor)

fun chatFunction(fnName: String, schema: JsonObject): FunctionObject =
FunctionObject(fnName, "Generated function for $fnName", schema)
@OptIn(ExperimentalSerializationApi::class)
fun functionDescription(descriptor: SerialDescriptor, fnName: String): String =
(descriptor.annotations.filterIsInstance<Description>().firstOrNull()?.value
?: defaultFunctionDescription(fnName))

fun defaultFunctionDescription(fnName: String): String = "Generated function for $fnName"

@OptIn(ExperimentalSerializationApi::class)
fun functionName(descriptor: SerialDescriptor): String =
descriptor.serialName.substringAfterLast(".")

data class UsageTracker(
var llmCalls: Int = 0,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package com.xebia.functional.xef.functions

import com.xebia.functional.xef.conversation.Description
import com.xebia.functional.xef.conversation.Schema
import com.xebia.functional.xef.llm.chatFunction
import com.xebia.functional.xef.llm.defaultFunctionDescription
import com.xebia.functional.xef.llm.functionName
import com.xebia.functional.xef.llm.models.functions.buildJsonSchema
import io.kotest.core.spec.style.StringSpec
import io.kotest.matchers.shouldBe
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.JsonObject

class FunctionSchemaTests :
StringSpec({
"Request has default description" {
val descriptor = Request.serializer().descriptor
val function = chatFunction(descriptor)
val fnName = functionName(descriptor)
function.description shouldBe defaultFunctionDescription(fnName)
}

"Description can be set on request" {
val descriptor = RequestWithDescription.serializer().descriptor
val function = chatFunction(descriptor)
function.description shouldBe "Request With Description"
}

"Schema can be generated on request" {
val descriptor = Request.serializer().descriptor
val function = chatFunction(descriptor)
function.parameters shouldBe buildJsonSchema(descriptor)
}

"Schema can be set on request" {
val descriptor = RequestWithSchema.serializer().descriptor
val function = chatFunction(descriptor)
function.parameters shouldBe JsonObject(emptyMap())
}
}) {

@Serializable data class Request(val input: String)

@Serializable
@Description("Request With Description")
data class RequestWithDescription(val input: String)

@Serializable
@Description("Request with schema")
@Schema("""
{
}
""")
data class RequestWithSchema(val input: String)
}

0 comments on commit 7295e6c

Please sign in to comment.