From 1038d746312ad52b77676a011e807888e56a7044 Mon Sep 17 00:00:00 2001 From: goupilew Date: Mon, 9 Sep 2024 14:58:07 +0200 Subject: [PATCH] feat: add support for multimodal in Vertex (#1338) * feat: add support for multimodal in Vertex * Nit changes and remove tools if multimodal * revert model name change * Fix tools/multimodal condition * chores(lint): fix formatting --------- Co-authored-by: Thomas Co-authored-by: Nathan Sarrazin --- README.md | 14 ++++- .../server/endpoints/google/endpointVertex.ts | 60 +++++++++++++++---- 2 files changed, 59 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 2d70651a098..fed7aa91304 100644 --- a/README.md +++ b/README.md @@ -775,21 +775,29 @@ MODELS=`[ { "name": "gemini-1.5-pro", "displayName": "Vertex Gemini Pro 1.5", + "multimodal": true, "endpoints" : [{ "type": "vertex", "project": "abc-xyz", "location": "europe-west3", "model": "gemini-1.5-pro-preview-0409", // model-name - // Optional "safetyThreshold": "BLOCK_MEDIUM_AND_ABOVE", "apiEndpoint": "", // alternative api endpoint url, - // Optional "tools": [{ "googleSearchRetrieval": { "disableAttribution": true } - }] + }], + "multimodal": { + "image": { + "supportedMimeTypes": ["image/png", "image/jpeg", "image/webp"], + "preferredMimeType": "image/png", + "maxSizeInMB": 5, + "maxWidth": 2000, + "maxHeight": 1000; + } + } }] }, ]` diff --git a/src/lib/server/endpoints/google/endpointVertex.ts b/src/lib/server/endpoints/google/endpointVertex.ts index 4ffee0ec232..ed70f1451b5 100644 --- a/src/lib/server/endpoints/google/endpointVertex.ts +++ b/src/lib/server/endpoints/google/endpointVertex.ts @@ -9,6 +9,7 @@ import type { Endpoint } from "../endpoints"; import { z } from "zod"; import type { Message } from "$lib/types/Message"; import type { TextGenerationStreamOutput } from "@huggingface/inference"; +import { createImageProcessorOptionsValidator, makeImageProcessor } from "../images"; export const endpointVertexParametersSchema = z.object({ weight: z.number().int().positive().default(1), @@ -27,10 +28,28 @@ export const endpointVertexParametersSchema = z.object({ ]) .optional(), tools: z.array(z.any()).optional(), + multimodal: z + .object({ + image: createImageProcessorOptionsValidator({ + supportedMimeTypes: [ + "image/png", + "image/jpeg", + "image/webp", + "image/avif", + "image/tiff", + "image/gif", + ], + preferredMimeType: "image/webp", + maxSizeInMB: Infinity, + maxWidth: 4096, + maxHeight: 4096, + }), + }) + .default({}), }); export function endpointVertex(input: z.input): Endpoint { - const { project, location, model, apiEndpoint, safetyThreshold, tools } = + const { project, location, model, apiEndpoint, safetyThreshold, tools, multimodal } = endpointVertexParametersSchema.parse(input); const vertex_ai = new VertexAI({ @@ -42,6 +61,8 @@ export function endpointVertex(input: z.input { const parameters = { ...model.parameters, ...generateSettings }; + const hasFiles = messages.some((message) => message.files && message.files.length > 0); + const generativeModel = vertex_ai.getGenerativeModel({ model: model.id ?? model.name, safetySettings: safetyThreshold @@ -73,7 +94,8 @@ export function endpointVertex(input: z.input): Content => { - return { - role: from === "user" ? "user" : "model", - parts: [ - { - text: content, - }, - ], - }; - }); + const vertexMessages = await Promise.all( + messages.map(async ({ from, content, files }: Omit): Promise => { + const imageProcessor = makeImageProcessor(multimodal.image); + const processedFiles = + files && files.length > 0 + ? await Promise.all(files.map(async (file) => imageProcessor(file))) + : []; + + return { + role: from === "user" ? "user" : "model", + parts: [ + ...processedFiles.map((processedFile) => ({ + inlineData: { + data: processedFile.image.toString("base64"), + mimeType: processedFile.mime, + }, + })), + { + text: content, + }, + ], + }; + }) + ); const result = await generativeModel.generateContentStream({ contents: vertexMessages,