Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add services to better support Chinese conversation #476

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions examples/foundational/07l-interruptible-Chinese.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

import asyncio
import aiohttp
import os
import sys

from pipecat.frames.frames import Frame, LLMMessagesFrame, MetricsFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_response import (
LLMAssistantResponseAggregator, LLMUserResponseAggregator)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.services.cartesia import CartesiaTTSService
from pipecat.services.chattts import ChatTTSTTSService
from pipecat.services.doubao import DoubaoLLMService
from pipecat.services.ollama import OLLamaLLMService
from pipecat.services.openai import OpenAILLMService
from pipecat.services.tencentstt import TencentSTTService
from pipecat.transports.services.daily import DailyParams, DailyTransport
from pipecat.vad.silero import SileroVADAnalyzer

from runner import configure

from loguru import logger

from dotenv import load_dotenv
load_dotenv(override=True)

logger.remove(0)
logger.add(sys.stderr, level="DEBUG")


class MetricsLogger(FrameProcessor):
async def process_frame(self, frame: Frame, direction: FrameDirection):
if isinstance(frame, MetricsFrame):
print(
f"!!! MetricsFrame: {frame}, ttfb: {frame.ttfb}, processing: {frame.processing}, tokens: {frame.tokens}, characters: {frame.characters}")
await self.push_frame(frame, direction)


async def main():
async with aiohttp.ClientSession() as session:
(room_url, token) = await configure(session)

transport = DailyTransport(
room_url,
token,
"Respond bot",
DailyParams(
audio_out_enabled=True,
audio_in_enabled=True,
transcription_enabled=False,
vad_enabled=True,
vad_audio_passthrough=True,
vad_analyzer=SileroVADAnalyzer(),
)
)
stt = TencentSTTService()

# you need to setup a ChatTTS service
tts = ChatTTSTTSService(
aiohttp_session=session,
api_url=os.getenv("CHATTTS_API_URL", "http://localhost:8555/generate")
)

# llm = OLLamaLLMService()

llm = DoubaoLLMService(
model=os.getenv("DOUBAO_MODEL_ID"), # DOUBAO_MODEL_ID
)

messages = [
{
"role": "system",
"content": """你是一个智能客服, 友好的回答用户问题

注意:
用[uv_break]表示断句。
如果有需要笑的地方,请加[laugh],不要用[smile]。
如果有数字,输出中文数字,比如一,三十一,五百八十九。不要用阿拉伯数字。
"""
}
]

ml = MetricsLogger()
tma_in = LLMUserResponseAggregator(messages)
tma_out = LLMAssistantResponseAggregator(messages)

pipeline = Pipeline([
transport.input(), # Transport user input
stt, # STT
tma_in, # User responses
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
tma_out # Assistant spoken responses
])

task = PipelineTask(pipeline, PipelineParams(
allow_interruptions=True,
enable_metrics=True,
enable_usage_metrics=True,
report_only_initial_ttfb=True,
))

@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
transport.capture_participant_transcription(participant["id"])
# # Kick off the conversation.
guzi5618 marked this conversation as resolved.
Show resolved Hide resolved
messages.append(
{"role": "user", "content": "向用户问好.不超过10个字"})
await task.queue_frames([LLMMessagesFrame(messages)])

runner = PipelineRunner()

await runner.run(task)


if __name__ == "__main__":
asyncio.run(main())
55 changes: 55 additions & 0 deletions src/pipecat/services/chattts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@


import aiohttp
from loguru import logger
import requests
from typing import AsyncGenerator
from pipecat.frames.frames import Frame, AudioRawFrame, TTSStartedFrame, TTSStoppedFrame, ErrorFrame
from pipecat.services.ai_services import TTSService


class ChatTTSTTSService(TTSService):
def __init__(self, *, api_url: str, aiohttp_session: aiohttp.ClientSession, **kwargs):
super().__init__(**kwargs)
self.api_url = api_url
self._aiohttp_session = aiohttp_session
print(f"chattts url: {self.api_url}")
guzi5618 marked this conversation as resolved.
Show resolved Hide resolved

async def set_model(self, model: str):
pass

async def set_voice(self, voice: str):
pass

async def set_language(self, language: str):
pass

async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
try:

logger.debug(f"ChatTTS Generating TTS: [{text}]")
payload = {
"texts": [text],
}
await self.start_ttfb_metrics()
async with self._aiohttp_session.post(self.api_url, json=payload) as response:

if response.status != 200:
text = await response.text()
raise Exception(f"Error getting audio: {text}")
guzi5618 marked this conversation as resolved.
Show resolved Hide resolved
await self.start_tts_usage_metrics(text)
await self.push_frame(TTSStartedFrame())
guzi5618 marked this conversation as resolved.
Show resolved Hide resolved

async for chunk in response.content.iter_chunked(1024):
if len(chunk) > 0:
await self.stop_ttfb_metrics()
frame = AudioRawFrame(chunk, 24000, 1)
guzi5618 marked this conversation as resolved.
Show resolved Hide resolved
yield frame

await self.push_frame(TTSStoppedFrame())
guzi5618 marked this conversation as resolved.
Show resolved Hide resolved

except requests.exceptions.RequestException as e:
yield ErrorFrame(f"Request to ChatTTS failed: {e}")

except Exception as e:
yield ErrorFrame(f"Unexpected error: {str(e)}")
116 changes: 116 additions & 0 deletions src/pipecat/services/doubao.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@

# doubao is a Chinese llm service belonging to Bytedance
# install doubao sdk
# pip install 'volcengine-python-sdk[ark]'

import os
from dotenv import load_dotenv
from volcenginesdkarkruntime import Ark
from typing import AsyncGenerator
from loguru import logger
import json
from pipecat.services.ai_services import LLMService
from pipecat.frames.frames import Frame, LLMMessagesFrame, TextFrame, LLMFullResponseStartFrame, LLMFullResponseEndFrame, UserStartedSpeakingFrame
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext, OpenAILLMContextFrame

load_dotenv()


class DoubaoLLMService(LLMService):
def __init__(self, *, model: str, **kwargs):
super().__init__(model=model, **kwargs)
self._client = Ark()
self._model: str = model
# detect if user start a new conversation during llm generation
self.llm_user_interrupt = False
# self._enable_metrics = True
# self._report_only_initial_ttfb = True

async def get_chat_completions(
self,
messages: list) -> AsyncGenerator[dict, None]:
try:
stream = self._client.chat.completions.create(
model=self._model,
messages=messages,
stream=True
)
for chunk in stream:
if not chunk.choices:
continue
char = chunk.choices[0].delta.content
yield {"choices": [{"delta": {"content": char}}]}
except Exception as e:
logger.error(f"Error in doubao API call: {e}")
yield {"choices": [{"delta": {"content": f"Error: {str(e)}"}}]}

def _convert_context_to_doubao_format(self, context: OpenAILLMContext) -> list:
doubao_messages = [
{
"role": "system",
"content": "You are a helpful assistant."
}
]
guzi5618 marked this conversation as resolved.
Show resolved Hide resolved

messages_json = context.get_messages_json()
try:
messages = json.loads(messages_json)
if isinstance(messages, list):
for message in messages:
if isinstance(message, dict) and 'role' in message and 'content' in message:
doubao_messages.append({
"role": message['role'],
"content": message['content']
})
else:
logger.warning(f"Skipping invalid message format: {message}")
else:
logger.error(f"Unexpected type for messages after JSON parsing: {type(messages)}")
except json.JSONDecodeError as e:
logger.error(f"Error decoding JSON from get_messages_json(): {e}")

return doubao_messages

def can_generate_metrics(self) -> bool:
return self._enable_metrics

async def _process_context(self, context: OpenAILLMContext):
self.llm_user_interrupt = False
await self.start_ttfb_metrics()
messages = self._convert_context_to_doubao_format(context)
chunk_stream = self.get_chat_completions(messages)
try:
async for chunk in chunk_stream:
# detect if user start a new conversation during llm generation
if self.llm_user_interrupt:
break
await self.stop_ttfb_metrics()

if len(chunk["choices"]) == 0:
continue

if "content" in chunk["choices"][0]["delta"]:
await self.push_frame(TextFrame(chunk["choices"][0]["delta"]["content"]))

except Exception as e:
logger.error(f"Error in processing context: {e}")

async def process_frame(self, frame: Frame, direction):
if isinstance(frame, UserStartedSpeakingFrame):
self.llm_user_interrupt = True
guzi5618 marked this conversation as resolved.
Show resolved Hide resolved
await super().process_frame(frame, direction)

context = None
if isinstance(frame, OpenAILLMContextFrame):
context: OpenAILLMContext = frame.context
if isinstance(frame, LLMMessagesFrame):
context = OpenAILLMContext.from_messages(frame.messages)
if context:
try:
await self.push_frame(LLMFullResponseStartFrame())
await self.start_processing_metrics()
await self._process_context(context)
await self.stop_processing_metrics()
await self.push_frame(LLMFullResponseEndFrame())
except Exception as e:
logger.error(f"Error in processing context: {e}")
Empty file.
Loading
Loading