From e19644268c2f2970b555732ecda9f7a3b5c506b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Oliver=20J=C3=A4gle?= Date: Fri, 17 Jan 2025 23:09:19 +0100 Subject: [PATCH] feat: configure dynamic providers via .env (#1108) * Use backend API route to fetch dynamic models # Conflicts: # app/components/chat/BaseChat.tsx * Override ApiKeys if provided in frontend * Remove obsolete artifact * Transport api keys from client to server in header * Cache static provider information * Restore reading provider settings from cookie * Reload only a single provider on api key change * Transport apiKeys and providerSettings via cookies. While doing this, introduce a simple helper function for cookies --- app/components/chat/BaseChat.tsx | 81 ++++++++++------------------ app/lib/api/cookies.ts | 33 ++++++++++++ app/lib/modules/llm/manager.ts | 2 +- app/routes/api.enhancer.ts | 33 ++---------- app/routes/api.llmcall.ts | 44 ++++++---------- app/routes/api.models.$provider.ts | 2 + app/routes/api.models.ts | 84 ++++++++++++++++++++++++++++-- app/utils/constants.ts | 35 +------------ 8 files changed, 164 insertions(+), 150 deletions(-) create mode 100644 app/lib/api/cookies.ts create mode 100644 app/routes/api.models.$provider.ts diff --git a/app/components/chat/BaseChat.tsx b/app/components/chat/BaseChat.tsx index bf995a3d9..4bfc038c5 100644 --- a/app/components/chat/BaseChat.tsx +++ b/app/components/chat/BaseChat.tsx @@ -3,13 +3,13 @@ * Preventing TS checks with files presented in the video for a better presentation. */ import type { Message } from 'ai'; -import React, { type RefCallback, useCallback, useEffect, useState } from 'react'; +import React, { type RefCallback, useEffect, useState } from 'react'; import { ClientOnly } from 'remix-utils/client-only'; import { Menu } from '~/components/sidebar/Menu.client'; import { IconButton } from '~/components/ui/IconButton'; import { Workbench } from '~/components/workbench/Workbench.client'; import { classNames } from '~/utils/classNames'; -import { MODEL_LIST, PROVIDER_LIST, initializeModelList } from '~/utils/constants'; +import { PROVIDER_LIST } from '~/utils/constants'; import { Messages } from './Messages.client'; import { SendButton } from './SendButton.client'; import { APIKeyManager, getApiKeysFromCookies } from './APIKeyManager'; @@ -25,13 +25,13 @@ import GitCloneButton from './GitCloneButton'; import FilePreview from './FilePreview'; import { ModelSelector } from '~/components/chat/ModelSelector'; import { SpeechRecognitionButton } from '~/components/chat/SpeechRecognition'; -import type { IProviderSetting, ProviderInfo } from '~/types/model'; +import type { ProviderInfo } from '~/types/model'; import { ScreenshotStateManager } from './ScreenshotStateManager'; import { toast } from 'react-toastify'; import StarterTemplates from './StarterTemplates'; import type { ActionAlert } from '~/types/actions'; import ChatAlert from './ChatAlert'; -import { LLMManager } from '~/lib/modules/llm/manager'; +import type { ModelInfo } from '~/lib/modules/llm/types'; const TEXTAREA_MIN_HEIGHT = 76; @@ -102,35 +102,13 @@ export const BaseChat = React.forwardRef( ) => { const TEXTAREA_MAX_HEIGHT = chatStarted ? 400 : 200; const [apiKeys, setApiKeys] = useState>(getApiKeysFromCookies()); - const [modelList, setModelList] = useState(MODEL_LIST); + const [modelList, setModelList] = useState([]); const [isModelSettingsCollapsed, setIsModelSettingsCollapsed] = useState(false); const [isListening, setIsListening] = useState(false); const [recognition, setRecognition] = useState(null); const [transcript, setTranscript] = useState(''); const [isModelLoading, setIsModelLoading] = useState('all'); - const getProviderSettings = useCallback(() => { - let providerSettings: Record | undefined = undefined; - - try { - const savedProviderSettings = Cookies.get('providers'); - - if (savedProviderSettings) { - const parsedProviderSettings = JSON.parse(savedProviderSettings); - - if (typeof parsedProviderSettings === 'object' && parsedProviderSettings !== null) { - providerSettings = parsedProviderSettings; - } - } - } catch (error) { - console.error('Error loading Provider Settings from cookies:', error); - - // Clear invalid cookie data - Cookies.remove('providers'); - } - - return providerSettings; - }, []); useEffect(() => { console.log(transcript); }, [transcript]); @@ -169,7 +147,6 @@ export const BaseChat = React.forwardRef( useEffect(() => { if (typeof window !== 'undefined') { - const providerSettings = getProviderSettings(); let parsedApiKeys: Record | undefined = {}; try { @@ -177,17 +154,18 @@ export const BaseChat = React.forwardRef( setApiKeys(parsedApiKeys); } catch (error) { console.error('Error loading API keys from cookies:', error); - - // Clear invalid cookie data Cookies.remove('apiKeys'); } + setIsModelLoading('all'); - initializeModelList({ apiKeys: parsedApiKeys, providerSettings }) - .then((modelList) => { - setModelList(modelList); + fetch('/api/models') + .then((response) => response.json()) + .then((data) => { + const typedData = data as { modelList: ModelInfo[] }; + setModelList(typedData.modelList); }) .catch((error) => { - console.error('Error initializing model list:', error); + console.error('Error fetching model list:', error); }) .finally(() => { setIsModelLoading(undefined); @@ -200,29 +178,24 @@ export const BaseChat = React.forwardRef( setApiKeys(newApiKeys); Cookies.set('apiKeys', JSON.stringify(newApiKeys)); - const provider = LLMManager.getInstance(import.meta.env || process.env || {}).getProvider(providerName); + setIsModelLoading(providerName); - if (provider && provider.getDynamicModels) { - setIsModelLoading(providerName); + let providerModels: ModelInfo[] = []; - try { - const providerSettings = getProviderSettings(); - const staticModels = provider.staticModels; - const dynamicModels = await provider.getDynamicModels( - newApiKeys, - providerSettings, - import.meta.env || process.env || {}, - ); - - setModelList((preModels) => { - const filteredOutPreModels = preModels.filter((x) => x.provider !== providerName); - return [...filteredOutPreModels, ...staticModels, ...dynamicModels]; - }); - } catch (error) { - console.error('Error loading dynamic models:', error); - } - setIsModelLoading(undefined); + try { + const response = await fetch(`/api/models/${encodeURIComponent(providerName)}`); + const data = await response.json(); + providerModels = (data as { modelList: ModelInfo[] }).modelList; + } catch (error) { + console.error('Error loading dynamic models for:', providerName, error); } + + // Only update models for the specific provider + setModelList((prevModels) => { + const otherModels = prevModels.filter((model) => model.provider !== providerName); + return [...otherModels, ...providerModels]; + }); + setIsModelLoading(undefined); }; const startListening = () => { diff --git a/app/lib/api/cookies.ts b/app/lib/api/cookies.ts new file mode 100644 index 000000000..fa6862ea3 --- /dev/null +++ b/app/lib/api/cookies.ts @@ -0,0 +1,33 @@ +export function parseCookies(cookieHeader: string | null) { + const cookies: Record = {}; + + if (!cookieHeader) { + return cookies; + } + + // Split the cookie string by semicolons and spaces + const items = cookieHeader.split(';').map((cookie) => cookie.trim()); + + items.forEach((item) => { + const [name, ...rest] = item.split('='); + + if (name && rest.length > 0) { + // Decode the name and value, and join value parts in case it contains '=' + const decodedName = decodeURIComponent(name.trim()); + const decodedValue = decodeURIComponent(rest.join('=').trim()); + cookies[decodedName] = decodedValue; + } + }); + + return cookies; +} + +export function getApiKeysFromCookie(cookieHeader: string | null): Record { + const cookies = parseCookies(cookieHeader); + return cookies.apiKeys ? JSON.parse(cookies.apiKeys) : {}; +} + +export function getProviderSettingsFromCookie(cookieHeader: string | null): Record { + const cookies = parseCookies(cookieHeader); + return cookies.providers ? JSON.parse(cookies.providers) : {}; +} diff --git a/app/lib/modules/llm/manager.ts b/app/lib/modules/llm/manager.ts index 5b134218d..88ae28c91 100644 --- a/app/lib/modules/llm/manager.ts +++ b/app/lib/modules/llm/manager.ts @@ -83,7 +83,7 @@ export class LLMManager { let enabledProviders = Array.from(this._providers.values()).map((p) => p.name); - if (providerSettings) { + if (providerSettings && Object.keys(providerSettings).length > 0) { enabledProviders = enabledProviders.filter((p) => providerSettings[p].enabled); } diff --git a/app/routes/api.enhancer.ts b/app/routes/api.enhancer.ts index 5d16ac256..5f6db1f03 100644 --- a/app/routes/api.enhancer.ts +++ b/app/routes/api.enhancer.ts @@ -1,34 +1,13 @@ import { type ActionFunctionArgs } from '@remix-run/cloudflare'; - -//import { StreamingTextResponse, parseStreamPart } from 'ai'; import { streamText } from '~/lib/.server/llm/stream-text'; import { stripIndents } from '~/utils/stripIndent'; -import type { IProviderSetting, ProviderInfo } from '~/types/model'; +import type { ProviderInfo } from '~/types/model'; +import { getApiKeysFromCookie, getProviderSettingsFromCookie } from '~/lib/api/cookies'; export async function action(args: ActionFunctionArgs) { return enhancerAction(args); } -function parseCookies(cookieHeader: string) { - const cookies: any = {}; - - // Split the cookie string by semicolons and spaces - const items = cookieHeader.split(';').map((cookie) => cookie.trim()); - - items.forEach((item) => { - const [name, ...rest] = item.split('='); - - if (name && rest) { - // Decode the name and value, and join value parts in case it contains '=' - const decodedName = decodeURIComponent(name.trim()); - const decodedValue = decodeURIComponent(rest.join('=').trim()); - cookies[decodedName] = decodedValue; - } - }); - - return cookies; -} - async function enhancerAction({ context, request }: ActionFunctionArgs) { const { message, model, provider } = await request.json<{ message: string; @@ -55,12 +34,8 @@ async function enhancerAction({ context, request }: ActionFunctionArgs) { } const cookieHeader = request.headers.get('Cookie'); - - // Parse the cookie's value (returns an object or null if no cookie exists) - const apiKeys = JSON.parse(parseCookies(cookieHeader || '').apiKeys || '{}'); - const providerSettings: Record = JSON.parse( - parseCookies(cookieHeader || '').providers || '{}', - ); + const apiKeys = getApiKeysFromCookie(cookieHeader); + const providerSettings = getProviderSettingsFromCookie(cookieHeader); try { const result = await streamText({ diff --git a/app/routes/api.llmcall.ts b/app/routes/api.llmcall.ts index 0fc3c85ec..a4a775519 100644 --- a/app/routes/api.llmcall.ts +++ b/app/routes/api.llmcall.ts @@ -1,34 +1,24 @@ import { type ActionFunctionArgs } from '@remix-run/cloudflare'; - -//import { StreamingTextResponse, parseStreamPart } from 'ai'; import { streamText } from '~/lib/.server/llm/stream-text'; import type { IProviderSetting, ProviderInfo } from '~/types/model'; import { generateText } from 'ai'; -import { getModelList, PROVIDER_LIST } from '~/utils/constants'; +import { PROVIDER_LIST } from '~/utils/constants'; import { MAX_TOKENS } from '~/lib/.server/llm/constants'; +import { LLMManager } from '~/lib/modules/llm/manager'; +import type { ModelInfo } from '~/lib/modules/llm/types'; +import { getApiKeysFromCookie, getProviderSettingsFromCookie } from '~/lib/api/cookies'; export async function action(args: ActionFunctionArgs) { return llmCallAction(args); } -function parseCookies(cookieHeader: string) { - const cookies: any = {}; - - // Split the cookie string by semicolons and spaces - const items = cookieHeader.split(';').map((cookie) => cookie.trim()); - - items.forEach((item) => { - const [name, ...rest] = item.split('='); - - if (name && rest) { - // Decode the name and value, and join value parts in case it contains '=' - const decodedName = decodeURIComponent(name.trim()); - const decodedValue = decodeURIComponent(rest.join('=').trim()); - cookies[decodedName] = decodedValue; - } - }); - - return cookies; +async function getModelList(options: { + apiKeys?: Record; + providerSettings?: Record; + serverEnv?: Record; +}) { + const llmManager = LLMManager.getInstance(import.meta.env); + return llmManager.updateModelList(options); } async function llmCallAction({ context, request }: ActionFunctionArgs) { @@ -58,12 +48,8 @@ async function llmCallAction({ context, request }: ActionFunctionArgs) { } const cookieHeader = request.headers.get('Cookie'); - - // Parse the cookie's value (returns an object or null if no cookie exists) - const apiKeys = JSON.parse(parseCookies(cookieHeader || '').apiKeys || '{}'); - const providerSettings: Record = JSON.parse( - parseCookies(cookieHeader || '').providers || '{}', - ); + const apiKeys = getApiKeysFromCookie(cookieHeader); + const providerSettings = getProviderSettingsFromCookie(cookieHeader); if (streamOutput) { try { @@ -105,8 +91,8 @@ async function llmCallAction({ context, request }: ActionFunctionArgs) { } } else { try { - const MODEL_LIST = await getModelList({ apiKeys, providerSettings, serverEnv: context.cloudflare.env as any }); - const modelDetails = MODEL_LIST.find((m) => m.name === model); + const models = await getModelList({ apiKeys, providerSettings, serverEnv: context.cloudflare.env as any }); + const modelDetails = models.find((m: ModelInfo) => m.name === model); if (!modelDetails) { throw new Error('Model not found'); diff --git a/app/routes/api.models.$provider.ts b/app/routes/api.models.$provider.ts new file mode 100644 index 000000000..d60817625 --- /dev/null +++ b/app/routes/api.models.$provider.ts @@ -0,0 +1,2 @@ +import { loader } from './api.models'; +export { loader }; diff --git a/app/routes/api.models.ts b/app/routes/api.models.ts index ace4ef009..13588f900 100644 --- a/app/routes/api.models.ts +++ b/app/routes/api.models.ts @@ -1,6 +1,84 @@ import { json } from '@remix-run/cloudflare'; -import { MODEL_LIST } from '~/utils/constants'; +import { LLMManager } from '~/lib/modules/llm/manager'; +import type { ModelInfo } from '~/lib/modules/llm/types'; +import type { ProviderInfo } from '~/types/model'; +import { getApiKeysFromCookie, getProviderSettingsFromCookie } from '~/lib/api/cookies'; -export async function loader() { - return json(MODEL_LIST); +interface ModelsResponse { + modelList: ModelInfo[]; + providers: ProviderInfo[]; + defaultProvider: ProviderInfo; +} + +let cachedProviders: ProviderInfo[] | null = null; +let cachedDefaultProvider: ProviderInfo | null = null; + +function getProviderInfo(llmManager: LLMManager) { + if (!cachedProviders) { + cachedProviders = llmManager.getAllProviders().map((provider) => ({ + name: provider.name, + staticModels: provider.staticModels, + getApiKeyLink: provider.getApiKeyLink, + labelForGetApiKey: provider.labelForGetApiKey, + icon: provider.icon, + })); + } + + if (!cachedDefaultProvider) { + const defaultProvider = llmManager.getDefaultProvider(); + cachedDefaultProvider = { + name: defaultProvider.name, + staticModels: defaultProvider.staticModels, + getApiKeyLink: defaultProvider.getApiKeyLink, + labelForGetApiKey: defaultProvider.labelForGetApiKey, + icon: defaultProvider.icon, + }; + } + + return { providers: cachedProviders, defaultProvider: cachedDefaultProvider }; +} + +export async function loader({ + request, + params, +}: { + request: Request; + params: { provider?: string }; +}): Promise { + const llmManager = LLMManager.getInstance(import.meta.env); + + // Get client side maintained API keys and provider settings from cookies + const cookieHeader = request.headers.get('Cookie'); + const apiKeys = getApiKeysFromCookie(cookieHeader); + const providerSettings = getProviderSettingsFromCookie(cookieHeader); + + const { providers, defaultProvider } = getProviderInfo(llmManager); + + let modelList: ModelInfo[] = []; + + if (params.provider) { + // Only update models for the specific provider + const provider = llmManager.getProvider(params.provider); + + if (provider) { + const staticModels = provider.staticModels; + const dynamicModels = provider.getDynamicModels + ? await provider.getDynamicModels(apiKeys, providerSettings, import.meta.env) + : []; + modelList = [...staticModels, ...dynamicModels]; + } + } else { + // Update all models + modelList = await llmManager.updateModelList({ + apiKeys, + providerSettings, + serverEnv: import.meta.env, + }); + } + + return json({ + modelList, + providers, + defaultProvider, + }); } diff --git a/app/utils/constants.ts b/app/utils/constants.ts index 31e72b77c..621fcdaf8 100644 --- a/app/utils/constants.ts +++ b/app/utils/constants.ts @@ -1,7 +1,4 @@ -import type { IProviderSetting } from '~/types/model'; - import { LLMManager } from '~/lib/modules/llm/manager'; -import type { ModelInfo } from '~/lib/modules/llm/types'; import type { Template } from '~/types/template'; export const WORK_DIR_NAME = 'project'; @@ -17,9 +14,7 @@ const llmManager = LLMManager.getInstance(import.meta.env); export const PROVIDER_LIST = llmManager.getAllProviders(); export const DEFAULT_PROVIDER = llmManager.getDefaultProvider(); -let MODEL_LIST = llmManager.getModelList(); - -const providerBaseUrlEnvKeys: Record = {}; +export const providerBaseUrlEnvKeys: Record = {}; PROVIDER_LIST.forEach((provider) => { providerBaseUrlEnvKeys[provider.name] = { baseUrlKey: provider.config.baseUrlKey, @@ -27,34 +22,6 @@ PROVIDER_LIST.forEach((provider) => { }; }); -// Export the getModelList function using the manager -export async function getModelList(options: { - apiKeys?: Record; - providerSettings?: Record; - serverEnv?: Record; -}) { - return await llmManager.updateModelList(options); -} - -async function initializeModelList(options: { - env?: Record; - providerSettings?: Record; - apiKeys?: Record; -}): Promise { - const { providerSettings, apiKeys, env } = options; - const list = await getModelList({ - apiKeys, - providerSettings, - serverEnv: env, - }); - MODEL_LIST = list || MODEL_LIST; - - return list; -} - -// initializeModelList({}) -export { initializeModelList, providerBaseUrlEnvKeys, MODEL_LIST }; - // starter Templates export const STARTER_TEMPLATES: Template[] = [