From 1206e96334e8dde1eedf055df14457192d328353 Mon Sep 17 00:00:00 2001 From: Peli de Halleux Date: Thu, 9 Jan 2025 16:42:02 +0000 Subject: [PATCH] =?UTF-8?q?feat:=20=E2=9C=A8=20add=20audio=20transcription?= =?UTF-8?q?=20support=20with=20Whisper-1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- packages/core/src/chat.ts | 11 +++ packages/core/src/host.ts | 12 ++- packages/core/src/llms.json | 3 +- packages/core/src/promptdom.ts | 3 +- packages/core/src/runpromptcontext.ts | 92 ++++++++++++++++---- packages/core/src/testhost.ts | 9 +- packages/core/src/types/prompt_template.d.ts | 67 +++++++++++++- packages/core/src/types/prompt_type.d.ts | 16 +++- packages/core/src/vectorsearch.ts | 7 +- 9 files changed, 189 insertions(+), 31 deletions(-) diff --git a/packages/core/src/chat.ts b/packages/core/src/chat.ts index 95ada20f5..7465c48f7 100644 --- a/packages/core/src/chat.ts +++ b/packages/core/src/chat.ts @@ -140,11 +140,22 @@ export type PullModelFunction = ( options: TraceOptions & CancellationOptions ) => Promise<{ ok: boolean; error?: SerializedError }> +export type CreateTranscriptionRequest = { + file: BufferLike +} & TranscriptionOptions + +export type TranscribeFunction = ( + req: CreateTranscriptionRequest, + cfg: LanguageModelConfiguration, + options: TraceOptions & CancellationOptions +) => Promise + export interface LanguageModel { id: string completer: ChatCompletionHandler listModels?: ListModelsFunction pullModel?: PullModelFunction + transcribe?: TranscribeFunction } async function runToolCalls( diff --git a/packages/core/src/host.ts b/packages/core/src/host.ts index a9c460ac8..26a74f9e4 100644 --- a/packages/core/src/host.ts +++ b/packages/core/src/host.ts @@ -2,7 +2,12 @@ import { CancellationOptions, CancellationToken } from "./cancellation" import { LanguageModel } from "./chat" import { Progress } from "./progress" import { MarkdownTrace, TraceOptions } from "./trace" -import { AzureCredentialsType, LanguageModelConfiguration, Project, ResponseStatus } from "./server/messages" +import { + AzureCredentialsType, + LanguageModelConfiguration, + Project, + ResponseStatus, +} from "./server/messages" import { HostConfiguration } from "./hostconfiguration" // this is typically an instance of TextDecoder @@ -135,7 +140,10 @@ export interface RuntimeHost extends Host { azureToken: AzureTokenResolver modelAliases: Readonly - pullModel(model: string, options?: TraceOptions): Promise + pullModel( + model: string, + options?: TraceOptions & CancellationOptions + ): Promise setModelAlias( source: "env" | "cli" | "config" | "script", diff --git a/packages/core/src/llms.json b/packages/core/src/llms.json index 3a8707d30..cc3ef14ce 100644 --- a/packages/core/src/llms.json +++ b/packages/core/src/llms.json @@ -11,7 +11,8 @@ "vision": "gpt-4o", "embeddings": "text-embedding-3-small", "reasoning": "o1", - "reasoning_small": "o1-mini" + "reasoning_small": "o1-mini", + "transcribe": "whisper-1" } }, { diff --git a/packages/core/src/promptdom.ts b/packages/core/src/promptdom.ts index ef267d796..2ddad36e4 100644 --- a/packages/core/src/promptdom.ts +++ b/packages/core/src/promptdom.ts @@ -37,6 +37,7 @@ import { startMcpServer } from "./mcp" import { tryZodToJsonSchema } from "./zod" import { GROQEvaluate } from "./groq" import { trimNewlines } from "./unwrappers" +import { CancellationOptions } from "./cancellation" // Definition of the PromptNode interface which is an essential part of the code structure. export interface PromptNode extends ContextExpansionOptions { @@ -1140,7 +1141,7 @@ async function deduplicatePromptNode(trace: MarkdownTrace, root: PromptNode) { export async function renderPromptNode( modelId: string, node: PromptNode, - options?: ModelTemplateOptions & TraceOptions + options?: ModelTemplateOptions & TraceOptions & CancellationOptions ): Promise { const { trace, flexTokens } = options || {} const { encode: encoder } = await resolveTokenEncoder(modelId) diff --git a/packages/core/src/runpromptcontext.ts b/packages/core/src/runpromptcontext.ts index 018c30c11..678815cab 100644 --- a/packages/core/src/runpromptcontext.ts +++ b/packages/core/src/runpromptcontext.ts @@ -625,6 +625,65 @@ export function createChatGenerationContext( return p } + const transcribe = async ( + file: BufferLike, + options?: TranscriptionOptions + ): Promise => { + const transcriptionTrace = trace.startTraceDetails("🎤 transcribe") + try { + const conn: ModelConnectionOptions = { + model: options?.model || "transcribe", + } + const { info, configuration } = await resolveModelConnectionInfo( + conn, + { + trace: transcriptionTrace, + cancellationToken, + token: true, + } + ) + if (info.error) throw new Error(info.error) + if (!configuration) throw new Error("model configuration not found") + checkCancelled(cancellationToken) + const { ok } = await runtimeHost.pullModel(conn.model, { + trace: transcriptionTrace, + cancellationToken, + }) + if (!ok) throw new Error(`failed to pull model ${conn}`) + checkCancelled(cancellationToken) + const { transcribe } = await resolveLanguageModel( + configuration.provider + ) + if (!transcribe) + throw new Error("model driver not found for " + info.model) + const res = await transcribe( + { + file, + language: options?.language, + translate: options?.translate, + }, + configuration, + { + trace: transcriptionTrace, + cancellationToken, + } + ) + trace.fence(res.text, "markdown") + if (res.error) trace.error(errorMessage(res.error)) + if (res.segments) trace.fence(res.segments, "yaml") + return res + } catch (e) { + logError(e) + transcriptionTrace.error(e) + return { + text: undefined, + error: serializeError(e), + } satisfies TranscriptionResult + } finally { + transcriptionTrace.endDetails() + } + } + const runPrompt = async ( generator: string | PromptGenerator, runOptions?: PromptGeneratorOptions @@ -639,11 +698,15 @@ export function createChatGenerationContext( genOptions.fallbackTools = undefined genOptions.inner = true genOptions.trace = runTrace - const { info } = await resolveModelConnectionInfo(genOptions, { - trace, - token: true, - }) + const { info, configuration } = await resolveModelConnectionInfo( + genOptions, + { + trace: runTrace, + token: true, + } + ) if (info.error) throw new Error(info.error) + if (!configuration) throw new Error("model configuration not found") genOptions.model = info.model genOptions.stats = genOptions.stats.createChild( genOptions.model, @@ -652,6 +715,7 @@ export function createChatGenerationContext( const { ok } = await runtimeHost.pullModel(genOptions.model, { trace: runTrace, + cancellationToken, }) if (!ok) throw new Error(`failed to pull model ${genOptions.model}`) @@ -700,6 +764,7 @@ export function createChatGenerationContext( flexTokens: genOptions.flexTokens, fenceFormat: genOptions.fenceFormat, trace: runTrace, + cancellationToken, }) schemas = scs @@ -796,22 +861,12 @@ export function createChatGenerationContext( messages.push(toChatCompletionUserMessage("", images)) finalizeMessages(messages, { fileOutputs }) - const connection = await resolveModelConnectionInfo(genOptions, { - trace: runTrace, - token: true, - }) - checkCancelled(cancellationToken) - if (!connection.configuration) - throw new Error( - "missing model connection information for " + - genOptions.model - ) const { completer } = await resolveLanguageModel( - connection.configuration.provider + configuration.provider ) - checkCancelled(cancellationToken) if (!completer) - throw new Error("model driver not found for " + connection.info) + throw new Error("model driver not found for " + info.model) + checkCancelled(cancellationToken) const modelConcurrency = options.modelConcurrency?.[genOptions.model] ?? @@ -822,7 +877,7 @@ export function createChatGenerationContext( ) const resp = await modelLimit(() => executeChatSession( - connection.configuration, + configuration, cancellationToken, messages, tools, @@ -875,6 +930,7 @@ export function createChatGenerationContext( defFileMerge, prompt, runPrompt, + transcribe, }) return ctx diff --git a/packages/core/src/testhost.ts b/packages/core/src/testhost.ts index 501ed8185..53a4588ed 100644 --- a/packages/core/src/testhost.ts +++ b/packages/core/src/testhost.ts @@ -29,8 +29,13 @@ import { } from "node:path" import { LanguageModel } from "./chat" import { NotSupportedError } from "./error" -import { LanguageModelConfiguration, Project, ResponseStatus } from "./server/messages" +import { + LanguageModelConfiguration, + Project, + ResponseStatus, +} from "./server/messages" import { defaultModelConfigurations } from "./llms" +import { CancellationToken } from "./cancellation" // Function to create a frozen object representing Node.js path methods // This object provides utility methods for path manipulations @@ -69,7 +74,7 @@ export class TestHost implements RuntimeHost { } async pullModel( model: string, - options?: TraceOptions + options?: TraceOptions & CancellationToken ): Promise { return { ok: true } } diff --git a/packages/core/src/types/prompt_template.d.ts b/packages/core/src/types/prompt_template.d.ts index 4981dda54..ca6704b11 100644 --- a/packages/core/src/types/prompt_template.d.ts +++ b/packages/core/src/types/prompt_template.d.ts @@ -195,6 +195,8 @@ type ModelVisionType = OptionsOrString< "openai:gpt-4o" | "github:gpt-4o" | "azure:gpt-4o" | "azure:gpt-4o-mini" > +type ModelTranscriptionType = OptionsOrString<"openai:whisper-1"> + type ModelProviderType = | "openai" | "azure" @@ -2679,6 +2681,68 @@ type McpServersConfig = Record> type ZodTypeLike = { _def: any; safeParse: any; refine: any } +type BufferLike = string | WorkspaceFile | Buffer | Blob | ArrayBuffer | ReadableStream + +interface TranscriptionOptions { + /** + * Model to use for transcription. By default uses the `transcribe` alias. + */ + model?: TranscribeModelType + + /** + * Translate to English. + */ + translate?: boolean + + /** + * Input language in iso-639-1 format. + * @see https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes + */ + language?: string + + /** + * The sampling temperature, between 0 and 1. + * Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. + */ + temperature?: number +} + +interface TranscriptionResult { + /** + * Complete transcription text + */ + text: string + /** + * Error if any + */ + error?: SerializedError + /** + * Individual segments + */ + segments?: { + /** + * The start time of the segment + */ + start: number + /** + * The transcribed text. + */ + text: string + /** + * Seek offset of the segment + */ + seek?: number + /** + * End time in seconds + */ + end?: number + /** + * Temperature used for the generation of the segment + */ + temperature?: number + }[] +} + interface ChatGenerationContext extends ChatTurnGenerationContext { defSchema( name: string, @@ -2686,7 +2750,7 @@ interface ChatGenerationContext extends ChatTurnGenerationContext { options?: DefSchemaOptions ): string defImages( - files: ElementOrArray, + files: ElementOrArray, options?: DefImagesOptions ): void defTool( @@ -2729,6 +2793,7 @@ interface ChatGenerationContext extends ChatTurnGenerationContext { ): RunPromptResultPromiseWithOptions defFileMerge(fn: FileMergeHandler): void defOutputProcessor(fn: PromptOutputProcessorHandler): void + transcribe(audio: BufferLike, options?: TranscriptionOptions): Promise } interface GenerationOutput { diff --git a/packages/core/src/types/prompt_type.d.ts b/packages/core/src/types/prompt_type.d.ts index adf147588..b582f1ef2 100644 --- a/packages/core/src/types/prompt_type.d.ts +++ b/packages/core/src/types/prompt_type.d.ts @@ -39,7 +39,7 @@ declare function writeText( ): void /** - * Append given string to the prompt as an assistant mesage. + * Append given string to the prompt as an assistant message. */ declare function assistant( text: Awaitable, @@ -259,9 +259,7 @@ declare function defSchema( * @param options */ declare function defImages( - files: ElementOrArray< - string | WorkspaceFile | Buffer | Blob | ArrayBuffer | ReadableStream - >, + files: ElementOrArray, options?: DefImagesOptions ): void @@ -328,3 +326,13 @@ declare function defChatParticipant( participant: ChatParticipantHandler, options?: ChatParticipantOptions ): void + +/** + * Transcribes audio to text. + * @param audio An audio file to transcribe. + * @param options + */ +declare function transcribe( + audio: BufferLike, + options?: TranscriptionOptions +): Promise diff --git a/packages/core/src/vectorsearch.ts b/packages/core/src/vectorsearch.ts index 188a1e663..05a361e64 100644 --- a/packages/core/src/vectorsearch.ts +++ b/packages/core/src/vectorsearch.ts @@ -20,6 +20,7 @@ import { LanguageModelConfiguration } from "./server/messages" import { getConfigHeaders } from "./openai" import { logVerbose, trimTrailingSlash } from "./util" import { TraceOptions } from "./trace" +import { CancellationOptions } from "./cancellation" /** * Represents the cache key for embeddings. @@ -169,7 +170,8 @@ class OpenAIEmbeddings implements EmbeddingsModel { export async function vectorSearch( query: string, files: WorkspaceFile[], - options: VectorSearchOptions & { folderPath: string } & TraceOptions + options: VectorSearchOptions & { folderPath: string } & TraceOptions & + CancellationOptions ): Promise { const { topK, @@ -177,6 +179,7 @@ export async function vectorSearch( embeddingsModel = runtimeHost.modelAliases.embeddings.model, minScore = 0, trace, + cancellationToken, } = options trace?.startDetails(`🔍 embeddings`) @@ -203,7 +206,7 @@ export async function vectorSearch( throw new Error("No configuration found for vector search") // Pull the model - await runtimeHost.pullModel(info.model, { trace }) + await runtimeHost.pullModel(info.model, { trace, cancellationToken }) const embeddings = new OpenAIEmbeddings(info, configuration, { trace }) // Create a local document index