diff --git a/.env b/.env index abb84264fdf..6a607b5a0ad 100644 --- a/.env +++ b/.env @@ -16,6 +16,7 @@ ANTHROPIC_API_KEY=#your anthropic api key here CLOUDFLARE_ACCOUNT_ID=#your cloudflare account id here CLOUDFLARE_API_TOKEN=#your cloudflare api token here COHERE_API_TOKEN=#your cohere api token here +GOOGLE_GENAI_API_KEY=#your google genai api token here HF_ACCESS_TOKEN=#LEGACY! Use HF_TOKEN instead diff --git a/docs/source/configuration/models/providers/google.md b/docs/source/configuration/models/providers/google.md index 008baf0cdb4..f43c6335cd6 100644 --- a/docs/source/configuration/models/providers/google.md +++ b/docs/source/configuration/models/providers/google.md @@ -52,7 +52,11 @@ MODELS=`[ Or use the Gemini API API provider [from](https://github.com/google-gemini/generative-ai-js#readme): -> Make sure that you have an API key from Google Cloud Platform. To get an API key, follow the instructions [here](https://cloud.google.com/docs/authentication/api-keys). +Make sure that you have an API key from Google Cloud Platform. To get an API key, follow the instructions [here](https://ai.google.dev/gemini-api/docs/api-key). + +You can either specify them directly in your `.env.local` using the `GOOGLE_GENAI_API_KEY` variables, or you can set them directly in the endpoint config. + +You can find the list of models available [here](https://ai.google.dev/gemini-api/docs/models/gemini), and experimental models available [here](https://ai.google.dev/gemini-api/docs/models/experimental-models). ```ini MODELS=`[ @@ -63,12 +67,12 @@ MODELS=`[ "endpoints": [ { "type": "genai", + + // Optional "apiKey": "abc...xyz" + "safetyThreshold": "BLOCK_MEDIUM_AND_ABOVE", } ] - - // Optional - "safetyThreshold": "BLOCK_MEDIUM_AND_ABOVE", }, { "name": "gemini-1.5-pro", @@ -77,6 +81,8 @@ MODELS=`[ "endpoints": [ { "type": "genai", + + // Optional "apiKey": "abc...xyz" } ] diff --git a/src/lib/server/endpoints/google/endpointGenAI.ts b/src/lib/server/endpoints/google/endpointGenAI.ts index 3a6de6fa675..14667480045 100644 --- a/src/lib/server/endpoints/google/endpointGenAI.ts +++ b/src/lib/server/endpoints/google/endpointGenAI.ts @@ -1,17 +1,18 @@ import { GoogleGenerativeAI, HarmBlockThreshold, HarmCategory } from "@google/generative-ai"; -import type { Content, Part, TextPart } from "@google/generative-ai"; +import type { Content, Part, SafetySetting, TextPart } from "@google/generative-ai"; import { z } from "zod"; import type { Message, MessageFile } from "$lib/types/Message"; import type { TextGenerationStreamOutput } from "@huggingface/inference"; import type { Endpoint } from "../endpoints"; import { createImageProcessorOptionsValidator, makeImageProcessor } from "../images"; import type { ImageProcessorOptions } from "../images"; +import { env } from "$env/dynamic/private"; export const endpointGenAIParametersSchema = z.object({ weight: z.number().int().positive().default(1), model: z.any(), type: z.literal("genai"), - apiKey: z.string(), + apiKey: z.string().default(env.GOOGLE_GENAI_API_KEY), safetyThreshold: z .enum([ HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED, @@ -40,35 +41,24 @@ export function endpointGenAI(input: z.input cat !== HarmCategory.HARM_CATEGORY_UNSPECIFIED) + .reduce((acc, val) => { + acc.push({ + category: val as HarmCategory, + threshold: safetyThreshold, + }); + return acc; + }, [] as SafetySetting[]) + : undefined; + return async ({ messages, preprompt, generateSettings }) => { const parameters = { ...model.parameters, ...generateSettings }; const generativeModel = genAI.getGenerativeModel({ model: model.id ?? model.name, - safetySettings: safetyThreshold - ? [ - { - category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - threshold: safetyThreshold, - }, - { - category: HarmCategory.HARM_CATEGORY_HARASSMENT, - threshold: safetyThreshold, - }, - { - category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, - threshold: safetyThreshold, - }, - { - category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - threshold: safetyThreshold, - }, - { - category: HarmCategory.HARM_CATEGORY_UNSPECIFIED, - threshold: safetyThreshold, - }, - ] - : undefined, + safetySettings, generationConfig: { maxOutputTokens: parameters?.max_new_tokens ?? 4096, stopSequences: parameters?.stop,