Skip to content

Commit

Permalink
stream synthesis (still slightly broken)
Browse files Browse the repository at this point in the history
  • Loading branch information
nbsp committed Dec 23, 2024
1 parent d2c60e0 commit 875c1bc
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 184 deletions.
2 changes: 1 addition & 1 deletion plugins/cartesia/src/tts.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ import { describe } from 'vitest';
import { TTS } from './tts.js';

describe('Cartesia', async () => {
await tts(new TTS(), new STT(), { streaming: false });
await tts(new TTS(), new STT());
});
295 changes: 113 additions & 182 deletions plugins/cartesia/src/tts.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
// SPDX-FileCopyrightText: 2024 LiveKit, Inc.
//
// SPDX-License-Identifier: Apache-2.0
import { AsyncIterableQueue, AudioByteStream, log, tokenize, tts } from '@livekit/agents';
import { AudioFrame } from '@livekit/rtc-node';
import { AudioByteStream, log, tokenize, tts } from '@livekit/agents';
import type { AudioFrame } from '@livekit/rtc-node';
import { randomUUID } from 'node:crypto';
import { request } from 'node:https';
import { URL } from 'node:url';
import { type RawData, WebSocket } from 'ws';
import { WebSocket } from 'ws';
import {
TTSDefaultVoiceId,
type TTSEncoding,
Expand Down Expand Up @@ -69,8 +68,7 @@ export class TTS extends tts.TTS {
}

stream(): tts.SynthesizeStream {
throw new Error();
// return new SynthesizeStream(this, this.#opts);
return new SynthesizeStream(this, this.#opts);
}
}

Expand Down Expand Up @@ -134,182 +132,115 @@ export class ChunkedStream extends tts.ChunkedStream {
}
}

// export class SynthesizeStream extends tts.SynthesizeStream {
// #opts: TTSOptions;
// #logger = log();
// label = 'cartesia.SynthesizeStream';
// readonly streamURL: URL;

// constructor(tts: TTS, opts: TTSOptions) {
// super(tts);
// this.#opts = opts;
// this.closed = false;

// // add trailing slash to URL if needed
// const baseURL = opts.baseURL + (opts.baseURL.endsWith('/') ? '' : '/');

// this.streamURL = new URL(`text-to-speech/${opts.voice.id}/stream-input`, baseURL);
// const params = {
// model_id: opts.modelID,
// output_format: opts.encoding,
// optimize_streaming_latency: `${opts.streamingLatency}`,
// enable_ssml_parsing: `${opts.enableSsmlParsing}`,
// };
// Object.entries(params).forEach(([k, v]) => this.streamURL.searchParams.append(k, v));
// this.streamURL.protocol = this.streamURL.protocol.replace('http', 'ws');

// this.#run();
// }

// async #run() {
// const segments = new AsyncIterableQueue<tokenize.WordStream>();

// const tokenizeInput = async () => {
// let stream: tokenize.WordStream | null = null;
// for await (const text of this.input) {
// if (text === SynthesizeStream.FLUSH_SENTINEL) {
// stream?.endInput();
// stream = null;
// } else {
// if (!stream) {
// stream = this.#opts.wordTokenizer.stream();
// segments.put(stream);
// }
// stream.pushText(text);
// }
// }
// segments.close();
// };

// const runStream = async () => {
// for await (const stream of segments) {
// await this.#runWS(stream);
// this.queue.put(SynthesizeStream.END_OF_STREAM);
// }
// };

// await Promise.all([tokenizeInput(), runStream()]);
// this.close();
// }

// async #runWS(stream: tokenize.WordStream, maxRetry = 3) {
// let retries = 0;
// let ws: WebSocket;
// while (true) {
// ws = new WebSocket(this.streamURL, {
// headers: { [AUTHORIZATION_HEADER]: this.#opts.apiKey },
// });

// try {
// await new Promise((resolve, reject) => {
// ws.on('open', resolve);
// ws.on('error', (error) => reject(error));
// ws.on('close', (code) => reject(`WebSocket returned ${code}`));
// });
// break;
// } catch (e) {
// if (retries >= maxRetry) {
// throw new Error(`failed to connect to ElevenLabs after ${retries} attempts: ${e}`);
// }

// const delay = Math.min(retries * 5, 5);
// retries++;

// this.#logger.warn(
// `failed to connect to ElevenLabs, retrying in ${delay} seconds: ${e} (${retries}/${maxRetry})`,
// );
// await new Promise((resolve) => setTimeout(resolve, delay * 1000));
// }
// }

// const requestId = randomUUID();
// const segmentId = randomUUID();

// ws.send(
// JSON.stringify({
// text: ' ',
// voice_settings: this.#opts.voice.settings,
// try_trigger_generation: true,
// chunk_length_schedule: this.#opts.chunkLengthSchedule,
// }),
// );
// let eosSent = false;

// const sendTask = async () => {
// let xmlContent: string[] = [];
// for await (const data of stream) {
// let text = data.token;

// if ((this.#opts.enableSsmlParsing && text.startsWith('<phoneme')) || xmlContent.length) {
// xmlContent.push(text);
// if (text.indexOf('</phoneme>') !== -1) {
// text = xmlContent.join(' ');
// xmlContent = [];
// } else {
// continue;
// }
// }

// ws.send(JSON.stringify({ text: text + ' ', try_trigger_generation: false }));
// }

// if (xmlContent.length) {
// this.#logger.warn('ElevenLabs stream ended with incomplete XML content');
// }

// ws.send(JSON.stringify({ text: '' }));
// eosSent = true;
// };

// let lastFrame: AudioFrame | undefined;
// const sendLastFrame = (segmentId: string, final: boolean) => {
// if (lastFrame) {
// this.queue.put({ requestId, segmentId, frame: lastFrame, final });
// lastFrame = undefined;
// }
// };

// const listenTask = async () => {
// while (!this.closed) {
// try {
// await new Promise<RawData>((resolve, reject) => {
// ws.removeAllListeners();
// ws.on('message', (data) => resolve(data));
// ws.on('close', (code, reason) => {
// if (!eosSent) {
// this.#logger.error(`WebSocket closed with code ${code}: ${reason}`);
// }
// reject();
// });
// }).then((msg) => {
// const json = JSON.parse(msg.toString());
// if ('audio' in json) {
// const data = new Int16Array(Buffer.from(json.audio, 'base64').buffer);
// const frame = new AudioFrame(
// data,
// sampleRateFromFormat(this.#opts.encoding),
// 1,
// data.length,
// );
// sendLastFrame(segmentId, false);
// lastFrame = frame;
// } else if ('isFinal' in json) {
// sendLastFrame(segmentId, true);
// }
// });
// } catch {
// break;
// }
// }
// };

// await Promise.all([sendTask(), listenTask()]);
// }
// }

const sampleRateFromFormat = (encoding: TTSEncoding): number => {
return Number(encoding.split('_')[1]);
};
export class SynthesizeStream extends tts.SynthesizeStream {
#opts: TTSOptions;
#logger = log();
#tokenizer = new tokenize.basic.SentenceTokenizer(undefined, BUFFERED_WORDS_COUNT).stream();
label = 'cartesia.SynthesizeStream';

constructor(tts: TTS, opts: TTSOptions) {
super(tts);
this.#opts = opts;
this.#run();
}

async #run() {
const requestId = randomUUID();
let closing = false;

const sentenceStreamTask = async (ws: WebSocket) => {
const packet = toCartesiaOptions(this.#opts);
for await (const event of this.#tokenizer) {
ws.send(
JSON.stringify({
...packet,
context_id: requestId,
transcript: event.token + ' ',
continue: true,
}),
);
}

ws.send(
JSON.stringify({
...packet,
context_id: requestId,
transcript: ' ',
continue: false,
}),
);
};

const inputTask = async () => {
for await (const data of this.input) {
if (data === SynthesizeStream.FLUSH_SENTINEL) {
this.#tokenizer.flush();
continue;
}
this.#tokenizer.pushText(data);
}
this.#tokenizer.endInput();
this.#tokenizer.close();
};

const recvTask = async (ws: WebSocket) => {
const bstream = new AudioByteStream(this.#opts.sampleRate, NUM_CHANNELS);

let lastFrame: AudioFrame | undefined;
const sendLastFrame = (segmentId: string, final: boolean) => {
if (lastFrame) {
this.queue.put({ requestId, segmentId, frame: lastFrame, final });
lastFrame = undefined;
}
};

ws.on('message', (data) => {
const json = JSON.parse(data.toString());
const segmentId = json.context_id;
if ('data' in json) {
const data = new Int8Array(Buffer.from(json.data, 'base64').buffer);
for (const frame of bstream.write(data)) {
sendLastFrame(segmentId, false);
lastFrame = frame;
}
} else if ('done' in json) {
for (const frame of bstream.flush()) {
sendLastFrame(segmentId, false);
lastFrame = frame;
}
sendLastFrame(segmentId, true);

if (segmentId === requestId) {
closing = true;
ws.close();
return;
}
}
});
ws.on('close', (code, reason) => {
if (!closing) {
this.#logger.error(`WebSocket closed with code ${code}: ${reason}`);
}
ws.removeAllListeners();
});
};

const url = `wss://api.cartesia.ai/tts/websocket?api_key=${this.#opts.apiKey}&cartesia_version=${VERSION}`;
const ws = new WebSocket(url);

try {
await new Promise((resolve, reject) => {
ws.on('open', resolve);
ws.on('error', (error) => reject(error));
ws.on('close', (code) => reject(`WebSocket returned ${code}`));
});

await Promise.all([inputTask(), sentenceStreamTask(ws), recvTask(ws)]);
} catch (e) {
throw new Error(`failed to connect to Cartesia: ${e}`);
}
}
}

const toCartesiaOptions = (opts: TTSOptions): { [id: string]: unknown } => {
const voice: { [id: string]: unknown } = {};
Expand Down
2 changes: 1 addition & 1 deletion plugins/test/src/tts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ export const tts = async (
await validate(frames, stt, TEXT, 0.2);
});

it('should properly stream synthesize tests', async () => {
it('should properly stream synthesize text', async () => {
let stream: ttslib.SynthesizeStream;
if (supports.streaming) {
stream = tts.stream();
Expand Down

0 comments on commit 875c1bc

Please sign in to comment.