From 91f8de1d896bec01f488b52d6f297b4c6e3b4ed2 Mon Sep 17 00:00:00 2001 From: Goupil Date: Thu, 25 Apr 2024 11:49:50 +0200 Subject: [PATCH] Simplify tools parameter parsing and add support for passing parameters in model --- .../server/endpoints/google/endpointVertex.ts | 80 ++----------------- 1 file changed, 6 insertions(+), 74 deletions(-) diff --git a/src/lib/server/endpoints/google/endpointVertex.ts b/src/lib/server/endpoints/google/endpointVertex.ts index 60cf856f38d..6392b653e1b 100644 --- a/src/lib/server/endpoints/google/endpointVertex.ts +++ b/src/lib/server/endpoints/google/endpointVertex.ts @@ -4,82 +4,12 @@ import { HarmBlockThreshold, type Content, type TextPart, - FunctionDeclarationSchemaType, - type FunctionDeclarationSchema, - type FunctionDeclarationSchemaProperty, - type Tool, - type FunctionDeclarationsTool, - type FunctionDeclaration, - type GoogleSearchRetrievalTool, - type GoogleSearchRetrieval, - type RetrievalTool, - type Retrieval, - type VertexAISearch, } from "@google-cloud/vertexai"; import type { Endpoint } from "../endpoints"; import { z } from "zod"; import type { Message } from "$lib/types/Message"; import type { TextGenerationStreamOutput } from "@huggingface/inference"; -const vertexAISearchSchema: z.ZodType = z.object({ - datastore: z.string(), -}); - -const retrievalSchema: z.ZodType = z.object({ - vertexAiSearch: vertexAISearchSchema.optional(), - disableAttribution: z.boolean().optional(), -}); - -const retrievalToolSchema: z.ZodType = z.object({ - retrieval: retrievalSchema.optional(), -}); - -const googleSearchRetrievalSchema: z.ZodType = z.object({ - disableAttribution: z.boolean().optional(), -}); - -const googleSearchRetrievalToolSchema: z.ZodType = z.object({ - googleSearchRetrieval: googleSearchRetrievalSchema.optional(), -}); - -const functionDeclarationSchemaTypeSchema = z.nativeEnum(FunctionDeclarationSchemaType); - -const functionDeclarationSchemaPropertySchema: z.ZodType = - z.object({ - type: functionDeclarationSchemaTypeSchema.optional(), - format: z.string().optional(), - description: z.string().optional(), - nullable: z.boolean().optional(), - items: z.lazy(() => functionDeclarationSchemaSchema), - enum: z.array(z.string()).optional(), - properties: z.lazy(() => z.record(z.string(), functionDeclarationSchemaSchema)), - required: z.array(z.string()).optional(), - example: z.unknown().optional(), - }); - -const functionDeclarationSchemaSchema: z.ZodType = z.object({ - type: functionDeclarationSchemaTypeSchema, - properties: z.lazy(() => z.record(z.string(), functionDeclarationSchemaPropertySchema)), - description: z.string().optional(), - required: z.array(z.string()).optional(), -}); - -const functionDeclarationSchema: z.ZodType = z.object({ - name: z.string(), - description: z.string().optional(), - parameters: functionDeclarationSchemaSchema.optional(), -}); - -const functionDeclarationsToolSchema: z.ZodType = z.object({ - functionDeclarations: z.array(functionDeclarationSchema).optional(), -}); - -const toolSchema: z.ZodType = z.union([ - functionDeclarationsToolSchema, - retrievalToolSchema, - googleSearchRetrievalToolSchema, -]); - export const endpointVertexParametersSchema = z.object({ weight: z.number().int().positive().default(1), model: z.any(), // allow optional and validate against emptiness @@ -96,7 +26,7 @@ export const endpointVertexParametersSchema = z.object({ HarmBlockThreshold.BLOCK_ONLY_HIGH, ]) .optional(), - tools: toolSchema.array().optional(), + tools: z.array(z.any()), }); export function endpointVertex(input: z.input): Endpoint { @@ -110,6 +40,8 @@ export function endpointVertex(input: z.input { + const parameters = { ...model.parameters, ...generateSettings }; + const generativeModel = vertex_ai.getGenerativeModel({ model: model.id ?? model.name, safetySettings: safetyThreshold @@ -137,9 +69,9 @@ export function endpointVertex(input: z.input