From 875c1bca2ddfd4af30eb29db6699c62e2985489d Mon Sep 17 00:00:00 2001 From: aoife cassidy Date: Fri, 20 Dec 2024 11:42:00 +0200 Subject: [PATCH] stream synthesis (still slightly broken) --- plugins/cartesia/src/tts.test.ts | 2 +- plugins/cartesia/src/tts.ts | 295 ++++++++++++------------------- plugins/test/src/tts.ts | 2 +- 3 files changed, 115 insertions(+), 184 deletions(-) diff --git a/plugins/cartesia/src/tts.test.ts b/plugins/cartesia/src/tts.test.ts index d811790e..d7238045 100644 --- a/plugins/cartesia/src/tts.test.ts +++ b/plugins/cartesia/src/tts.test.ts @@ -7,5 +7,5 @@ import { describe } from 'vitest'; import { TTS } from './tts.js'; describe('Cartesia', async () => { - await tts(new TTS(), new STT(), { streaming: false }); + await tts(new TTS(), new STT()); }); diff --git a/plugins/cartesia/src/tts.ts b/plugins/cartesia/src/tts.ts index bbe6140f..83fee1ef 100644 --- a/plugins/cartesia/src/tts.ts +++ b/plugins/cartesia/src/tts.ts @@ -1,12 +1,11 @@ // SPDX-FileCopyrightText: 2024 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 -import { AsyncIterableQueue, AudioByteStream, log, tokenize, tts } from '@livekit/agents'; -import { AudioFrame } from '@livekit/rtc-node'; +import { AudioByteStream, log, tokenize, tts } from '@livekit/agents'; +import type { AudioFrame } from '@livekit/rtc-node'; import { randomUUID } from 'node:crypto'; import { request } from 'node:https'; -import { URL } from 'node:url'; -import { type RawData, WebSocket } from 'ws'; +import { WebSocket } from 'ws'; import { TTSDefaultVoiceId, type TTSEncoding, @@ -69,8 +68,7 @@ export class TTS extends tts.TTS { } stream(): tts.SynthesizeStream { - throw new Error(); - // return new SynthesizeStream(this, this.#opts); + return new SynthesizeStream(this, this.#opts); } } @@ -134,182 +132,115 @@ export class ChunkedStream extends tts.ChunkedStream { } } -// export class SynthesizeStream extends tts.SynthesizeStream { -// #opts: TTSOptions; -// #logger = log(); -// label = 'cartesia.SynthesizeStream'; -// readonly streamURL: URL; - -// constructor(tts: TTS, opts: TTSOptions) { -// super(tts); -// this.#opts = opts; -// this.closed = false; - -// // add trailing slash to URL if needed -// const baseURL = opts.baseURL + (opts.baseURL.endsWith('/') ? '' : '/'); - -// this.streamURL = new URL(`text-to-speech/${opts.voice.id}/stream-input`, baseURL); -// const params = { -// model_id: opts.modelID, -// output_format: opts.encoding, -// optimize_streaming_latency: `${opts.streamingLatency}`, -// enable_ssml_parsing: `${opts.enableSsmlParsing}`, -// }; -// Object.entries(params).forEach(([k, v]) => this.streamURL.searchParams.append(k, v)); -// this.streamURL.protocol = this.streamURL.protocol.replace('http', 'ws'); - -// this.#run(); -// } - -// async #run() { -// const segments = new AsyncIterableQueue(); - -// const tokenizeInput = async () => { -// let stream: tokenize.WordStream | null = null; -// for await (const text of this.input) { -// if (text === SynthesizeStream.FLUSH_SENTINEL) { -// stream?.endInput(); -// stream = null; -// } else { -// if (!stream) { -// stream = this.#opts.wordTokenizer.stream(); -// segments.put(stream); -// } -// stream.pushText(text); -// } -// } -// segments.close(); -// }; - -// const runStream = async () => { -// for await (const stream of segments) { -// await this.#runWS(stream); -// this.queue.put(SynthesizeStream.END_OF_STREAM); -// } -// }; - -// await Promise.all([tokenizeInput(), runStream()]); -// this.close(); -// } - -// async #runWS(stream: tokenize.WordStream, maxRetry = 3) { -// let retries = 0; -// let ws: WebSocket; -// while (true) { -// ws = new WebSocket(this.streamURL, { -// headers: { [AUTHORIZATION_HEADER]: this.#opts.apiKey }, -// }); - -// try { -// await new Promise((resolve, reject) => { -// ws.on('open', resolve); -// ws.on('error', (error) => reject(error)); -// ws.on('close', (code) => reject(`WebSocket returned ${code}`)); -// }); -// break; -// } catch (e) { -// if (retries >= maxRetry) { -// throw new Error(`failed to connect to ElevenLabs after ${retries} attempts: ${e}`); -// } - -// const delay = Math.min(retries * 5, 5); -// retries++; - -// this.#logger.warn( -// `failed to connect to ElevenLabs, retrying in ${delay} seconds: ${e} (${retries}/${maxRetry})`, -// ); -// await new Promise((resolve) => setTimeout(resolve, delay * 1000)); -// } -// } - -// const requestId = randomUUID(); -// const segmentId = randomUUID(); - -// ws.send( -// JSON.stringify({ -// text: ' ', -// voice_settings: this.#opts.voice.settings, -// try_trigger_generation: true, -// chunk_length_schedule: this.#opts.chunkLengthSchedule, -// }), -// ); -// let eosSent = false; - -// const sendTask = async () => { -// let xmlContent: string[] = []; -// for await (const data of stream) { -// let text = data.token; - -// if ((this.#opts.enableSsmlParsing && text.startsWith('') !== -1) { -// text = xmlContent.join(' '); -// xmlContent = []; -// } else { -// continue; -// } -// } - -// ws.send(JSON.stringify({ text: text + ' ', try_trigger_generation: false })); -// } - -// if (xmlContent.length) { -// this.#logger.warn('ElevenLabs stream ended with incomplete XML content'); -// } - -// ws.send(JSON.stringify({ text: '' })); -// eosSent = true; -// }; - -// let lastFrame: AudioFrame | undefined; -// const sendLastFrame = (segmentId: string, final: boolean) => { -// if (lastFrame) { -// this.queue.put({ requestId, segmentId, frame: lastFrame, final }); -// lastFrame = undefined; -// } -// }; - -// const listenTask = async () => { -// while (!this.closed) { -// try { -// await new Promise((resolve, reject) => { -// ws.removeAllListeners(); -// ws.on('message', (data) => resolve(data)); -// ws.on('close', (code, reason) => { -// if (!eosSent) { -// this.#logger.error(`WebSocket closed with code ${code}: ${reason}`); -// } -// reject(); -// }); -// }).then((msg) => { -// const json = JSON.parse(msg.toString()); -// if ('audio' in json) { -// const data = new Int16Array(Buffer.from(json.audio, 'base64').buffer); -// const frame = new AudioFrame( -// data, -// sampleRateFromFormat(this.#opts.encoding), -// 1, -// data.length, -// ); -// sendLastFrame(segmentId, false); -// lastFrame = frame; -// } else if ('isFinal' in json) { -// sendLastFrame(segmentId, true); -// } -// }); -// } catch { -// break; -// } -// } -// }; - -// await Promise.all([sendTask(), listenTask()]); -// } -// } - -const sampleRateFromFormat = (encoding: TTSEncoding): number => { - return Number(encoding.split('_')[1]); -}; +export class SynthesizeStream extends tts.SynthesizeStream { + #opts: TTSOptions; + #logger = log(); + #tokenizer = new tokenize.basic.SentenceTokenizer(undefined, BUFFERED_WORDS_COUNT).stream(); + label = 'cartesia.SynthesizeStream'; + + constructor(tts: TTS, opts: TTSOptions) { + super(tts); + this.#opts = opts; + this.#run(); + } + + async #run() { + const requestId = randomUUID(); + let closing = false; + + const sentenceStreamTask = async (ws: WebSocket) => { + const packet = toCartesiaOptions(this.#opts); + for await (const event of this.#tokenizer) { + ws.send( + JSON.stringify({ + ...packet, + context_id: requestId, + transcript: event.token + ' ', + continue: true, + }), + ); + } + + ws.send( + JSON.stringify({ + ...packet, + context_id: requestId, + transcript: ' ', + continue: false, + }), + ); + }; + + const inputTask = async () => { + for await (const data of this.input) { + if (data === SynthesizeStream.FLUSH_SENTINEL) { + this.#tokenizer.flush(); + continue; + } + this.#tokenizer.pushText(data); + } + this.#tokenizer.endInput(); + this.#tokenizer.close(); + }; + + const recvTask = async (ws: WebSocket) => { + const bstream = new AudioByteStream(this.#opts.sampleRate, NUM_CHANNELS); + + let lastFrame: AudioFrame | undefined; + const sendLastFrame = (segmentId: string, final: boolean) => { + if (lastFrame) { + this.queue.put({ requestId, segmentId, frame: lastFrame, final }); + lastFrame = undefined; + } + }; + + ws.on('message', (data) => { + const json = JSON.parse(data.toString()); + const segmentId = json.context_id; + if ('data' in json) { + const data = new Int8Array(Buffer.from(json.data, 'base64').buffer); + for (const frame of bstream.write(data)) { + sendLastFrame(segmentId, false); + lastFrame = frame; + } + } else if ('done' in json) { + for (const frame of bstream.flush()) { + sendLastFrame(segmentId, false); + lastFrame = frame; + } + sendLastFrame(segmentId, true); + + if (segmentId === requestId) { + closing = true; + ws.close(); + return; + } + } + }); + ws.on('close', (code, reason) => { + if (!closing) { + this.#logger.error(`WebSocket closed with code ${code}: ${reason}`); + } + ws.removeAllListeners(); + }); + }; + + const url = `wss://api.cartesia.ai/tts/websocket?api_key=${this.#opts.apiKey}&cartesia_version=${VERSION}`; + const ws = new WebSocket(url); + + try { + await new Promise((resolve, reject) => { + ws.on('open', resolve); + ws.on('error', (error) => reject(error)); + ws.on('close', (code) => reject(`WebSocket returned ${code}`)); + }); + + await Promise.all([inputTask(), sentenceStreamTask(ws), recvTask(ws)]); + } catch (e) { + throw new Error(`failed to connect to Cartesia: ${e}`); + } + } +} const toCartesiaOptions = (opts: TTSOptions): { [id: string]: unknown } => { const voice: { [id: string]: unknown } = {}; diff --git a/plugins/test/src/tts.ts b/plugins/test/src/tts.ts index 6a6324f7..0cf697a7 100644 --- a/plugins/test/src/tts.ts +++ b/plugins/test/src/tts.ts @@ -35,7 +35,7 @@ export const tts = async ( await validate(frames, stt, TEXT, 0.2); }); - it('should properly stream synthesize tests', async () => { + it('should properly stream synthesize text', async () => { let stream: ttslib.SynthesizeStream; if (supports.streaming) { stream = tts.stream();