Skip to content

Commit

Permalink
feat: ✨ add audio transcription support with Whisper-1
Browse files Browse the repository at this point in the history
  • Loading branch information
pelikhan committed Jan 9, 2025
1 parent 8c9456e commit 1206e96
Show file tree
Hide file tree
Showing 9 changed files with 189 additions and 31 deletions.
11 changes: 11 additions & 0 deletions packages/core/src/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,22 @@ export type PullModelFunction = (
options: TraceOptions & CancellationOptions
) => Promise<{ ok: boolean; error?: SerializedError }>

export type CreateTranscriptionRequest = {
file: BufferLike
} & TranscriptionOptions

export type TranscribeFunction = (
req: CreateTranscriptionRequest,
cfg: LanguageModelConfiguration,
options: TraceOptions & CancellationOptions
) => Promise<TranscriptionResult>

export interface LanguageModel {
id: string
completer: ChatCompletionHandler
listModels?: ListModelsFunction
pullModel?: PullModelFunction
transcribe?: TranscribeFunction
}

async function runToolCalls(
Expand Down
12 changes: 10 additions & 2 deletions packages/core/src/host.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@ import { CancellationOptions, CancellationToken } from "./cancellation"
import { LanguageModel } from "./chat"
import { Progress } from "./progress"
import { MarkdownTrace, TraceOptions } from "./trace"
import { AzureCredentialsType, LanguageModelConfiguration, Project, ResponseStatus } from "./server/messages"
import {
AzureCredentialsType,
LanguageModelConfiguration,
Project,
ResponseStatus,
} from "./server/messages"
import { HostConfiguration } from "./hostconfiguration"

// this is typically an instance of TextDecoder
Expand Down Expand Up @@ -135,7 +140,10 @@ export interface RuntimeHost extends Host {
azureToken: AzureTokenResolver
modelAliases: Readonly<ModelConfigurations>

pullModel(model: string, options?: TraceOptions): Promise<ResponseStatus>
pullModel(
model: string,
options?: TraceOptions & CancellationOptions
): Promise<ResponseStatus>

setModelAlias(
source: "env" | "cli" | "config" | "script",
Expand Down
3 changes: 2 additions & 1 deletion packages/core/src/llms.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
"vision": "gpt-4o",
"embeddings": "text-embedding-3-small",
"reasoning": "o1",
"reasoning_small": "o1-mini"
"reasoning_small": "o1-mini",
"transcribe": "whisper-1"
}
},
{
Expand Down
3 changes: 2 additions & 1 deletion packages/core/src/promptdom.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import { startMcpServer } from "./mcp"
import { tryZodToJsonSchema } from "./zod"
import { GROQEvaluate } from "./groq"
import { trimNewlines } from "./unwrappers"
import { CancellationOptions } from "./cancellation"

// Definition of the PromptNode interface which is an essential part of the code structure.
export interface PromptNode extends ContextExpansionOptions {
Expand Down Expand Up @@ -1140,7 +1141,7 @@ async function deduplicatePromptNode(trace: MarkdownTrace, root: PromptNode) {
export async function renderPromptNode(
modelId: string,
node: PromptNode,
options?: ModelTemplateOptions & TraceOptions
options?: ModelTemplateOptions & TraceOptions & CancellationOptions
): Promise<PromptNodeRender> {
const { trace, flexTokens } = options || {}
const { encode: encoder } = await resolveTokenEncoder(modelId)
Expand Down
92 changes: 74 additions & 18 deletions packages/core/src/runpromptcontext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,65 @@ export function createChatGenerationContext(
return p
}

const transcribe = async (
file: BufferLike,
options?: TranscriptionOptions
): Promise<TranscriptionResult> => {
const transcriptionTrace = trace.startTraceDetails("🎤 transcribe")
try {
const conn: ModelConnectionOptions = {
model: options?.model || "transcribe",
}
const { info, configuration } = await resolveModelConnectionInfo(
conn,
{
trace: transcriptionTrace,
cancellationToken,
token: true,
}
)
if (info.error) throw new Error(info.error)
if (!configuration) throw new Error("model configuration not found")
checkCancelled(cancellationToken)
const { ok } = await runtimeHost.pullModel(conn.model, {
trace: transcriptionTrace,
cancellationToken,
})
if (!ok) throw new Error(`failed to pull model ${conn}`)
checkCancelled(cancellationToken)
const { transcribe } = await resolveLanguageModel(
configuration.provider
)
if (!transcribe)
throw new Error("model driver not found for " + info.model)
const res = await transcribe(
{
file,
language: options?.language,
translate: options?.translate,
},
configuration,
{
trace: transcriptionTrace,
cancellationToken,
}
)
trace.fence(res.text, "markdown")
if (res.error) trace.error(errorMessage(res.error))
if (res.segments) trace.fence(res.segments, "yaml")
return res
} catch (e) {
logError(e)
transcriptionTrace.error(e)
return {
text: undefined,
error: serializeError(e),
} satisfies TranscriptionResult
} finally {
transcriptionTrace.endDetails()
}
}

const runPrompt = async (
generator: string | PromptGenerator,
runOptions?: PromptGeneratorOptions
Expand All @@ -639,11 +698,15 @@ export function createChatGenerationContext(
genOptions.fallbackTools = undefined
genOptions.inner = true
genOptions.trace = runTrace
const { info } = await resolveModelConnectionInfo(genOptions, {
trace,
token: true,
})
const { info, configuration } = await resolveModelConnectionInfo(
genOptions,
{
trace: runTrace,
token: true,
}
)
if (info.error) throw new Error(info.error)
if (!configuration) throw new Error("model configuration not found")
genOptions.model = info.model
genOptions.stats = genOptions.stats.createChild(
genOptions.model,
Expand All @@ -652,6 +715,7 @@ export function createChatGenerationContext(

const { ok } = await runtimeHost.pullModel(genOptions.model, {
trace: runTrace,
cancellationToken,
})
if (!ok) throw new Error(`failed to pull model ${genOptions.model}`)

Expand Down Expand Up @@ -700,6 +764,7 @@ export function createChatGenerationContext(
flexTokens: genOptions.flexTokens,
fenceFormat: genOptions.fenceFormat,
trace: runTrace,
cancellationToken,
})

schemas = scs
Expand Down Expand Up @@ -796,22 +861,12 @@ export function createChatGenerationContext(
messages.push(toChatCompletionUserMessage("", images))

finalizeMessages(messages, { fileOutputs })
const connection = await resolveModelConnectionInfo(genOptions, {
trace: runTrace,
token: true,
})
checkCancelled(cancellationToken)
if (!connection.configuration)
throw new Error(
"missing model connection information for " +
genOptions.model
)
const { completer } = await resolveLanguageModel(
connection.configuration.provider
configuration.provider
)
checkCancelled(cancellationToken)
if (!completer)
throw new Error("model driver not found for " + connection.info)
throw new Error("model driver not found for " + info.model)
checkCancelled(cancellationToken)

const modelConcurrency =
options.modelConcurrency?.[genOptions.model] ??
Expand All @@ -822,7 +877,7 @@ export function createChatGenerationContext(
)
const resp = await modelLimit(() =>
executeChatSession(
connection.configuration,
configuration,
cancellationToken,
messages,
tools,
Expand Down Expand Up @@ -875,6 +930,7 @@ export function createChatGenerationContext(
defFileMerge,
prompt,
runPrompt,
transcribe,
})

return ctx
Expand Down
9 changes: 7 additions & 2 deletions packages/core/src/testhost.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,13 @@ import {
} from "node:path"
import { LanguageModel } from "./chat"
import { NotSupportedError } from "./error"
import { LanguageModelConfiguration, Project, ResponseStatus } from "./server/messages"
import {
LanguageModelConfiguration,
Project,
ResponseStatus,
} from "./server/messages"
import { defaultModelConfigurations } from "./llms"
import { CancellationToken } from "./cancellation"

// Function to create a frozen object representing Node.js path methods
// This object provides utility methods for path manipulations
Expand Down Expand Up @@ -69,7 +74,7 @@ export class TestHost implements RuntimeHost {
}
async pullModel(
model: string,
options?: TraceOptions
options?: TraceOptions & CancellationToken
): Promise<ResponseStatus> {
return { ok: true }
}
Expand Down
67 changes: 66 additions & 1 deletion packages/core/src/types/prompt_template.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ type ModelVisionType = OptionsOrString<
"openai:gpt-4o" | "github:gpt-4o" | "azure:gpt-4o" | "azure:gpt-4o-mini"
>

type ModelTranscriptionType = OptionsOrString<"openai:whisper-1">

type ModelProviderType =
| "openai"
| "azure"
Expand Down Expand Up @@ -2679,14 +2681,76 @@ type McpServersConfig = Record<string, Omit<McpServerConfig, "id" | "options">>

type ZodTypeLike = { _def: any; safeParse: any; refine: any }

type BufferLike = string | WorkspaceFile | Buffer | Blob | ArrayBuffer | ReadableStream

interface TranscriptionOptions {
/**
* Model to use for transcription. By default uses the `transcribe` alias.
*/
model?: TranscribeModelType

/**
* Translate to English.
*/
translate?: boolean

/**
* Input language in iso-639-1 format.
* @see https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes
*/
language?: string

/**
* The sampling temperature, between 0 and 1.
* Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
*/
temperature?: number
}

interface TranscriptionResult {
/**
* Complete transcription text
*/
text: string
/**
* Error if any
*/
error?: SerializedError
/**
* Individual segments
*/
segments?: {
/**
* The start time of the segment
*/
start: number
/**
* The transcribed text.
*/
text: string
/**
* Seek offset of the segment
*/
seek?: number
/**
* End time in seconds
*/
end?: number
/**
* Temperature used for the generation of the segment
*/
temperature?: number
}[]
}

interface ChatGenerationContext extends ChatTurnGenerationContext {
defSchema(
name: string,
schema: JSONSchema | ZodTypeLike,
options?: DefSchemaOptions
): string
defImages(
files: ElementOrArray<string | WorkspaceFile | Buffer | Blob | ArrayBuffer | ReadableStream>,
files: ElementOrArray<BufferLike>,
options?: DefImagesOptions
): void
defTool(
Expand Down Expand Up @@ -2729,6 +2793,7 @@ interface ChatGenerationContext extends ChatTurnGenerationContext {
): RunPromptResultPromiseWithOptions
defFileMerge(fn: FileMergeHandler): void
defOutputProcessor(fn: PromptOutputProcessorHandler): void
transcribe(audio: BufferLike, options?: TranscriptionOptions): Promise<TranscriptionResult>
}

interface GenerationOutput {
Expand Down
16 changes: 12 additions & 4 deletions packages/core/src/types/prompt_type.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ declare function writeText(
): void

/**
* Append given string to the prompt as an assistant mesage.
* Append given string to the prompt as an assistant message.
*/
declare function assistant(
text: Awaitable<string>,
Expand Down Expand Up @@ -259,9 +259,7 @@ declare function defSchema(
* @param options
*/
declare function defImages(
files: ElementOrArray<
string | WorkspaceFile | Buffer | Blob | ArrayBuffer | ReadableStream
>,
files: ElementOrArray<BufferLike>,
options?: DefImagesOptions
): void

Expand Down Expand Up @@ -328,3 +326,13 @@ declare function defChatParticipant(
participant: ChatParticipantHandler,
options?: ChatParticipantOptions
): void

/**
* Transcribes audio to text.
* @param audio An audio file to transcribe.
* @param options
*/
declare function transcribe(
audio: BufferLike,
options?: TranscriptionOptions
): Promise<TranscriptionResult>
Loading

0 comments on commit 1206e96

Please sign in to comment.