diff --git a/src/lib/server/endpoints/anthropic/endpointAnthropic.ts b/src/lib/server/endpoints/anthropic/endpointAnthropic.ts index 50b422f9500..fc085baecd0 100644 --- a/src/lib/server/endpoints/anthropic/endpointAnthropic.ts +++ b/src/lib/server/endpoints/anthropic/endpointAnthropic.ts @@ -3,9 +3,19 @@ import type { Endpoint } from "../endpoints"; import { env } from "$env/dynamic/private"; import type { TextGenerationStreamOutput } from "@huggingface/inference"; import { createImageProcessorOptionsValidator } from "../images"; -import { endpointMessagesToAnthropicMessages } from "./utils"; +import { endpointMessagesToAnthropicMessages, addToolResults } from "./utils"; import { createDocumentProcessorOptionsValidator } from "../document"; +import type { + Tool, + ToolCall, + ToolInput, + ToolInputFile, + ToolInputFixed, + ToolInputOptional, +} from "$lib/types/Tool"; +import type Anthropic from "@anthropic-ai/sdk"; import type { MessageParam } from "@anthropic-ai/sdk/resources/messages.mjs"; +import directlyAnswer from "$lib/server/tools/directlyAnswer"; export const endpointAnthropicParametersSchema = z.object({ weight: z.number().int().positive().default(1), @@ -52,23 +62,41 @@ export async function endpointAnthropic( defaultQuery, }); - return async ({ messages, preprompt, generateSettings }) => { + return async ({ + messages, + preprompt, + generateSettings, + conversationId, + tools = [], + toolResults = [], + }) => { let system = preprompt; if (messages?.[0]?.from === "system") { system = messages[0].content; } let tokenId = 0; + if (tools.length === 0 && toolResults.length > 0) { + const toolNames = new Set(toolResults.map((tool) => tool.call.name)); + tools = Array.from(toolNames).map((name) => ({ + name, + description: "", + inputs: [], + })) as unknown as Tool[]; + } const parameters = { ...model.parameters, ...generateSettings }; return (async function* () { const stream = anthropic.messages.stream({ model: model.id ?? model.name, - messages: (await endpointMessagesToAnthropicMessages( - messages, - multimodal - )) as MessageParam[], + tools: createAnthropicTools(tools), + tool_choice: + tools.length > 0 ? { type: "auto", disable_parallel_tool_use: false } : undefined, + messages: addToolResults( + await endpointMessagesToAnthropicMessages(messages, multimodal, conversationId), + toolResults + ) as MessageParam[], max_tokens: parameters?.max_new_tokens, temperature: parameters?.temperature, top_p: parameters?.top_p, @@ -79,21 +107,40 @@ export async function endpointAnthropic( while (true) { const result = await Promise.race([stream.emitted("text"), stream.emitted("end")]); - // Stream end if (result === undefined) { - yield { - token: { - id: tokenId++, - text: "", - logprob: 0, - special: true, - }, - generated_text: await stream.finalText(), - details: null, - } satisfies TextGenerationStreamOutput; + if ("tool_use" === stream.receivedMessages[0].stop_reason) { + // this should really create a new "Assistant" message with the tool id in it. + const toolCalls: ToolCall[] = stream.receivedMessages[0].content + .filter( + (block): block is Anthropic.Messages.ContentBlock & { type: "tool_use" } => + block.type === "tool_use" + ) + .map((block) => ({ + name: block.name, + parameters: block.input as Record, + id: block.id, + })); + + yield { + token: { id: tokenId, text: "", logprob: 0, special: false, toolCalls }, + generated_text: null, + details: null, + }; + } else { + yield { + token: { + id: tokenId++, + text: "", + logprob: 0, + special: true, + }, + generated_text: await stream.finalText(), + details: null, + } satisfies TextGenerationStreamOutput; + } + return; } - // Text delta yield { token: { @@ -109,3 +156,66 @@ export async function endpointAnthropic( })(); }; } + +function createAnthropicTools(tools: Tool[]): Anthropic.Messages.Tool[] { + return tools + .filter((tool) => tool.name !== directlyAnswer.name) + .map((tool) => { + const properties = tool.inputs.reduce((acc, input) => { + acc[input.name] = convertToolInputToJSONSchema(input); + return acc; + }, {} as Record); + + const required = tool.inputs + .filter((input) => input.paramType === "required") + .map((input) => input.name); + + return { + name: tool.name, + description: tool.description, + input_schema: { + type: "object", + properties, + required: required.length > 0 ? required : undefined, + }, + }; + }); +} + +function convertToolInputToJSONSchema(input: ToolInput): Record { + const baseSchema: Record = {}; + if ("description" in input) { + baseSchema["description"] = input.description || ""; + } + switch (input.paramType) { + case "optional": + baseSchema["default"] = (input as ToolInputOptional).default; + break; + case "fixed": + baseSchema["const"] = (input as ToolInputFixed).value; + break; + } + + if (input.type === "file") { + baseSchema["type"] = "string"; + baseSchema["format"] = "binary"; + baseSchema["mimeTypes"] = (input as ToolInputFile).mimeTypes; + } else { + switch (input.type) { + case "str": + baseSchema["type"] = "string"; + break; + case "int": + baseSchema["type"] = "integer"; + break; + case "float": + baseSchema["type"] = "number"; + break; + case "bool": + baseSchema["type"] = "boolean"; + break; + } + } + + return baseSchema; +} diff --git a/src/lib/server/endpoints/anthropic/utils.ts b/src/lib/server/endpoints/anthropic/utils.ts index 0239e426300..c935ea43d95 100644 --- a/src/lib/server/endpoints/anthropic/utils.ts +++ b/src/lib/server/endpoints/anthropic/utils.ts @@ -7,12 +7,16 @@ import type { BetaMessageParam, BetaBase64PDFBlock, } from "@anthropic-ai/sdk/resources/beta/messages/messages.mjs"; +import type { ToolResult } from "$lib/types/Tool"; +import { downloadFile } from "$lib/server/files/downloadFile"; +import type { ObjectId } from "mongodb"; export async function fileToImageBlock( file: MessageFile, opts: ImageProcessorOptions<"image/png" | "image/jpeg" | "image/webp"> ): Promise { const processor = makeImageProcessor(opts); + const { image, mime } = await processor(file); return { @@ -48,7 +52,8 @@ export async function endpointMessagesToAnthropicMessages( multimodal: { image: ImageProcessorOptions<"image/png" | "image/jpeg" | "image/webp">; document?: FileProcessorOptions<"application/pdf">; - } + }, + conversationId?: ObjectId | undefined ): Promise { return await Promise.all( messages @@ -57,20 +62,59 @@ export async function endpointMessagesToAnthropicMessages( return { role: message.from, content: [ - ...(await Promise.all( - (message.files ?? []).map(async (file) => { - if (file.mime.startsWith("image/")) { - return fileToImageBlock(file, multimodal.image); - } else if (file.mime === "application/pdf" && multimodal.document) { - return fileToDocumentBlock(file, multimodal.document); - } else { - throw new Error(`Unsupported file type: ${file.mime}`); - } - }) - )), + ...(message.from === "user" + ? await Promise.all( + (message.files ?? []).map(async (file) => { + if (file.type === "hash" && conversationId) { + file = await downloadFile(file.value, conversationId); + } + + if (file.mime.startsWith("image/")) { + return fileToImageBlock(file, multimodal.image); + } else if (file.mime === "application/pdf" && multimodal.document) { + return fileToDocumentBlock(file, multimodal.document); + } else { + throw new Error(`Unsupported file type: ${file.mime}`); + } + }) + ) + : []), { type: "text", text: message.content }, ], }; }) ); } + +export function addToolResults( + messages: BetaMessageParam[], + toolResults: ToolResult[] +): BetaMessageParam[] { + const id = crypto.randomUUID(); + if (toolResults.length === 0) { + return messages; + } + return [ + ...messages, + { + role: "assistant", + content: toolResults.map((result, index) => ({ + type: "tool_use", + id: `tool_${index}_${id}`, + name: result.call.name, + input: result.call.parameters, + })), + }, + { + role: "user", + content: toolResults.map((result, index) => ({ + type: "tool_result", + tool_use_id: `tool_${index}_${id}`, + is_error: result.status === "error", + content: JSON.stringify( + result.status === "error" ? result.message : "outputs" in result ? result.outputs : "" + ), + })), + }, + ]; +} diff --git a/src/lib/server/textGeneration/generate.ts b/src/lib/server/textGeneration/generate.ts index 36a5d6deac8..560e97bcaf9 100644 --- a/src/lib/server/textGeneration/generate.ts +++ b/src/lib/server/textGeneration/generate.ts @@ -1,4 +1,4 @@ -import type { ToolResult } from "$lib/types/Tool"; +import type { ToolResult, Tool } from "$lib/types/Tool"; import { MessageReasoningUpdateType, MessageUpdateType, @@ -16,7 +16,8 @@ type GenerateContext = Omit & { messages: End export async function* generate( { model, endpoint, conv, messages, assistant, isContinue, promptedAt }: GenerateContext, toolResults: ToolResult[], - preprompt?: string + preprompt?: string, + tools?: Tool[] ): AsyncIterable { // reasoning mode is false by default let reasoning = false; @@ -43,6 +44,7 @@ export async function* generate( preprompt, continueMessage: isContinue, generateSettings: assistant?.generateSettings, + tools, toolResults, isMultimodal: model.multimodal, conversationId: conv._id, diff --git a/src/lib/server/textGeneration/index.ts b/src/lib/server/textGeneration/index.ts index bef84a283f3..0142acfbb52 100644 --- a/src/lib/server/textGeneration/index.ts +++ b/src/lib/server/textGeneration/index.ts @@ -20,6 +20,7 @@ import { mergeAsyncGenerators } from "$lib/utils/mergeAsyncGenerators"; import type { TextGenerationContext } from "./types"; import type { ToolResult } from "$lib/types/Tool"; import { toolHasName } from "../tools/utils"; +import directlyAnswer from "../tools/directlyAnswer"; async function* keepAlive(done: AbortSignal): AsyncGenerator { while (!done.aborted) { @@ -73,11 +74,13 @@ async function* textGenerationWithoutTitle( } let toolResults: ToolResult[] = []; + let tools = model.tools ? await getTools(toolsPreference, ctx.assistant) : undefined; - if (model.tools) { - const tools = await getTools(toolsPreference, ctx.assistant); - const toolCallsRequired = tools.some((tool) => !toolHasName("directly_answer", tool)); - if (toolCallsRequired) toolResults = yield* runTools(ctx, tools, preprompt); + if (tools) { + const toolCallsRequired = tools.some((tool) => !toolHasName(directlyAnswer.name, tool)); + if (toolCallsRequired) { + toolResults = yield* runTools(ctx, tools, preprompt); + } else tools = undefined; } const processedMessages = await preprocessMessages(messages, webSearchResult, convId); diff --git a/src/lib/server/textGeneration/tools.ts b/src/lib/server/textGeneration/tools.ts index 2046bc89dfb..bc36b082360 100644 --- a/src/lib/server/textGeneration/tools.ts +++ b/src/lib/server/textGeneration/tools.ts @@ -213,7 +213,7 @@ export async function* runTools( } // if we dont see a tool call in the first 25 chars, something is going wrong and we abort - if (rawText.length > 25 && !(rawText.includes("```json") || rawText.includes("{"))) { + if (rawText.length > 100 && !(rawText.includes("```json") || rawText.includes("{"))) { return []; }