diff --git a/examples/foundational/07l-interruptible-Chinese.py b/examples/foundational/07l-interruptible-Chinese.py new file mode 100644 index 000000000..33b6bffeb --- /dev/null +++ b/examples/foundational/07l-interruptible-Chinese.py @@ -0,0 +1,119 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import aiohttp +import os +import sys + +from pipecat.frames.frames import Frame, LLMMessagesFrame, MetricsFrame +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.llm_response import ( + LLMAssistantResponseAggregator, LLMUserResponseAggregator) +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.services.chattts import ChatTTSTTSService +from pipecat.services.doubao import DoubaoLLMService +from pipecat.services.tencent import TencentSTTService +from pipecat.transports.services.daily import DailyParams, DailyTransport +from pipecat.vad.silero import SileroVADAnalyzer + +from runner import configure + +from loguru import logger + +from dotenv import load_dotenv +load_dotenv(override=True) + +logger.remove(0) +logger.add(sys.stderr, level="DEBUG") + + +class MetricsLogger(FrameProcessor): + async def process_frame(self, frame: Frame, direction: FrameDirection): + if isinstance(frame, MetricsFrame): + print( + f"!!! MetricsFrame: {frame}, ttfb: { + frame.ttfb}, processing: { + frame.processing}, tokens: { + frame.tokens}, characters: { + frame.characters}") + await self.push_frame(frame, direction) + + +async def main(): + async with aiohttp.ClientSession() as session: + (room_url, token) = await configure(session) + + transport = DailyTransport( + room_url, + token, + "Respond bot", + DailyParams( + audio_out_enabled=True, + audio_in_enabled=True, + transcription_enabled=False, + vad_enabled=True, + vad_audio_passthrough=True, + vad_analyzer=SileroVADAnalyzer(), + ) + ) + stt = TencentSTTService() + + # you need to setup a ChatTTS service + tts = ChatTTSTTSService( + aiohttp_session=session, + api_url=os.getenv("CHATTTS_API_URL", "http://localhost:8555/generate") + ) + + llm = DoubaoLLMService( + model=os.getenv("DOUBAO_MODEL_ID"), + ) + + messages = [ + { + "role": "system", + "content": "你是一个智能客服, 友好的回答用户问题" + } + ] + + ml = MetricsLogger() + tma_in = LLMUserResponseAggregator(messages) + tma_out = LLMAssistantResponseAggregator(messages) + + pipeline = Pipeline([ + transport.input(), # Transport user input + stt, # STT + tma_in, # User responses + llm, # LLM + tts, # TTS + transport.output(), # Transport bot output + tma_out # Assistant spoken responses + ]) + + task = PipelineTask(pipeline, PipelineParams( + allow_interruptions=True, + enable_metrics=True, + enable_usage_metrics=True, + report_only_initial_ttfb=True, + )) + + @transport.event_handler("on_first_participant_joined") + async def on_first_participant_joined(transport, participant): + transport.capture_participant_transcription(participant["id"]) + # Kick off the conversation. + messages.append( + {"role": "user", "content": "向用户问好.不超过10个字"}) + await task.queue_frames([LLMMessagesFrame(messages)]) + + runner = PipelineRunner() + + await runner.run(task) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/pipecat/services/chattts.py b/src/pipecat/services/chattts.py new file mode 100644 index 000000000..e20fb5c36 --- /dev/null +++ b/src/pipecat/services/chattts.py @@ -0,0 +1,71 @@ + + +import aiohttp +from loguru import logger +import requests +from typing import AsyncGenerator +from pipecat.frames.frames import Frame, AudioRawFrame, OutputAudioRawFrame, TTSStartedFrame, TTSStoppedFrame, ErrorFrame +from pipecat.services.ai_services import TTSService + + +class ChatTTSTTSService(TTSService): + def __init__( + self, + *, + api_url: str, + aiohttp_session: aiohttp.ClientSession, + sample_rate: int = 24000, + num_channels: int = 1, + **kwargs, + ): + super().__init__(**kwargs) + self.api_url = api_url + self._aiohttp_session = aiohttp_session + self._settings = { + "sample_rate": sample_rate, + "num_channels": num_channels, + } + + async def set_model(self, model: str): + pass + + async def set_voice(self, voice: str): + pass + + async def set_language(self, language: str): + pass + + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + try: + + logger.debug(f"ChatTTS Generating TTS: [{text}]") + payload = { + "texts": [text], + } + await self.start_ttfb_metrics() + async with self._aiohttp_session.post(self.api_url, json=payload) as response: + + if response.status != 200: + text = await response.text() + yield ErrorFrame("Error getting audio: {text}") + return + await self.start_tts_usage_metrics(text) + yield TTSStartedFrame() + + async for chunk in response.content.iter_chunked(1024): + if len(chunk) > 0: + await self.stop_ttfb_metrics() + frame = OutputAudioRawFrame( + audio=chunk, + sample_rate=self._settings["sample_rate"], + num_channels=self._settings["num_channels"], + ) + yield frame + + yield TTSStoppedFrame() + + except requests.exceptions.RequestException as e: + yield ErrorFrame(f"Request to ChatTTS failed: {e}") + + except Exception as e: + yield ErrorFrame(f"Unexpected error: {str(e)}") diff --git a/src/pipecat/services/doubao.py b/src/pipecat/services/doubao.py new file mode 100644 index 000000000..85f7fba50 --- /dev/null +++ b/src/pipecat/services/doubao.py @@ -0,0 +1,102 @@ + +# doubao is a Chinese llm service belonging to Bytedance +# install doubao sdk +# pip install 'volcengine-python-sdk[ark]' + +import os +from dotenv import load_dotenv +from volcenginesdkarkruntime import Ark +from typing import AsyncGenerator +from loguru import logger +import json +from pipecat.services.ai_services import LLMService +from pipecat.frames.frames import Frame, LLMMessagesFrame, TextFrame, LLMFullResponseStartFrame, LLMFullResponseEndFrame, UserStartedSpeakingFrame +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext, OpenAILLMContextFrame + +load_dotenv() + + +class DoubaoLLMService(LLMService): + def __init__(self, *, model: str, **kwargs): + super().__init__(model=model, **kwargs) + self._client = Ark() + self._model: str = model + + async def get_chat_completions( + self, + messages: list) -> AsyncGenerator[dict, None]: + try: + stream = self._client.chat.completions.create( + model=self._model, + messages=messages, + stream=True + ) + for chunk in stream: + if not chunk.choices: + continue + char = chunk.choices[0].delta.content + yield {"choices": [{"delta": {"content": char}}]} + except Exception as e: + logger.error(f"Error in doubao API call: {e}") + yield {"choices": [{"delta": {"content": f"Error: {str(e)}"}}]} + + def _convert_context_to_doubao_format(self, context: OpenAILLMContext) -> list: + doubao_messages = [] + messages_json = context.get_messages_json() + try: + messages = json.loads(messages_json) + if isinstance(messages, list): + for message in messages: + if isinstance(message, dict) and 'role' in message and 'content' in message: + doubao_messages.append({ + "role": message['role'], + "content": message['content'] + }) + else: + logger.warning(f"Skipping invalid message format: {message}") + else: + logger.error(f"Unexpected type for messages after JSON parsing: {type(messages)}") + except json.JSONDecodeError as e: + logger.error(f"Error decoding JSON from get_messages_json(): {e}") + return doubao_messages + + def can_generate_metrics(self) -> bool: + return self._enable_metrics + + async def _process_context(self, context: OpenAILLMContext): + + await self.start_ttfb_metrics() + messages = self._convert_context_to_doubao_format(context) + chunk_stream = self.get_chat_completions(messages) + try: + async for chunk in chunk_stream: + + await self.stop_ttfb_metrics() + + if len(chunk["choices"]) == 0: + continue + + if "content" in chunk["choices"][0]["delta"]: + await self.push_frame(TextFrame(chunk["choices"][0]["delta"]["content"])) + + except Exception as e: + logger.error(f"Error in processing context: {e}") + + async def process_frame(self, frame: Frame, direction): + + await super().process_frame(frame, direction) + + context = None + if isinstance(frame, OpenAILLMContextFrame): + context: OpenAILLMContext = frame.context + if isinstance(frame, LLMMessagesFrame): + context = OpenAILLMContext.from_messages(frame.messages) + if context: + try: + await self.push_frame(LLMFullResponseStartFrame()) + await self.start_processing_metrics() + await self._process_context(context) + await self.stop_processing_metrics() + await self.push_frame(LLMFullResponseEndFrame()) + except Exception as e: + logger.error(f"Error in processing context: {e}") diff --git a/src/pipecat/services/helpers/tencent/asr/__init__.py b/src/pipecat/services/helpers/tencent/asr/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/pipecat/services/helpers/tencent/asr/flash_recognizer.py b/src/pipecat/services/helpers/tencent/asr/flash_recognizer.py new file mode 100644 index 000000000..418ec3342 --- /dev/null +++ b/src/pipecat/services/helpers/tencent/asr/flash_recognizer.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- +import requests +import hmac +import hashlib +import base64 +import time +import random +import os +import json +from common import credential + +# 录音识别极速版使用 + + +class FlashRecognitionRequest: + def __init__(self, engine_type): + self.engine_type = engine_type + self.speaker_diarization = 0 + self.hotword_id = "" + self.hotword_list = "" + self.input_sample_rate = 0 + self.customization_id = "" + self.filter_dirty = 0 + self.filter_modal = 0 + self.filter_punc = 0 + self.convert_num_mode = 1 + self.word_info = 0 + self.voice_format = "" + self.first_channel_only = 1 + self.reinforce_hotword = 0 + self.sentence_max_length = 0 + + def set_first_channel_only(self, first_channel_only): + self.first_channel_only = first_channel_only + + def set_speaker_diarization(self, speaker_diarization): + self.speaker_diarization = speaker_diarization + + def set_filter_dirty(self, filter_dirty): + self.filter_dirty = filter_dirty + + def set_filter_modal(self, filter_modal): + self.filter_modal = filter_modal + + def set_filter_punc(self, filter_punc): + self.filter_punc = filter_punc + + def set_convert_num_mode(self, convert_num_mode): + self.convert_num_mode = convert_num_mode + + def set_word_info(self, word_info): + self.word_info = word_info + + def set_hotword_id(self, hotword_id): + self.hotword_id = hotword_id + + def set_hotword_list(self, hotword_list): + self.hotword_list = hotword_list + + def set_input_sample_rate(self, input_sample_rate): + self.input_sample_rate = input_sample_rate + + def set_customization_id(self, customization_id): + self.customization_id = customization_id + + def set_voice_format(self, voice_format): + self.voice_format = voice_format + + def set_sentence_max_length(self, sentence_max_length): + self.sentence_max_length = sentence_max_length + + def set_reinforce_hotword(self, reinforce_hotword): + self.reinforce_hotword = reinforce_hotword + + +class FlashRecognizer: + ''' + reponse: + 字段名 类型 + request_id string + status Integer + message String + audio_duration Integer + flash_result Result Array + + Result的结构体格式为: + text String + channel_id Integer + sentence_list Sentence Array + + Sentence的结构体格式为: + text String + start_time Integer + end_time Integer + speaker_id Integer + word_list Word Array + + Word的类型为: + word String + start_time Integer + end_time Integer + stable_flag: Integer + ''' + + def __init__(self, appid, credential): + self.credential = credential + self.appid = appid + + def _format_sign_string(self, param): + signstr = "POSTasr.cloud.tencent.com/asr/flash/v1/" + for t in param: + if 'appid' in t: + signstr += str(t[1]) + break + signstr += "?" + for x in param: + tmp = x + if 'appid' in x: + continue + for t in tmp: + signstr += str(t) + signstr += "=" + signstr = signstr[:-1] + signstr += "&" + signstr = signstr[:-1] + return signstr + + def _build_header(self): + header = dict() + header["Host"] = "asr.cloud.tencent.com" + return header + + def _sign(self, signstr, secret_key): + hmacstr = hmac.new(secret_key.encode('utf-8'), + signstr.encode('utf-8'), hashlib.sha1).digest() + s = base64.b64encode(hmacstr) + s = s.decode('utf-8') + return s + + def _build_req_with_signature(self, secret_key, params, header): + query = sorted(params.items(), key=lambda d: d[0]) + signstr = self._format_sign_string(query) + signature = self._sign(signstr, secret_key) + header["Authorization"] = signature + requrl = "https://" + requrl += signstr[4::] + return requrl + + def _create_query_arr(self, req): + query_arr = dict() + query_arr['appid'] = self.appid + query_arr['secretid'] = self.credential.secret_id + query_arr['timestamp'] = str(int(time.time())) + query_arr['engine_type'] = req.engine_type + query_arr['voice_format'] = req.voice_format + query_arr['speaker_diarization'] = req.speaker_diarization + if req.hotword_id != "": + query_arr['hotword_id'] = req.hotword_id + if req.hotword_list != "": + query_arr['hotword_list'] = req.hotword_list + if req.input_sample_rate != 0: + query_arr['input_sample_rate'] = req.input_sample_rate + query_arr['customization_id'] = req.customization_id + query_arr['filter_dirty'] = req.filter_dirty + query_arr['filter_modal'] = req.filter_modal + query_arr['filter_punc'] = req.filter_punc + query_arr['convert_num_mode'] = req.convert_num_mode + query_arr['word_info'] = req.word_info + query_arr['first_channel_only'] = req.first_channel_only + query_arr['reinforce_hotword'] = req.reinforce_hotword + query_arr['sentence_max_length'] = req.sentence_max_length + return query_arr + + def recognize(self, req, data): + header = self._build_header() + query_arr = self._create_query_arr(req) + req_url = self._build_req_with_signature(self.credential.secret_key, query_arr, header) + r = requests.post(req_url, headers=header, data=data) + return r.text diff --git a/src/pipecat/services/helpers/tencent/asr/speech_recognizer.py b/src/pipecat/services/helpers/tencent/asr/speech_recognizer.py new file mode 100644 index 000000000..c87ccc7a8 --- /dev/null +++ b/src/pipecat/services/helpers/tencent/asr/speech_recognizer.py @@ -0,0 +1,322 @@ +# -*- coding: utf-8 -*- +import asyncio +import sys +import hmac +import hashlib +import base64 +import time +import json +import threading +from loguru import logger +import websocket +import uuid +import urllib + +import websockets + + +def is_python3(): + if sys.version > '3': + return True + return False + + +# 实时识别语音使用 +class SpeechRecognitionListener(): + ''' + reponse: + on_recognition_start的返回只有voice_id字段。 + on_fail 只有voice_id、code、message字段。 + on_recognition_complete没有result字段。 + 其余消息包含所有字段。 + 字段名 类型 + code Integer + message String + voice_id String + message_id String + result Result + final Integer + + Result的结构体格式为: + slice_type Integer + index Integer + start_time Integer + end_time Integer + voice_text_str String + word_size Integer + word_list Word Array + + Word的类型为: + word String + start_time Integer + end_time Integer + stable_flag:Integer + ''' + + async def on_recognition_start(self, response): + pass + + async def on_sentence_begin(self, response): + pass + + async def on_recognition_result_change(self, response): + pass + + async def on_sentence_end(self, response): + pass + + async def on_recognition_complete(self, response): + pass + + async def on_fail(self, response): + pass + + +NOTOPEN = 0 +STARTED = 1 +OPENED = 2 +FINAL = 3 +ERROR = 4 +CLOSED = 5 + +# 实时识别语音使用 + + +class SpeechRecognizer: + + def __init__(self, appid, credential, engine_model_type, listener): + self.result = "" + self.credential = credential + self.appid = appid + self.engine_model_type = engine_model_type + self.status = NOTOPEN + self.ws = None + self.wst = None + self.voice_id = "" + self.new_start = 0 + self.listener = listener + self.filter_dirty = 0 + self.filter_modal = 0 + self.filter_punc = 0 + self.convert_num_mode = 0 + self.word_info = 0 + self.need_vad = 0 + self.vad_silence_time = 0 + self.hotword_id = "" + self.hotword_list = "" + self.reinforce_hotword = 0 + self.noise_threshold = 0 + self.voice_format = 4 + self.nonce = "" + + def set_filter_dirty(self, filter_dirty): + self.filter_dirty = filter_dirty + + def set_filter_modal(self, filter_modal): + self.filter_modal = filter_modal + + def set_filter_punc(self, filter_punc): + self.filter_punc = filter_punc + + def set_convert_num_mode(self, convert_num_mode): + self.convert_num_mode = convert_num_mode + + def set_word_info(self, word_info): + self.word_info = word_info + + def set_need_vad(self, need_vad): + self.need_vad = need_vad + + def set_vad_silence_time(self, vad_silence_time): + self.vad_silence_time = vad_silence_time + + def set_hotword_id(self, hotword_id): + self.hotword_id = hotword_id + + def set_hotword_list(self, hotword_list): + self.hotword_list = hotword_list + + def set_voice_format(self, voice_format): + self.voice_format = voice_format + + def set_nonce(self, nonce): + self.nonce = nonce + + def set_reinforce_hotword(self, reinforce_hotword): + self.reinforce_hotword = reinforce_hotword + + def set_noise_threshold(self, noise_threshold): + self.noise_threshold = noise_threshold + + def format_sign_string(self, param): + signstr = "asr.cloud.tencent.com/asr/v2/" + for t in param: + if 'appid' in t: + signstr += str(t[1]) + break + signstr += "?" + for x in param: + tmp = x + if 'appid' in x: + continue + for t in tmp: + signstr += str(t) + signstr += "=" + signstr = signstr[:-1] + signstr += "&" + signstr = signstr[:-1] + return signstr + + def create_query_string(self, param): + signstr = "wss://asr.cloud.tencent.com/asr/v2/" + for t in param: + if 'appid' in t: + signstr += str(t[1]) + break + signstr += "?" + for x in param: + tmp = x + if 'appid' in x: + continue + for t in tmp: + signstr += str(t) + signstr += "=" + signstr = signstr[:-1] + signstr += "&" + signstr = signstr[:-1] + return signstr + + def sign(self, signstr, secret_key): + hmacstr = hmac.new(secret_key.encode('utf-8'), + signstr.encode('utf-8'), hashlib.sha1).digest() + s = base64.b64encode(hmacstr) + s = s.decode('utf-8') + return s + + def create_query_arr(self): + query_arr = dict() + + query_arr['appid'] = self.appid + query_arr['sub_service_type'] = 1 + query_arr['engine_model_type'] = self.engine_model_type + query_arr['filter_dirty'] = self.filter_dirty + query_arr['filter_modal'] = self.filter_modal + query_arr['filter_punc'] = self.filter_punc + query_arr['needvad'] = self.need_vad + query_arr['convert_num_mode'] = self.convert_num_mode + query_arr['word_info'] = self.word_info + if self.vad_silence_time != 0: + query_arr['vad_silence_time'] = self.vad_silence_time + if self.hotword_id != "": + query_arr['hotword_id'] = self.hotword_id + if self.hotword_list != "": + query_arr['hotword_list'] = self.hotword_list + + query_arr['secretid'] = self.credential.secret_id + query_arr['voice_format'] = self.voice_format + query_arr['voice_id'] = self.voice_id + query_arr['timestamp'] = str(int(time.time())) + if self.nonce != "": + query_arr['nonce'] = self.nonce + else: + query_arr['nonce'] = query_arr['timestamp'] + query_arr['expired'] = int(time.time()) + 24 * 60 * 60 + query_arr['reinforce_hotword'] = self.reinforce_hotword + query_arr['noise_threshold'] = self.noise_threshold + return query_arr + + async def stop(self): + if self.status == OPENED: + msg = {} + msg['type'] = "end" + text_str = json.dumps(msg) + self.ws.sock.send(text_str) + if self.ws: + if self.wst and self.wst.is_alive(): + self.wst.join() + self.ws.close() + + async def write(self, data): + + if not self.ws.open: + logger.error("WebSocket is closed, unable to send data") + return + + try: + await self.ws.send(data) + except websockets.exceptions.ConnectionClosedOK as e: + logger.error(f"WebSocket connection closed: {e}") + # 处理重连逻辑或退出 + except Exception as e: + logger.error(f"Failed to send data over WebSocket: {e}") + + async def connect_websocket(self, requrl): + try: + ws = await websockets.connect(requrl) + self.ws = ws + self.status = OPENED + response = {'voice_id': self.voice_id} + await self.listener.on_recognition_start(response) + logger.info(f"{self.voice_id} recognition start") + except Exception as e: + logger.error(f"WebSocket connection failed: {e}") + raise e + + async def start(self): + query_arr = self.create_query_arr() + if self.voice_id == "": + query_arr['voice_id'] = str(uuid.uuid1()) + self.voice_id = query_arr['voice_id'] + query = sorted(query_arr.items(), key=lambda d: d[0]) + signstr = self.format_sign_string(query) + + autho = self.sign(signstr, self.credential.secret_key) + requrl = self.create_query_string(query) + if is_python3(): + autho = urllib.parse.quote(autho) + else: + autho = urllib.quote(autho) + requrl += "&signature=%s" % autho + + try: + # 建立 WebSocket 连接,并且创建接收消息的任务 + await self.connect_websocket(requrl) + + # 启动接收消息的任务 + asyncio.create_task(self.listen_for_messages(self.ws)) + + except Exception as e: + logger.error(f"WebSocket connection failed: {e}") + + async def listen_for_messages(self, ws): + # 处理消息的循环 + try: + if ws.open: + async for message in ws: + await self.on_message(message) + except Exception as e: + logger.error(f"WebSocket connection failed: {e}") + + async def on_message(self, message): + + response = json.loads(message) + response['voice_id'] = self.voice_id + if response['code'] != 0: + logger.error(f"{self.voice_id} server recognition fail {response['message']}") + await self.listener.on_fail(response) + return + + if "final" in response and response["final"] == 1: + self.status = FINAL + self.result = message + await self.listener.on_recognition_complete(response) + logger.info(f"{self.voice_id} recognition complete") + return + + if "result" in response.keys(): + if response["result"]['slice_type'] == 0: + await self.listener.on_sentence_begin(response) + elif response["result"]["slice_type"] == 2: + await self.listener.on_sentence_end(response) + elif response["result"]["slice_type"] == 1: + await self.listener.on_recognition_result_change(response) diff --git a/src/pipecat/services/helpers/tencent/common/__init__.py b/src/pipecat/services/helpers/tencent/common/__init__.py new file mode 100644 index 000000000..f783418f0 --- /dev/null +++ b/src/pipecat/services/helpers/tencent/common/__init__.py @@ -0,0 +1 @@ +#!#-*-coding:utf-8 -*- diff --git a/src/pipecat/services/helpers/tencent/common/credential.py b/src/pipecat/services/helpers/tencent/common/credential.py new file mode 100644 index 000000000..64d0aa96f --- /dev/null +++ b/src/pipecat/services/helpers/tencent/common/credential.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +class Credential: + def __init__(self, secret_id, secret_key, token=""): + self.secret_id = secret_id + self.secret_key = secret_key + self.token = token diff --git a/src/pipecat/services/helpers/tencent/common/utils.py b/src/pipecat/services/helpers/tencent/common/utils.py new file mode 100644 index 000000000..cf7d1f227 --- /dev/null +++ b/src/pipecat/services/helpers/tencent/common/utils.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- +import sys + +def is_python3(): + if sys.version > '3': + return True + return False \ No newline at end of file diff --git a/src/pipecat/services/tencent.py b/src/pipecat/services/tencent.py new file mode 100644 index 000000000..34206910b --- /dev/null +++ b/src/pipecat/services/tencent.py @@ -0,0 +1,117 @@ +import json +from typing import AsyncGenerator +from pipecat.frames.frames import ( + AudioRawFrame, + CancelFrame, + EndFrame, + ErrorFrame, + Frame, + InterimTranscriptionFrame, + TranscriptionFrame +) +from pipecat.services.ai_services import STTService +from pipecat.transcriptions.language import Language +from pipecat.utils.time import time_now_iso8601 +from loguru import logger +from pipecat.services.helpers.tencent.common import credential +from pipecat.services.helpers.tencent.asr import speech_recognizer +import os +from dotenv import load_dotenv + +load_dotenv() + +APPID = os.getenv("TENCENT_APPID") +SECRET_ID = os.getenv("TENCENT_SECRET_ID") +SECRET_KEY = os.getenv("TENCENT_SECRET_KEY") +ENGINE_MODEL_TYPE = os.getenv("TENCENT_ENGINE_MODEL_TYPE", "16k_zh") + + +class TencentSTTService(STTService): + def __init__(self, *, param1=None, param2=None, **kwargs): + super().__init__(**kwargs) + self._recognizer = None + self._listener = None + self._credential = credential.Credential(SECRET_ID, SECRET_KEY) + + async def set_model(self, model: str): + await super().set_model(model) + logger.debug(f"Switching STT model to: [{model}]") + global ENGINE_MODEL_TYPE + ENGINE_MODEL_TYPE = model + await self._reconnect() + + async def start(self, frame: Frame): + await super().start(frame) + await self._initialize_recognizer() + + async def stop(self, frame: EndFrame): + await super().stop(frame) + if self._recognizer: + await self._recognizer.stop() + + async def cancel(self, frame: CancelFrame): + await super().cancel(frame) + if self._recognizer: + await self._recognizer.stop() + + async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: + + if self._recognizer.ws.open: + await self._recognizer.write(audio) + if not self._recognizer.ws.open: + await self._reconnect() + + yield None + + async def _initialize_recognizer(self): + # logger.debug("Initializing Tencent STT recognizer") + self._listener = MySpeechRecognitionListener(self) + self._recognizer = speech_recognizer.SpeechRecognizer( + APPID, self._credential, ENGINE_MODEL_TYPE, self._listener + ) + self._recognizer.set_filter_modal(1) + self._recognizer.set_filter_punc(1) + self._recognizer.set_filter_dirty(1) + self._recognizer.set_need_vad(1) + self._recognizer.set_voice_format(1) + self._recognizer.set_word_info(1) + self._recognizer.set_convert_num_mode(1) + await self._recognizer.start() + + async def _on_message(self, response): + result = response['result'] + if result['slice_type'] == 1: + transcript = result['voice_text_str'] + await self.push_frame(InterimTranscriptionFrame(transcript, "", time_now_iso8601())) + elif result['slice_type'] == 2: + transcript = result['voice_text_str'] + logger.debug(f"Tencent STT: {transcript}") + + await self.push_frame(TranscriptionFrame(transcript, "", time_now_iso8601())) + + async def _reconnect(self): + if self._recognizer: + self._recognizer.stop() + await self._initialize_recognizer() + + +class MySpeechRecognitionListener(speech_recognizer.SpeechRecognitionListener): + def __init__(self, service): + self.service = service + self.partial_result = "" + + async def on_recognition_start(self, response): + pass + + async def on_sentence_begin(self, response): + pass + + async def on_recognition_result_change(self, response): + await self.service._on_message(response) + + async def on_sentence_end(self, response): + await self.service._on_message(response) + + async def on_fail(self, response): + rsp_str = json.dumps(response, ensure_ascii=False) + logger.error(f"Recognition failed: {rsp_str}")