Skip to content

Commit

Permalink
Merge pull request #665 from Undertone0809/hizeros/streaming
Browse files Browse the repository at this point in the history
fix: error in streaming
  • Loading branch information
Undertone0809 authored May 17, 2024
2 parents 176a3fa + fbbf113 commit 4de130a
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 31 deletions.
13 changes: 9 additions & 4 deletions promptulate/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def run(
"stream, tools and output_schema can't be True at the same time, "
"because stream is used to return Iterator[BaseMessage]."
)

if self.agent:
return self.agent.run(messages, output_schema=output_schema)

Expand All @@ -151,12 +150,18 @@ def run(
json_schema=output_schema, examples=examples
)
messages.messages[-1].content += f"\n{instruction}"

logger.info(f"[pne chat] messages: {messages}")

response: AssistantMessage = self.llm.predict(messages, stream=stream, **kwargs)
response: Union[AssistantMessage, StreamIterator] = self.llm.predict(
messages, stream=stream, **kwargs
)

if stream:
return response

logger.info(f"[pne chat] response: {response.additional_kwargs}")
if isinstance(response, AssistantMessage):
# Access additional_kwargs only if response is AssistantMessage
logger.info(f"[pne chat] response: {response.additional_kwargs}")

# return output format if provide
if output_schema:
Expand Down
7 changes: 5 additions & 2 deletions promptulate/llms/_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def parse_content(chunk) -> (str, str):
ret_data: The additional data of the chunk.
"""
content = chunk.choices[0].delta.content
ret_data = json.loads(chunk.json())
ret_data = json.loads(json.dumps(chunk.json()))
return content, ret_data


Expand All @@ -42,7 +42,10 @@ def _predict(
) -> Union[AssistantMessage, StreamIterator]:
logger.info(f"[pne chat] prompts: {messages.string_messages}")
temp_response = litellm.completion(
model=self._model, messages=messages.listdict_messages, **self._model_config
model=self._model,
messages=messages.listdict_messages,
**self._model_config,
stream=stream,
)

if stream:
Expand Down
1 change: 1 addition & 0 deletions promptulate/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def predict(self, messages: MessageSet, *args, **kwargs) -> AssistantMessage:
result = self._predict(messages, *args, **kwargs)
if isinstance(result, AssistantMessage):
Hook.call_hook(HookTable.ON_LLM_RESULT, self, result=result.content)

return result

@abstractmethod
Expand Down
17 changes: 1 addition & 16 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,4 @@ typing-extensions==4.11.0 ; python_full_version >= "3.8.1" and python_version <
urllib3==2.2.1 ; python_full_version >= "3.8.1" and python_version < "4.0"
wcwidth==0.2.13 ; python_full_version >= "3.8.1" and python_version < "4.0"
yarl==1.9.4 ; python_full_version >= "3.8.1" and python_version < "4.0"
zipp==3.18.1 ; python_full_version >= "3.8.1" and python_version < "4.0"

promptulate~=1.13.1
pytest~=7.4.4
langchain~=0.1.4
streamlit~=1.33.0
typing_extensions~=4.9.0
pyjwt~=2.8.0
requests~=2.31.0
numexpr~=2.8.6
arxiv~=1.4.8
pydantic~=1.10.14
click~=8.1.7
questionary~=2.0.1
litellm~=1.23.14
python-dotenv~=1.0.1
zipp==3.18.1 ; python_full_version >= "3.8.1" and python_version < "4.0"
43 changes: 34 additions & 9 deletions tests/test_chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
TODO add test: test_stream, test pne's llm, test litellm llm
TODO add test: test pne's llm, test litellm llm
"""
from typing import Optional
from typing import Generator, Optional, Union

import pytest

Expand All @@ -17,6 +17,21 @@
)


class StreamLLM(BaseLLM):
def _predict(
self, messages: MessageSet, *args, **kwargs
) -> Optional[type(BaseMessage)]:
messages: list = [
AssistantMessage(content="This", additional_kwargs={}),
AssistantMessage(content="is", additional_kwargs={}),
AssistantMessage(content="fake", additional_kwargs={}),
AssistantMessage(content="message", additional_kwargs={}),
]

for message in messages:
yield message


class FakeLLM(BaseLLM):
llm_type: str = "fake"

Expand All @@ -26,13 +41,13 @@ def __init__(self, *args, **kwargs):
def __call__(self, instruction: str, *args, **kwargs):
return "fake response"

def _predict(self, messages: MessageSet, *args, **kwargs) -> BaseMessage:
def _predict(
self, messages: MessageSet, *args, **kwargs
) -> Union[BaseMessage, Generator]:
content = "fake response"

if "Output format" in messages.messages[-1].content:
content = """{"city": "Shanghai", "temperature": 25}"""

return AssistantMessage(content=content)
return AssistantMessage(content=content, additional_kwargs={})


def mock_tool():
Expand Down Expand Up @@ -148,12 +163,22 @@ class LLMResponse(BaseModel):


def test_streaming():
llm = FakeLLM()
llm = StreamLLM()

# test stream output
answer = chat(
answer_stream = pne.chat(
"what's weather tomorrow in shanghai?",
custom_llm=llm,
stream=True,
)
assert answer == "fake response"

# check if the answer is a stream of response
answer: list = []
for item in answer_stream:
answer.append(item)

assert len(answer) == 4
assert answer[0].content == "This"
assert answer[1].content == "is"
assert answer[2].content == "fake"
assert answer[3].content == "message"

0 comments on commit 4de130a

Please sign in to comment.