Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Anthropic Tool Support #1594

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
0c60707
support anthropic PDF beta
evalstate Nov 18, 2024
a0e1b5f
Merge remote-tracking branch 'upstream/main' into feature/anthropic-p…
evalstate Nov 18, 2024
84aac49
upstream merge, remove commented out console log line
evalstate Nov 18, 2024
67d7295
Merge branch 'main' into feature/anthropic-pdf-beta
evalstate Nov 18, 2024
92bf923
Fixing type errors.
evalstate Nov 18, 2024
16637d2
Merge branch 'main' into feature/anthropic-pdf-beta
nsarrazin Nov 18, 2024
3e34b63
changed document processor to async (matching image processor)
evalstate Nov 18, 2024
4c67d1c
Merge remote-tracking branch 'upstream/main' into feature/anthropic-p…
evalstate Nov 22, 2024
36a1cc3
use the beta api types rather than custom extension
evalstate Nov 22, 2024
3cbc5de
Merge remote-tracking branch 'upstream/main' into feature/anthropic-p…
evalstate Nov 25, 2024
786b576
rudimentary tool testing
evalstate Nov 25, 2024
e66ba8d
Merge branch 'main' of https://github.com/huggingface/chat-ui into fe…
evalstate Nov 25, 2024
da22402
interim commit (tool re-passing, file handling)
evalstate Nov 26, 2024
b69d18e
Merge branch 'feature/anthropic-pdf-beta' into feature/anthropic-tool…
evalstate Nov 26, 2024
506ecff
remove merge error
evalstate Nov 26, 2024
3c3d282
Merge branch 'feature/anthropic-pdf-beta' of https://github.com/barre…
evalstate Nov 26, 2024
07764f5
Merge branch 'feature/anthropic-pdf-beta' into feature/anthropic-tool…
evalstate Nov 26, 2024
b233de7
tidy up, isolate beta classes to utils
evalstate Nov 26, 2024
ee6107a
anthropic tool calling support.
evalstate Nov 26, 2024
a8bac58
improve handling of directlyAnswer tool
evalstate Nov 26, 2024
87b57f3
fix streaming
evalstate Nov 26, 2024
0c9abdf
slight tidy up to tools flow handling
evalstate Nov 27, 2024
d3ceb35
Merge remote-tracking branch 'upstream/main' into feature/anthropic-t…
evalstate Nov 27, 2024
06a763a
Merge branch 'main' into feature/anthropic-tool-support
nsarrazin Jan 3, 2025
e0adc68
fix: dont pass tools in final generation, instead deduce tools from t…
nsarrazin Jan 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 128 additions & 18 deletions src/lib/server/endpoints/anthropic/endpointAnthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,19 @@ import type { Endpoint } from "../endpoints";
import { env } from "$env/dynamic/private";
import type { TextGenerationStreamOutput } from "@huggingface/inference";
import { createImageProcessorOptionsValidator } from "../images";
import { endpointMessagesToAnthropicMessages } from "./utils";
import { endpointMessagesToAnthropicMessages, addToolResults } from "./utils";
import { createDocumentProcessorOptionsValidator } from "../document";
import type {
Tool,
ToolCall,
ToolInput,
ToolInputFile,
ToolInputFixed,
ToolInputOptional,
} from "$lib/types/Tool";
import type Anthropic from "@anthropic-ai/sdk";
import type { MessageParam } from "@anthropic-ai/sdk/resources/messages.mjs";
import directlyAnswer from "$lib/server/tools/directlyAnswer";

export const endpointAnthropicParametersSchema = z.object({
weight: z.number().int().positive().default(1),
Expand Down Expand Up @@ -52,23 +62,41 @@ export async function endpointAnthropic(
defaultQuery,
});

return async ({ messages, preprompt, generateSettings }) => {
return async ({
messages,
preprompt,
generateSettings,
conversationId,
tools = [],
toolResults = [],
}) => {
let system = preprompt;
if (messages?.[0]?.from === "system") {
system = messages[0].content;
}

let tokenId = 0;
if (tools.length === 0 && toolResults.length > 0) {
const toolNames = new Set(toolResults.map((tool) => tool.call.name));
tools = Array.from(toolNames).map((name) => ({
name,
description: "",
inputs: [],
})) as unknown as Tool[];
}

const parameters = { ...model.parameters, ...generateSettings };

return (async function* () {
const stream = anthropic.messages.stream({
model: model.id ?? model.name,
messages: (await endpointMessagesToAnthropicMessages(
messages,
multimodal
)) as MessageParam[],
tools: createAnthropicTools(tools),
tool_choice:
tools.length > 0 ? { type: "auto", disable_parallel_tool_use: false } : undefined,
messages: addToolResults(
await endpointMessagesToAnthropicMessages(messages, multimodal, conversationId),
toolResults
) as MessageParam[],
max_tokens: parameters?.max_new_tokens,
temperature: parameters?.temperature,
top_p: parameters?.top_p,
Expand All @@ -79,21 +107,40 @@ export async function endpointAnthropic(
while (true) {
const result = await Promise.race([stream.emitted("text"), stream.emitted("end")]);

// Stream end
if (result === undefined) {
yield {
token: {
id: tokenId++,
text: "",
logprob: 0,
special: true,
},
generated_text: await stream.finalText(),
details: null,
} satisfies TextGenerationStreamOutput;
if ("tool_use" === stream.receivedMessages[0].stop_reason) {
// this should really create a new "Assistant" message with the tool id in it.
const toolCalls: ToolCall[] = stream.receivedMessages[0].content
.filter(
(block): block is Anthropic.Messages.ContentBlock & { type: "tool_use" } =>
block.type === "tool_use"
)
.map((block) => ({
name: block.name,
parameters: block.input as Record<string, string | number | boolean>,
id: block.id,
}));

yield {
token: { id: tokenId, text: "", logprob: 0, special: false, toolCalls },
generated_text: null,
details: null,
};
} else {
yield {
token: {
id: tokenId++,
text: "",
logprob: 0,
special: true,
},
generated_text: await stream.finalText(),
details: null,
} satisfies TextGenerationStreamOutput;
}

return;
}

// Text delta
yield {
token: {
Expand All @@ -109,3 +156,66 @@ export async function endpointAnthropic(
})();
};
}

function createAnthropicTools(tools: Tool[]): Anthropic.Messages.Tool[] {
return tools
.filter((tool) => tool.name !== directlyAnswer.name)
.map((tool) => {
const properties = tool.inputs.reduce((acc, input) => {
acc[input.name] = convertToolInputToJSONSchema(input);
return acc;
}, {} as Record<string, unknown>);

const required = tool.inputs
.filter((input) => input.paramType === "required")
.map((input) => input.name);

return {
name: tool.name,
description: tool.description,
input_schema: {
type: "object",
properties,
required: required.length > 0 ? required : undefined,
},
};
});
}

function convertToolInputToJSONSchema(input: ToolInput): Record<string, unknown> {
const baseSchema: Record<string, unknown> = {};
if ("description" in input) {
baseSchema["description"] = input.description || "";
}
switch (input.paramType) {
case "optional":
baseSchema["default"] = (input as ToolInputOptional).default;
break;
case "fixed":
baseSchema["const"] = (input as ToolInputFixed).value;
break;
}

if (input.type === "file") {
baseSchema["type"] = "string";
baseSchema["format"] = "binary";
baseSchema["mimeTypes"] = (input as ToolInputFile).mimeTypes;
} else {
switch (input.type) {
case "str":
baseSchema["type"] = "string";
break;
case "int":
baseSchema["type"] = "integer";
break;
case "float":
baseSchema["type"] = "number";
break;
case "bool":
baseSchema["type"] = "boolean";
break;
}
}

return baseSchema;
}
68 changes: 56 additions & 12 deletions src/lib/server/endpoints/anthropic/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@ import type {
BetaMessageParam,
BetaBase64PDFBlock,
} from "@anthropic-ai/sdk/resources/beta/messages/messages.mjs";
import type { ToolResult } from "$lib/types/Tool";
import { downloadFile } from "$lib/server/files/downloadFile";
import type { ObjectId } from "mongodb";

export async function fileToImageBlock(
file: MessageFile,
opts: ImageProcessorOptions<"image/png" | "image/jpeg" | "image/webp">
): Promise<BetaImageBlockParam> {
const processor = makeImageProcessor(opts);

const { image, mime } = await processor(file);

return {
Expand Down Expand Up @@ -48,7 +52,8 @@ export async function endpointMessagesToAnthropicMessages(
multimodal: {
image: ImageProcessorOptions<"image/png" | "image/jpeg" | "image/webp">;
document?: FileProcessorOptions<"application/pdf">;
}
},
conversationId?: ObjectId | undefined
): Promise<BetaMessageParam[]> {
return await Promise.all(
messages
Expand All @@ -57,20 +62,59 @@ export async function endpointMessagesToAnthropicMessages(
return {
role: message.from,
content: [
...(await Promise.all(
(message.files ?? []).map(async (file) => {
if (file.mime.startsWith("image/")) {
return fileToImageBlock(file, multimodal.image);
} else if (file.mime === "application/pdf" && multimodal.document) {
return fileToDocumentBlock(file, multimodal.document);
} else {
throw new Error(`Unsupported file type: ${file.mime}`);
}
})
)),
...(message.from === "user"
? await Promise.all(
(message.files ?? []).map(async (file) => {
if (file.type === "hash" && conversationId) {
file = await downloadFile(file.value, conversationId);
}

if (file.mime.startsWith("image/")) {
return fileToImageBlock(file, multimodal.image);
} else if (file.mime === "application/pdf" && multimodal.document) {
return fileToDocumentBlock(file, multimodal.document);
} else {
throw new Error(`Unsupported file type: ${file.mime}`);
}
})
)
: []),
{ type: "text", text: message.content },
],
};
})
);
}

export function addToolResults(
messages: BetaMessageParam[],
toolResults: ToolResult[]
): BetaMessageParam[] {
const id = crypto.randomUUID();
if (toolResults.length === 0) {
return messages;
}
return [
...messages,
{
role: "assistant",
content: toolResults.map((result, index) => ({
type: "tool_use",
id: `tool_${index}_${id}`,
name: result.call.name,
input: result.call.parameters,
})),
},
{
role: "user",
content: toolResults.map((result, index) => ({
type: "tool_result",
tool_use_id: `tool_${index}_${id}`,
is_error: result.status === "error",
content: JSON.stringify(
result.status === "error" ? result.message : "outputs" in result ? result.outputs : ""
),
})),
},
];
}
6 changes: 4 additions & 2 deletions src/lib/server/textGeneration/generate.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { ToolResult } from "$lib/types/Tool";
import type { ToolResult, Tool } from "$lib/types/Tool";
import {
MessageReasoningUpdateType,
MessageUpdateType,
Expand All @@ -16,7 +16,8 @@ type GenerateContext = Omit<TextGenerationContext, "messages"> & { messages: End
export async function* generate(
{ model, endpoint, conv, messages, assistant, isContinue, promptedAt }: GenerateContext,
toolResults: ToolResult[],
preprompt?: string
preprompt?: string,
tools?: Tool[]
): AsyncIterable<MessageUpdate> {
// reasoning mode is false by default
let reasoning = false;
Expand All @@ -43,6 +44,7 @@ export async function* generate(
preprompt,
continueMessage: isContinue,
generateSettings: assistant?.generateSettings,
tools,
toolResults,
isMultimodal: model.multimodal,
conversationId: conv._id,
Expand Down
11 changes: 7 additions & 4 deletions src/lib/server/textGeneration/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import { mergeAsyncGenerators } from "$lib/utils/mergeAsyncGenerators";
import type { TextGenerationContext } from "./types";
import type { ToolResult } from "$lib/types/Tool";
import { toolHasName } from "../tools/utils";
import directlyAnswer from "../tools/directlyAnswer";

async function* keepAlive(done: AbortSignal): AsyncGenerator<MessageUpdate, undefined, undefined> {
while (!done.aborted) {
Expand Down Expand Up @@ -73,11 +74,13 @@ async function* textGenerationWithoutTitle(
}

let toolResults: ToolResult[] = [];
let tools = model.tools ? await getTools(toolsPreference, ctx.assistant) : undefined;

if (model.tools) {
const tools = await getTools(toolsPreference, ctx.assistant);
const toolCallsRequired = tools.some((tool) => !toolHasName("directly_answer", tool));
if (toolCallsRequired) toolResults = yield* runTools(ctx, tools, preprompt);
if (tools) {
const toolCallsRequired = tools.some((tool) => !toolHasName(directlyAnswer.name, tool));
if (toolCallsRequired) {
toolResults = yield* runTools(ctx, tools, preprompt);
} else tools = undefined;
}

const processedMessages = await preprocessMessages(messages, webSearchResult, convId);
Expand Down
2 changes: 1 addition & 1 deletion src/lib/server/textGeneration/tools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ export async function* runTools(
}

// if we dont see a tool call in the first 25 chars, something is going wrong and we abort
if (rawText.length > 25 && !(rawText.includes("```json") || rawText.includes("{"))) {
if (rawText.length > 100 && !(rawText.includes("```json") || rawText.includes("{"))) {
return [];
}

Expand Down
Loading