Skip to content

Commit

Permalink
Fix streaming for sequencial multi function calling. Added streaming …
Browse files Browse the repository at this point in the history
…test.
  • Loading branch information
Alex-Karmazin committed Jul 25, 2024
1 parent 04186ee commit 00b2b68
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 30 deletions.
55 changes: 26 additions & 29 deletions just_agents/streaming/openai_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,39 +13,36 @@ async def resp_async_generator(self, memory: Memory,
options: dict,
available_tools: dict[str, Callable]
) -> AsyncGenerator[str, None]:
response: ModelResponse = rotate_completion(messages=memory.messages, stream=True, options=options)
parser: Optional[FunctionParser] = None
tool_messages: list[dict] = []
parsers: list[FunctionParser] = []
deltas: list[str] = []
for i, part in enumerate(response):
delta: str = part["choices"][0]["delta"].get("content") # type: ignore
if delta:
deltas.append(delta)
yield f"data: {self._get_chunk(i, delta, options)}\n\n"

tool_calls = part["choices"][0]["delta"].get("tool_calls")
if tool_calls and (available_tools is not None):
if not parser:
parser = FunctionParser(id = tool_calls[0].id)
if parser.parsed(tool_calls[0].function.name, tool_calls[0].function.arguments):
tool_messages.append(self._process_function(parser, available_tools))
parsers.append(parser)
parser = None #maybe Optional?

if len(tool_messages) > 0:
memory.add_message(self._get_tool_call_message(parsers))
for message in tool_messages:
memory.add_message(message)
response = rotate_completion(messages=memory.messages, stream=True, options=options)
deltas = []
proceed = True
while proceed:
proceed = False
response: ModelResponse = rotate_completion(messages=memory.messages, stream=True, options=options)
parser: Optional[FunctionParser] = None
tool_messages: list[dict] = []
parsers: list[FunctionParser] = []
deltas: list[str] = []
for i, part in enumerate(response):
delta: str = part["choices"][0]["delta"].get("content") # type: ignore
if delta:
deltas.append(delta)
yield f"data: {self._get_chunk(i, delta, options)}\n\n"
memory.add_message({"role":"assistant", "content":"".join(deltas)})
elif len(deltas) > 0:
memory.add_message({"role":"assistant", "content":"".join(deltas)})

tool_calls = part["choices"][0]["delta"].get("tool_calls")
if tool_calls and (available_tools is not None):
if not parser:
parser = FunctionParser(id = tool_calls[0].id)
if parser.parsed(tool_calls[0].function.name, tool_calls[0].function.arguments):
tool_messages.append(self._process_function(parser, available_tools))
parsers.append(parser)
parser = None #maybe Optional?

if len(tool_messages) > 0:
proceed = True
memory.add_message(self._get_tool_call_message(parsers))
for message in tool_messages:
memory.add_message(message)

if len(deltas) > 0:
memory.add_message({"role":"assistant", "content":"".join(deltas)})

yield "data: [DONE]\n\n"
2 changes: 1 addition & 1 deletion just_agents/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def len(self):
return len(self.keys)


def rotate_completion(messages: list[Message | dict], options: dict[str, str], stream: bool, remove_key_on_error: bool = True, max_tries: int = 2) -> ModelResponse:
def rotate_completion(messages: list[dict], options: dict[str, str], stream: bool, remove_key_on_error: bool = True, max_tries: int = 2) -> ModelResponse:
opt = options.copy()
key_getter: RotateKeys = opt.pop("key_getter", None)
backup_opt: dict = opt.pop("backup_options", None)
Expand Down
19 changes: 19 additions & 0 deletions tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import just_agents.llm_options
from just_agents.llm_session import LLMSession
import asyncio


def get_current_weather(location: str):
Expand All @@ -31,6 +32,24 @@ def test_sync_llama_function_calling():
assert "22" in result
assert "10" in result

async def process_stream(async_generator):
async for item in async_generator:
pass

def test_stream_llama_function_calling():
load_dotenv()
session: LLMSession = LLMSession(
llm_options=just_agents.llm_options.LLAMA3_1,
tools=[get_current_weather]
)
stream = session.stream("What's the weather like in San Francisco, Tokyo, and Paris?")
loop = asyncio.get_event_loop()
loop.run_until_complete(process_stream(stream))
result = session.memory.last_message["content"]
assert "72" in result
assert "22" in result
assert "10" in result

@pytest.mark.skip(reason="so far qwen inference we are using has issues with json function calling")
def test_async_gwen2_function_calling():
load_dotenv()
Expand Down

0 comments on commit 00b2b68

Please sign in to comment.