Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use event handler instead #12

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 77 additions & 30 deletions src/quartapp/chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root for full license information.

from typing import Any
from typing import Any, AsyncGenerator
from quart import Blueprint, jsonify, request, Response, render_template, current_app

import asyncio
Expand All @@ -18,7 +18,7 @@
FileSearchTool,
AsyncToolSet,
FilePurpose,
AgentStreamEvent
AsyncAgentEventHandler
)

bp = Blueprint("chat", __name__, template_folder="templates", static_folder="static")
Expand Down Expand Up @@ -82,36 +82,83 @@ async def stop_server():
async def index():
return await render_template("index.html")

async def create_stream(thread_id: str, agent_id: str):
class MyEventHandler(AsyncAgentEventHandler):

def __init__(self, queue: asyncio.Queue):
super().__init__()
self.queue = queue
self.accumulated_text = ""

async def on_message_delta(self, delta: "MessageDeltaChunk") -> None:
for content_part in delta.delta.content:
if isinstance(content_part, MessageDeltaTextContent):
text_value = content_part.text.value if content_part.text else "No text"
self.accumulated_text += text_value
stream_data = json.dumps({'content': text_value, 'type': "message"})
print(f"Stream data: {stream_data}")
await self.queue.put(("message", text_value))

async def on_thread_message(self, message: "ThreadMessage") -> None:

if (message.status == "completed"):
stream_data = json.dumps({'content': self.accumulated_text, 'type': "completed_message"})
print(f"Stream data: {stream_data}")
await self.queue.put(("completed_message", self.accumulated_text))

async def on_thread_run(self, run: "ThreadRun") -> None:
print(f"ThreadRun status: {run.status}")

async def on_run_step(self, step: "RunStep") -> None:
print(f"RunStep type: {step.type}, Status: {step.status}")

async def on_error(self, data: str) -> None:
print(f"An error occurred. Data: {data}")
stream_data = json.dumps({'type': "stream_end"})
print(f"Stream data: {stream_data}")

async def on_done(self) -> None:
print("Stream completed.")
await self.queue.put(("stream_end", ""))

async def on_unhandled_event(self, event_type: str, event_data: Any) -> None:
print(f"Unhandled Event Type: {event_type}, Data: {event_data}")

async def create_stream(queue: asyncio.Queue, thread_id: str, agent_id: str):
async with await bp.ai_client.agents.create_stream(
thread_id=thread_id, assistant_id=agent_id
thread_id=thread_id, assistant_id=agent_id,
event_handler=MyEventHandler(queue)
) as stream:
accumulated_text = ""

async for event_type, event_data in stream:

stream_data = None
if isinstance(event_data, MessageDeltaChunk):
for content_part in event_data.delta.content:
if isinstance(content_part, MessageDeltaTextContent):
text_value = content_part.text.value if content_part.text else "No text"
accumulated_text += text_value
print(f"Text delta received: {text_value}")
stream_data = json.dumps({'content': text_value, 'type': "message"})

elif isinstance(event_data, ThreadMessage):
print(f"ThreadMessage created. ID: {event_data.id}, Status: {event_data.status}")
if (event_data.status == "completed"):
stream_data = json.dumps({'content': accumulated_text, 'type': "completed_message"})

elif event_type == AgentStreamEvent.DONE:
print("Stream completed.")
stream_data = json.dumps({'type': "stream_end"})

if stream_data:
yield f"data: {stream_data}\n\n"

await stream.until_done()

async def get_result(thread_id: str, agent_id: str):

queue = asyncio.Queue()

task = asyncio.create_task(create_stream(queue, thread_id, agent_id))

while True:
try:
message_type, message = await queue.get()
if message_type == "message":
event_data = json.dumps({'content': message, 'type': message_type})
yield f"data: {event_data}\n\n"
elif message_type == "completed_message":
event_data = json.dumps({'content': message, 'type': message_type})
yield f"data: {event_data}\n\n"
elif message_type == "stream_end":
event_data = json.dumps({'content': message, 'type': message_type})
yield f"data: {event_data}\n\n"
await queue.task_done()
return
elif message_type == "function":
function_message = f"Function {message} called"
event_data = json.dumps({'content': function_message})
yield f"data: {event_data}\n\n"
except StopIteration:
break

await task

@bp.route('/chat', methods=['POST'])
async def chat():
thread_id = request.cookies.get('thread_id')
Expand Down Expand Up @@ -147,7 +194,7 @@ async def chat():
'Content-Type': 'text/event-stream'
}

response = Response(create_stream(thread_id, agent_id), headers=headers)
response = Response(get_result(thread_id, agent_id), headers=headers)
response.set_cookie('thread_id', thread_id)
response.set_cookie('agent_id', agent_id)
return response
Expand Down