diff --git a/src/quartapp/chat.py b/src/quartapp/chat.py index 0d39eec..63768d2 100644 --- a/src/quartapp/chat.py +++ b/src/quartapp/chat.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. # Licensed under the MIT license. See LICENSE.md file in the project root for full license information. -from typing import Any +from typing import Any, AsyncGenerator from quart import Blueprint, jsonify, request, Response, render_template, current_app import asyncio @@ -18,7 +18,7 @@ FileSearchTool, AsyncToolSet, FilePurpose, - AgentStreamEvent + AsyncAgentEventHandler ) bp = Blueprint("chat", __name__, template_folder="templates", static_folder="static") @@ -82,36 +82,83 @@ async def stop_server(): async def index(): return await render_template("index.html") -async def create_stream(thread_id: str, agent_id: str): +class MyEventHandler(AsyncAgentEventHandler): + + def __init__(self, queue: asyncio.Queue): + super().__init__() + self.queue = queue + self.accumulated_text = "" + + async def on_message_delta(self, delta: "MessageDeltaChunk") -> None: + for content_part in delta.delta.content: + if isinstance(content_part, MessageDeltaTextContent): + text_value = content_part.text.value if content_part.text else "No text" + self.accumulated_text += text_value + stream_data = json.dumps({'content': text_value, 'type': "message"}) + print(f"Stream data: {stream_data}") + await self.queue.put(("message", text_value)) + + async def on_thread_message(self, message: "ThreadMessage") -> None: + + if (message.status == "completed"): + stream_data = json.dumps({'content': self.accumulated_text, 'type': "completed_message"}) + print(f"Stream data: {stream_data}") + await self.queue.put(("completed_message", self.accumulated_text)) + + async def on_thread_run(self, run: "ThreadRun") -> None: + print(f"ThreadRun status: {run.status}") + + async def on_run_step(self, step: "RunStep") -> None: + print(f"RunStep type: {step.type}, Status: {step.status}") + + async def on_error(self, data: str) -> None: + print(f"An error occurred. Data: {data}") + stream_data = json.dumps({'type': "stream_end"}) + print(f"Stream data: {stream_data}") + + async def on_done(self) -> None: + print("Stream completed.") + await self.queue.put(("stream_end", "")) + + async def on_unhandled_event(self, event_type: str, event_data: Any) -> None: + print(f"Unhandled Event Type: {event_type}, Data: {event_data}") + +async def create_stream(queue: asyncio.Queue, thread_id: str, agent_id: str): async with await bp.ai_client.agents.create_stream( - thread_id=thread_id, assistant_id=agent_id + thread_id=thread_id, assistant_id=agent_id, + event_handler=MyEventHandler(queue) ) as stream: - accumulated_text = "" - - async for event_type, event_data in stream: - - stream_data = None - if isinstance(event_data, MessageDeltaChunk): - for content_part in event_data.delta.content: - if isinstance(content_part, MessageDeltaTextContent): - text_value = content_part.text.value if content_part.text else "No text" - accumulated_text += text_value - print(f"Text delta received: {text_value}") - stream_data = json.dumps({'content': text_value, 'type': "message"}) - - elif isinstance(event_data, ThreadMessage): - print(f"ThreadMessage created. ID: {event_data.id}, Status: {event_data.status}") - if (event_data.status == "completed"): - stream_data = json.dumps({'content': accumulated_text, 'type': "completed_message"}) - - elif event_type == AgentStreamEvent.DONE: - print("Stream completed.") - stream_data = json.dumps({'type': "stream_end"}) - - if stream_data: - yield f"data: {stream_data}\n\n" - + await stream.until_done() + +async def get_result(thread_id: str, agent_id: str): + + queue = asyncio.Queue() + task = asyncio.create_task(create_stream(queue, thread_id, agent_id)) + + while True: + try: + message_type, message = await queue.get() + if message_type == "message": + event_data = json.dumps({'content': message, 'type': message_type}) + yield f"data: {event_data}\n\n" + elif message_type == "completed_message": + event_data = json.dumps({'content': message, 'type': message_type}) + yield f"data: {event_data}\n\n" + elif message_type == "stream_end": + event_data = json.dumps({'content': message, 'type': message_type}) + yield f"data: {event_data}\n\n" + await queue.task_done() + return + elif message_type == "function": + function_message = f"Function {message} called" + event_data = json.dumps({'content': function_message}) + yield f"data: {event_data}\n\n" + except StopIteration: + break + + await task + @bp.route('/chat', methods=['POST']) async def chat(): thread_id = request.cookies.get('thread_id') @@ -147,7 +194,7 @@ async def chat(): 'Content-Type': 'text/event-stream' } - response = Response(create_stream(thread_id, agent_id), headers=headers) + response = Response(get_result(thread_id, agent_id), headers=headers) response.set_cookie('thread_id', thread_id) response.set_cookie('agent_id', agent_id) return response