diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ef4141cd..1cc1dcc1 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -60,6 +60,10 @@ pnpm install ``` ANTHROPIC_API_KEY=XXX +PROVIDER= gemini | antrophic | openai | ollama +MODEL_NAME=XXX +GOOGLE_GENERATIVE_AI_API_KEY=XXX +OPEN_AI_API_KEY=XXX ``` Optionally, you can set the debug level: diff --git a/app/lib/.server/llm/api-key.ts b/app/lib/.server/llm/api-key.ts index 863f7636..769843f7 100644 --- a/app/lib/.server/llm/api-key.ts +++ b/app/lib/.server/llm/api-key.ts @@ -5,5 +5,15 @@ export function getAPIKey(cloudflareEnv: Env) { * The `cloudflareEnv` is only used when deployed or when previewing locally. * In development the environment variables are available through `env`. */ + const provider = cloudflareEnv.PROVIDER || 'anthropic'; + + if (provider === 'gemini') { + return cloudflareEnv.GOOGLE_GENERATIVE_AI_API_KEY || (env.GOOGLE_GENERATIVE_AI_API_KEY as string); + } + + if (provider === 'openai') { + return cloudflareEnv.OPEN_AI_API_KEY || (env.OPEN_AI_API_KEY as string); + } + return env.ANTHROPIC_API_KEY || cloudflareEnv.ANTHROPIC_API_KEY; } diff --git a/app/lib/.server/llm/get-model.ts b/app/lib/.server/llm/get-model.ts new file mode 100644 index 00000000..0a066c4a --- /dev/null +++ b/app/lib/.server/llm/get-model.ts @@ -0,0 +1,25 @@ +import type { ModelFactory } from './providers/modelFactory'; +import { AnthropicFactory } from './providers/anthropic'; +import { OpenAiFactory } from './providers/openAi'; +import { GeminiFactory } from './providers/gemini'; +import { OllamaFactory } from './providers/ollama'; + +export function getModelFactory(provider: string): ModelFactory { + switch (provider.toLowerCase()) { + case 'anthropic': { + return new AnthropicFactory(); + } + case 'openai': { + return new OpenAiFactory(); + } + case 'gemini': { + return new GeminiFactory(); + } + case 'ollama': { + return new OllamaFactory(); + } + default: { + throw new Error(`Unsupported provider: ${provider}`); + } + } +} diff --git a/app/lib/.server/llm/model.ts b/app/lib/.server/llm/model.ts deleted file mode 100644 index f0d695c4..00000000 --- a/app/lib/.server/llm/model.ts +++ /dev/null @@ -1,9 +0,0 @@ -import { createAnthropic } from '@ai-sdk/anthropic'; - -export function getAnthropicModel(apiKey: string) { - const anthropic = createAnthropic({ - apiKey, - }); - - return anthropic('claude-3-5-sonnet-20240620'); -} diff --git a/app/lib/.server/llm/providers/anthropic.ts b/app/lib/.server/llm/providers/anthropic.ts new file mode 100644 index 00000000..6bd89609 --- /dev/null +++ b/app/lib/.server/llm/providers/anthropic.ts @@ -0,0 +1,16 @@ +import { createAnthropic } from '@ai-sdk/anthropic'; +import type { ModelFactory } from './modelFactory'; + +export function getAnthropicModel(apiKey: string, modelName: string = 'claude-3-5-sonnet-20240620') { + const anthropic = createAnthropic({ + apiKey, + }); + + return anthropic(modelName); +} + +export class AnthropicFactory implements ModelFactory { + createModel(apiKey: string, modelName: string) { + return getAnthropicModel(apiKey, modelName); + } +} diff --git a/app/lib/.server/llm/providers/gemini.ts b/app/lib/.server/llm/providers/gemini.ts new file mode 100644 index 00000000..8512b570 --- /dev/null +++ b/app/lib/.server/llm/providers/gemini.ts @@ -0,0 +1,16 @@ +import { createGoogleGenerativeAI } from '@ai-sdk/google'; +import type { ModelFactory } from './modelFactory'; + +export function getGeminiModel(apiKey: string, modelName: string = 'gemini-1.5-pro-latest') { + const model = createGoogleGenerativeAI({ + apiKey, + }); + + return model(modelName); +} + +export class GeminiFactory implements ModelFactory { + createModel(apiKey: string, modelName: string) { + return getGeminiModel(apiKey, modelName); + } +} diff --git a/app/lib/.server/llm/providers/modelFactory.ts b/app/lib/.server/llm/providers/modelFactory.ts new file mode 100644 index 00000000..1288b916 --- /dev/null +++ b/app/lib/.server/llm/providers/modelFactory.ts @@ -0,0 +1,4 @@ +import type { LanguageModel } from 'ai'; +export interface ModelFactory { + createModel(apiKey: string, modelName: string): LanguageModel; +} diff --git a/app/lib/.server/llm/providers/ollama.ts b/app/lib/.server/llm/providers/ollama.ts new file mode 100644 index 00000000..abedb750 --- /dev/null +++ b/app/lib/.server/llm/providers/ollama.ts @@ -0,0 +1,15 @@ +import { createOllama } from 'ollama-ai-provider'; +import type { ModelFactory } from './modelFactory'; + +export function getOllamaModel(apiKey: string, modelName: string = 'llama3.2:latest') { + const model = createOllama({ + baseURL: 'http://172.21.208.1:11434', + }); + return model(modelName); +} + +export class OllamaFactory implements ModelFactory { + createModel(apiKey: string, modelName: string) { + return getOllamaModel(apiKey, modelName); + } +} diff --git a/app/lib/.server/llm/providers/openAi.ts b/app/lib/.server/llm/providers/openAi.ts new file mode 100644 index 00000000..5937774c --- /dev/null +++ b/app/lib/.server/llm/providers/openAi.ts @@ -0,0 +1,16 @@ +import { createOpenAI } from '@ai-sdk/openai'; +import type { ModelFactory } from './modelFactory'; + +export function getOpenAiModel(apiKey: string, modelName: string = 'gpt-4o-mini') { + const model = createOpenAI({ + apiKey, + }); + + return model(modelName); +} + +export class OpenAiFactory implements ModelFactory { + createModel(apiKey: string, modelName: string) { + return getOpenAiModel(apiKey, modelName); + } +} diff --git a/app/lib/.server/llm/stream-text.ts b/app/lib/.server/llm/stream-text.ts index cf937fd0..b198922a 100644 --- a/app/lib/.server/llm/stream-text.ts +++ b/app/lib/.server/llm/stream-text.ts @@ -1,6 +1,6 @@ import { streamText as _streamText, convertToCoreMessages } from 'ai'; import { getAPIKey } from '~/lib/.server/llm/api-key'; -import { getAnthropicModel } from '~/lib/.server/llm/model'; +import { getModelFactory } from '~/lib/.server/llm/get-model'; import { MAX_TOKENS } from './constants'; import { getSystemPrompt } from './prompts'; @@ -22,13 +22,22 @@ export type Messages = Message[]; export type StreamingOptions = Omit[0], 'model'>; export function streamText(messages: Messages, env: Env, options?: StreamingOptions) { + const provider = env.PROVIDER || 'anthropic'; + const modelName = env.MODEL_NAME || 'default-model'; + const factory = getModelFactory(provider); + + const model = factory.createModel(getAPIKey(env), modelName); + return _streamText({ - model: getAnthropicModel(getAPIKey(env)), + model, system: getSystemPrompt(), maxTokens: MAX_TOKENS, - headers: { - 'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15', - }, + headers: + provider === 'anthropic' + ? { + 'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15', + } + : undefined, messages: convertToCoreMessages(messages), ...options, }); diff --git a/package.json b/package.json index 55834556..399e9ff7 100644 --- a/package.json +++ b/package.json @@ -23,7 +23,9 @@ "node": ">=18.18.0" }, "dependencies": { - "@ai-sdk/anthropic": "^0.0.39", + "@ai-sdk/anthropic": "^0.0.51", + "@ai-sdk/google": "^0.0.51", + "@ai-sdk/openai": "^0.0.66", "@codemirror/autocomplete": "^6.17.0", "@codemirror/commands": "^6.6.0", "@codemirror/lang-cpp": "^6.0.2", @@ -54,7 +56,7 @@ "@xterm/addon-fit": "^0.10.0", "@xterm/addon-web-links": "^0.11.0", "@xterm/xterm": "^5.5.0", - "ai": "^3.3.4", + "ai": "^3.4.9", "date-fns": "^3.6.0", "diff": "^5.2.0", "framer-motion": "^11.2.12", @@ -62,6 +64,8 @@ "istextorbinary": "^9.5.0", "jose": "^5.6.3", "nanostores": "^0.10.3", + "ollama": "^0.5.9", + "ollama-ai-provider": "^0.15.1", "react": "^18.2.0", "react-dom": "^18.2.0", "react-hotkeys-hook": "^4.5.0", diff --git a/worker-configuration.d.ts b/worker-configuration.d.ts index 606a4e52..46b88dbe 100644 --- a/worker-configuration.d.ts +++ b/worker-configuration.d.ts @@ -1,3 +1,7 @@ interface Env { ANTHROPIC_API_KEY: string; + PROVIDER: string; + MODEL_NAME: string; + GOOGLE_GENERATIVE_AI_API_KEY: string; + OPEN_AI_API_KEY: string; }