From 0b053250d8de6814d20ebf42eccb44c90ad642c0 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Wed, 18 Dec 2024 14:11:25 -0500 Subject: [PATCH] =?UTF-8?q?=F0=9F=A4=96=20feat:=20Support=20new=20`o1`=20m?= =?UTF-8?q?odel?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/app/clients/OpenAIClient.js | 11 ++++++++--- api/utils/tokens.js | 2 +- api/utils/tokens.spec.js | 26 ++++++++++++++++++++++++++ 3 files changed, 35 insertions(+), 4 deletions(-) diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index 9cde8b56e98..f81a4bc92c7 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -107,7 +107,8 @@ class OpenAIClient extends BaseClient { this.checkVisionRequest(this.options.attachments); } - this.isO1Model = /\bo1\b/i.test(this.modelOptions.model); + const o1Pattern = /\bo1\b/i; + this.isO1Model = o1Pattern.test(this.modelOptions.model); const { OPENROUTER_API_KEY, OPENAI_FORCE_PROMPT } = process.env ?? {}; if (OPENROUTER_API_KEY && !this.azure) { @@ -147,7 +148,7 @@ class OpenAIClient extends BaseClient { const { model } = this.modelOptions; this.isChatCompletion = - /\bo1\b/i.test(model) || model.includes('gpt') || this.useOpenRouter || !!reverseProxy; + o1Pattern.test(model) || model.includes('gpt') || this.useOpenRouter || !!reverseProxy; this.isChatGptModel = this.isChatCompletion; if ( model.includes('text-davinci') || @@ -1325,7 +1326,11 @@ ${convo} /** @type {(value: void | PromiseLike) => void} */ let streamResolve; - if (this.isO1Model === true && this.azure && modelOptions.stream) { + if ( + this.isO1Model === true && + (this.azure || /o1(?!-(?:mini|preview)).*$/.test(modelOptions.model)) && + modelOptions.stream + ) { delete modelOptions.stream; delete modelOptions.stop; } diff --git a/api/utils/tokens.js b/api/utils/tokens.js index b2c9cedf2f6..ac123c9dd95 100644 --- a/api/utils/tokens.js +++ b/api/utils/tokens.js @@ -2,7 +2,7 @@ const z = require('zod'); const { EModelEndpoint } = require('librechat-data-provider'); const openAIModels = { - o1: 127500, // -500 from max + o1: 195000, // -5000 from max 'o1-mini': 127500, // -500 from max 'o1-preview': 127500, // -500 from max 'gpt-4': 8187, // -5 from max diff --git a/api/utils/tokens.spec.js b/api/utils/tokens.spec.js index cacf72cb4a8..b1f37bb1f46 100644 --- a/api/utils/tokens.spec.js +++ b/api/utils/tokens.spec.js @@ -248,6 +248,32 @@ describe('getModelMaxTokens', () => { test('should return undefined for a model when using an unsupported endpoint', () => { expect(getModelMaxTokens('azure-gpt-3', 'unsupportedEndpoint')).toBeUndefined(); }); + + test('should return correct max context tokens for o1-series models', () => { + // Standard o1 variations + const o1Tokens = maxTokensMap[EModelEndpoint.openAI]['o1']; + expect(getModelMaxTokens('o1')).toBe(o1Tokens); + expect(getModelMaxTokens('o1-latest')).toBe(o1Tokens); + expect(getModelMaxTokens('o1-2024-12-17')).toBe(o1Tokens); + expect(getModelMaxTokens('o1-something-else')).toBe(o1Tokens); + expect(getModelMaxTokens('openai/o1-something-else')).toBe(o1Tokens); + + // Mini variations + const o1MiniTokens = maxTokensMap[EModelEndpoint.openAI]['o1-mini']; + expect(getModelMaxTokens('o1-mini')).toBe(o1MiniTokens); + expect(getModelMaxTokens('o1-mini-latest')).toBe(o1MiniTokens); + expect(getModelMaxTokens('o1-mini-2024-09-12')).toBe(o1MiniTokens); + expect(getModelMaxTokens('o1-mini-something')).toBe(o1MiniTokens); + expect(getModelMaxTokens('openai/o1-mini-something')).toBe(o1MiniTokens); + + // Preview variations + const o1PreviewTokens = maxTokensMap[EModelEndpoint.openAI]['o1-preview']; + expect(getModelMaxTokens('o1-preview')).toBe(o1PreviewTokens); + expect(getModelMaxTokens('o1-preview-latest')).toBe(o1PreviewTokens); + expect(getModelMaxTokens('o1-preview-2024-09-12')).toBe(o1PreviewTokens); + expect(getModelMaxTokens('o1-preview-something')).toBe(o1PreviewTokens); + expect(getModelMaxTokens('openai/o1-preview-something')).toBe(o1PreviewTokens); + }); }); describe('matchModelName', () => {