Skip to content

Commit

Permalink
Simplify tools parameter parsing and add support for passing paramete…
Browse files Browse the repository at this point in the history
…rs in model
  • Loading branch information
ArthurGoupil committed Apr 25, 2024
1 parent c16f551 commit 91f8de1
Showing 1 changed file with 6 additions and 74 deletions.
80 changes: 6 additions & 74 deletions src/lib/server/endpoints/google/endpointVertex.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,82 +4,12 @@ import {
HarmBlockThreshold,
type Content,
type TextPart,
FunctionDeclarationSchemaType,
type FunctionDeclarationSchema,
type FunctionDeclarationSchemaProperty,
type Tool,
type FunctionDeclarationsTool,
type FunctionDeclaration,
type GoogleSearchRetrievalTool,
type GoogleSearchRetrieval,
type RetrievalTool,
type Retrieval,
type VertexAISearch,
} from "@google-cloud/vertexai";
import type { Endpoint } from "../endpoints";
import { z } from "zod";
import type { Message } from "$lib/types/Message";
import type { TextGenerationStreamOutput } from "@huggingface/inference";

const vertexAISearchSchema: z.ZodType<VertexAISearch> = z.object({
datastore: z.string(),
});

const retrievalSchema: z.ZodType<Retrieval> = z.object({
vertexAiSearch: vertexAISearchSchema.optional(),
disableAttribution: z.boolean().optional(),
});

const retrievalToolSchema: z.ZodType<RetrievalTool> = z.object({
retrieval: retrievalSchema.optional(),
});

const googleSearchRetrievalSchema: z.ZodType<GoogleSearchRetrieval> = z.object({
disableAttribution: z.boolean().optional(),
});

const googleSearchRetrievalToolSchema: z.ZodType<GoogleSearchRetrievalTool> = z.object({
googleSearchRetrieval: googleSearchRetrievalSchema.optional(),
});

const functionDeclarationSchemaTypeSchema = z.nativeEnum(FunctionDeclarationSchemaType);

const functionDeclarationSchemaPropertySchema: z.ZodType<FunctionDeclarationSchemaProperty> =
z.object({
type: functionDeclarationSchemaTypeSchema.optional(),
format: z.string().optional(),
description: z.string().optional(),
nullable: z.boolean().optional(),
items: z.lazy(() => functionDeclarationSchemaSchema),
enum: z.array(z.string()).optional(),
properties: z.lazy(() => z.record(z.string(), functionDeclarationSchemaSchema)),
required: z.array(z.string()).optional(),
example: z.unknown().optional(),
});

const functionDeclarationSchemaSchema: z.ZodType<FunctionDeclarationSchema> = z.object({
type: functionDeclarationSchemaTypeSchema,
properties: z.lazy(() => z.record(z.string(), functionDeclarationSchemaPropertySchema)),
description: z.string().optional(),
required: z.array(z.string()).optional(),
});

const functionDeclarationSchema: z.ZodType<FunctionDeclaration> = z.object({
name: z.string(),
description: z.string().optional(),
parameters: functionDeclarationSchemaSchema.optional(),
});

const functionDeclarationsToolSchema: z.ZodType<FunctionDeclarationsTool> = z.object({
functionDeclarations: z.array(functionDeclarationSchema).optional(),
});

const toolSchema: z.ZodType<Tool> = z.union([
functionDeclarationsToolSchema,
retrievalToolSchema,
googleSearchRetrievalToolSchema,
]);

export const endpointVertexParametersSchema = z.object({
weight: z.number().int().positive().default(1),
model: z.any(), // allow optional and validate against emptiness
Expand All @@ -96,7 +26,7 @@ export const endpointVertexParametersSchema = z.object({
HarmBlockThreshold.BLOCK_ONLY_HIGH,
])
.optional(),
tools: toolSchema.array().optional(),
tools: z.array(z.any()),
});

export function endpointVertex(input: z.input<typeof endpointVertexParametersSchema>): Endpoint {
Expand All @@ -110,6 +40,8 @@ export function endpointVertex(input: z.input<typeof endpointVertexParametersSch
});

return async ({ messages, preprompt, generateSettings }) => {
const parameters = { ...model.parameters, ...generateSettings };

const generativeModel = vertex_ai.getGenerativeModel({
model: model.id ?? model.name,
safetySettings: safetyThreshold
Expand Down Expand Up @@ -137,9 +69,9 @@ export function endpointVertex(input: z.input<typeof endpointVertexParametersSch
]
: undefined,
generationConfig: {
maxOutputTokens: generateSettings?.max_new_tokens ?? 4096,
stopSequences: generateSettings?.stop,
temperature: generateSettings?.temperature ?? 1,
maxOutputTokens: parameters?.max_new_tokens ?? 4096,
stopSequences: parameters?.stop,
temperature: parameters?.temperature ?? 1,
},
tools,
});
Expand Down

0 comments on commit 91f8de1

Please sign in to comment.